python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
doctr/utils/data.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -13,7 +13,6 @@ import urllib
13
13
  import urllib.error
14
14
  import urllib.request
15
15
  from pathlib import Path
16
- from typing import Optional, Union
17
16
 
18
17
  from tqdm.auto import tqdm
19
18
 
@@ -25,7 +24,7 @@ HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
25
24
  USER_AGENT = "mindee/doctr"
26
25
 
27
26
 
28
- def _urlretrieve(url: str, filename: Union[Path, str], chunk_size: int = 1024) -> None:
27
+ def _urlretrieve(url: str, filename: Path | str, chunk_size: int = 1024) -> None:
29
28
  with open(filename, "wb") as fh:
30
29
  with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
31
30
  with tqdm(total=response.length) as pbar:
@@ -36,7 +35,7 @@ def _urlretrieve(url: str, filename: Union[Path, str], chunk_size: int = 1024) -
36
35
  fh.write(chunk)
37
36
 
38
37
 
39
- def _check_integrity(file_path: Union[str, Path], hash_prefix: str) -> bool:
38
+ def _check_integrity(file_path: str | Path, hash_prefix: str) -> bool:
40
39
  with open(file_path, "rb") as f:
41
40
  sha_hash = hashlib.sha256(f.read()).hexdigest()
42
41
 
@@ -45,10 +44,10 @@ def _check_integrity(file_path: Union[str, Path], hash_prefix: str) -> bool:
45
44
 
46
45
  def download_from_url(
47
46
  url: str,
48
- file_name: Optional[str] = None,
49
- hash_prefix: Optional[str] = None,
50
- cache_dir: Optional[str] = None,
51
- cache_subdir: Optional[str] = None,
47
+ file_name: str | None = None,
48
+ hash_prefix: str | None = None,
49
+ cache_dir: str | None = None,
50
+ cache_subdir: str | None = None,
52
51
  ) -> Path:
53
52
  """Download a file using its URL
54
53
 
@@ -56,7 +55,6 @@ def download_from_url(
56
55
  >>> download_from_url("https://yoursource.com/yourcheckpoint-yourhash.zip")
57
56
 
58
57
  Args:
59
- ----
60
58
  url: the URL of the file to download
61
59
  file_name: optional name of the file once downloaded
62
60
  hash_prefix: optional expected SHA256 hash of the file
@@ -64,11 +62,9 @@ def download_from_url(
64
62
  cache_subdir: subfolder to use in the cache
65
63
 
66
64
  Returns:
67
- -------
68
65
  the location of the downloaded file
69
66
 
70
67
  Note:
71
- ----
72
68
  You can change cache directory location by using `DOCTR_CACHE_DIR` environment variable.
73
69
  """
74
70
  if not isinstance(file_name, str):
@@ -96,7 +92,7 @@ def download_from_url(
96
92
  # Create folder hierarchy
97
93
  folder_path.mkdir(parents=True, exist_ok=True)
98
94
  except OSError:
99
- error_message = f"Failed creating cache direcotry at {folder_path}"
95
+ error_message = f"Failed creating cache directory at {folder_path}"
100
96
  if os.environ.get("DOCTR_CACHE_DIR", ""):
101
97
  error_message += " using path from 'DOCTR_CACHE_DIR' environment variable."
102
98
  else:
@@ -112,7 +108,7 @@ def download_from_url(
112
108
  except (urllib.error.URLError, IOError) as e:
113
109
  if url[:5] == "https":
114
110
  url = url.replace("https:", "http:")
115
- print("Failed download. Trying https -> http instead." f" Downloading {url} to {file_path}")
111
+ print(f"Failed download. Trying https -> http instead. Downloading {url} to {file_path}")
116
112
  _urlretrieve(url, file_path)
117
113
  else:
118
114
  raise e
doctr/utils/fonts.py CHANGED
@@ -1,29 +1,24 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import logging
7
7
  import platform
8
- from typing import Optional, Union
9
8
 
10
9
  from PIL import ImageFont
11
10
 
12
11
  __all__ = ["get_font"]
13
12
 
14
13
 
15
- def get_font(
16
- font_family: Optional[str] = None, font_size: int = 13
17
- ) -> Union[ImageFont.FreeTypeFont, ImageFont.ImageFont]:
14
+ def get_font(font_family: str | None = None, font_size: int = 13) -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
18
15
  """Resolves a compatible ImageFont for the system
19
16
 
20
17
  Args:
21
- ----
22
18
  font_family: the font family to use
23
19
  font_size: the size of the font upon rendering
24
20
 
25
21
  Returns:
26
- -------
27
22
  the Pillow font
28
23
  """
29
24
  # Font selection
doctr/utils/geometry.py CHANGED
@@ -1,11 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
7
  from math import ceil
8
- from typing import List, Optional, Tuple, Union
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -34,11 +33,9 @@ def bbox_to_polygon(bbox: BoundingBox) -> Polygon4P:
34
33
  """Convert a bounding box to a polygon
35
34
 
36
35
  Args:
37
- ----
38
36
  bbox: a bounding box
39
37
 
40
38
  Returns:
41
- -------
42
39
  a polygon
43
40
  """
44
41
  return bbox[0], (bbox[1][0], bbox[0][1]), (bbox[0][0], bbox[1][1]), bbox[1]
@@ -48,31 +45,27 @@ def polygon_to_bbox(polygon: Polygon4P) -> BoundingBox:
48
45
  """Convert a polygon to a bounding box
49
46
 
50
47
  Args:
51
- ----
52
48
  polygon: a polygon
53
49
 
54
50
  Returns:
55
- -------
56
51
  a bounding box
57
52
  """
58
53
  x, y = zip(*polygon)
59
54
  return (min(x), min(y)), (max(x), max(y))
60
55
 
61
56
 
62
- def detach_scores(boxes: List[np.ndarray]) -> Tuple[List[np.ndarray], List[np.ndarray]]:
57
+ def detach_scores(boxes: list[np.ndarray]) -> tuple[list[np.ndarray], list[np.ndarray]]:
63
58
  """Detach the objectness scores from box predictions
64
59
 
65
60
  Args:
66
- ----
67
61
  boxes: list of arrays with boxes of shape (N, 5) or (N, 5, 2)
68
62
 
69
63
  Returns:
70
- -------
71
64
  a tuple of two lists: the first one contains the boxes without the objectness scores,
72
65
  the second one contains the objectness scores
73
66
  """
74
67
 
75
- def _detach(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
68
+ def _detach(boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
76
69
  if boxes.ndim == 2:
77
70
  return boxes[:, :-1], boxes[:, -1]
78
71
  return boxes[:, :-1], boxes[:, -1, -1]
@@ -81,11 +74,10 @@ def detach_scores(boxes: List[np.ndarray]) -> Tuple[List[np.ndarray], List[np.nd
81
74
  return list(loc_preds), list(obj_scores)
82
75
 
83
76
 
84
- def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Union[BoundingBox, np.ndarray]:
77
+ def resolve_enclosing_bbox(bboxes: list[BoundingBox] | np.ndarray) -> BoundingBox | np.ndarray:
85
78
  """Compute enclosing bbox either from:
86
79
 
87
80
  Args:
88
- ----
89
81
  bboxes: boxes in one of the following formats:
90
82
 
91
83
  - an array of boxes: (*, 4), where boxes have this shape:
@@ -94,7 +86,6 @@ def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Unio
94
86
  - a list of BoundingBox
95
87
 
96
88
  Returns:
97
- -------
98
89
  a (1, 4) array (enclosing boxarray), or a BoundingBox
99
90
  """
100
91
  if isinstance(bboxes, np.ndarray):
@@ -105,11 +96,10 @@ def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Unio
105
96
  return (min(x), min(y)), (max(x), max(y))
106
97
 
107
98
 
108
- def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024) -> np.ndarray:
99
+ def resolve_enclosing_rbbox(rbboxes: list[np.ndarray], intermed_size: int = 1024) -> np.ndarray:
109
100
  """Compute enclosing rotated bbox either from:
110
101
 
111
102
  Args:
112
- ----
113
103
  rbboxes: boxes in one of the following formats:
114
104
 
115
105
  - an array of boxes: (*, 4, 2), where boxes have this shape:
@@ -119,26 +109,23 @@ def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024
119
109
  intermed_size: size of the intermediate image
120
110
 
121
111
  Returns:
122
- -------
123
112
  a (4, 2) array (enclosing rotated box)
124
113
  """
125
114
  cloud: np.ndarray = np.concatenate(rbboxes, axis=0)
126
115
  # Convert to absolute for minAreaRect
127
116
  cloud *= intermed_size
128
117
  rect = cv2.minAreaRect(cloud.astype(np.int32))
129
- return cv2.boxPoints(rect) / intermed_size # type: ignore[return-value]
118
+ return cv2.boxPoints(rect) / intermed_size
130
119
 
131
120
 
132
121
  def rotate_abs_points(points: np.ndarray, angle: float = 0.0) -> np.ndarray:
133
122
  """Rotate points counter-clockwise.
134
123
 
135
124
  Args:
136
- ----
137
125
  points: array of size (N, 2)
138
126
  angle: angle between -90 and +90 degrees
139
127
 
140
128
  Returns:
141
- -------
142
129
  Rotated points
143
130
  """
144
131
  angle_rad = angle * np.pi / 180.0 # compute radian angle for np functions
@@ -148,16 +135,14 @@ def rotate_abs_points(points: np.ndarray, angle: float = 0.0) -> np.ndarray:
148
135
  return np.matmul(points, rotation_mat.T)
149
136
 
150
137
 
151
- def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[int, int]:
138
+ def compute_expanded_shape(img_shape: tuple[int, int], angle: float) -> tuple[int, int]:
152
139
  """Compute the shape of an expanded rotated image
153
140
 
154
141
  Args:
155
- ----
156
142
  img_shape: the height and width of the image
157
143
  angle: angle between -90 and +90 degrees
158
144
 
159
145
  Returns:
160
- -------
161
146
  the height and width of the rotated image
162
147
  """
163
148
  points: np.ndarray = np.array([
@@ -174,21 +159,19 @@ def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[in
174
159
  def rotate_abs_geoms(
175
160
  geoms: np.ndarray,
176
161
  angle: float,
177
- img_shape: Tuple[int, int],
162
+ img_shape: tuple[int, int],
178
163
  expand: bool = True,
179
164
  ) -> np.ndarray:
180
165
  """Rotate a batch of bounding boxes or polygons by an angle around the
181
166
  image center.
182
167
 
183
168
  Args:
184
- ----
185
169
  geoms: (N, 4) or (N, 4, 2) array of ABSOLUTE coordinate boxes
186
170
  angle: anti-clockwise rotation angle in degrees
187
171
  img_shape: the height and width of the image
188
172
  expand: whether the image should be padded to avoid information loss
189
173
 
190
174
  Returns:
191
- -------
192
175
  A batch of rotated polygons (N, 4, 2)
193
176
  """
194
177
  # Switch to polygons
@@ -214,19 +197,17 @@ def rotate_abs_geoms(
214
197
  return rotated_polys
215
198
 
216
199
 
217
- def remap_boxes(loc_preds: np.ndarray, orig_shape: Tuple[int, int], dest_shape: Tuple[int, int]) -> np.ndarray:
200
+ def remap_boxes(loc_preds: np.ndarray, orig_shape: tuple[int, int], dest_shape: tuple[int, int]) -> np.ndarray:
218
201
  """Remaps a batch of rotated locpred (N, 4, 2) expressed for an origin_shape to a destination_shape.
219
202
  This does not impact the absolute shape of the boxes, but allow to calculate the new relative RotatedBbox
220
203
  coordinates after a resizing of the image.
221
204
 
222
205
  Args:
223
- ----
224
206
  loc_preds: (N, 4, 2) array of RELATIVE loc_preds
225
207
  orig_shape: shape of the origin image
226
208
  dest_shape: shape of the destination image
227
209
 
228
210
  Returns:
229
- -------
230
211
  A batch of rotated loc_preds (N, 4, 2) expressed in the destination referencial
231
212
  """
232
213
  if len(dest_shape) != 2:
@@ -245,9 +226,9 @@ def remap_boxes(loc_preds: np.ndarray, orig_shape: Tuple[int, int], dest_shape:
245
226
  def rotate_boxes(
246
227
  loc_preds: np.ndarray,
247
228
  angle: float,
248
- orig_shape: Tuple[int, int],
229
+ orig_shape: tuple[int, int],
249
230
  min_angle: float = 1.0,
250
- target_shape: Optional[Tuple[int, int]] = None,
231
+ target_shape: tuple[int, int] | None = None,
251
232
  ) -> np.ndarray:
252
233
  """Rotate a batch of straight bounding boxes (xmin, ymin, xmax, ymax, c) or rotated bounding boxes
253
234
  (4, 2) of an angle, if angle > min_angle, around the center of the page.
@@ -255,7 +236,6 @@ def rotate_boxes(
255
236
  is done to remove the padding that is created by rotate_page(expand=True)
256
237
 
257
238
  Args:
258
- ----
259
239
  loc_preds: (N, 4) or (N, 4, 2) array of RELATIVE boxes
260
240
  angle: angle between -90 and +90 degrees
261
241
  orig_shape: shape of the origin image
@@ -263,7 +243,6 @@ def rotate_boxes(
263
243
  target_shape: shape of the destination image
264
244
 
265
245
  Returns:
266
- -------
267
246
  A batch of rotated boxes (N, 4, 2): or a batch of straight bounding boxes
268
247
  """
269
248
  # Change format of the boxes to rotated boxes
@@ -310,20 +289,18 @@ def rotate_image(
310
289
  """Rotate an image counterclockwise by an given angle.
311
290
 
312
291
  Args:
313
- ----
314
292
  image: numpy tensor to rotate
315
293
  angle: rotation angle in degrees, between -90 and +90
316
294
  expand: whether the image should be padded before the rotation
317
295
  preserve_origin_shape: if expand is set to True, resizes the final output to the original image size
318
296
 
319
297
  Returns:
320
- -------
321
298
  Rotated array, padded by 0 by default.
322
299
  """
323
300
  # Compute the expanded padding
324
301
  exp_img: np.ndarray
325
302
  if expand:
326
- exp_shape = compute_expanded_shape(image.shape[:2], angle) # type: ignore[arg-type]
303
+ exp_shape = compute_expanded_shape(image.shape[:2], angle)
327
304
  h_pad, w_pad = (
328
305
  int(max(0, ceil(exp_shape[0] - image.shape[0]))),
329
306
  int(max(0, ceil(exp_shape[1] - image.shape[1]))),
@@ -344,7 +321,7 @@ def rotate_image(
344
321
  # Pad height
345
322
  else:
346
323
  h_pad, w_pad = int(rot_img.shape[1] * image.shape[0] / image.shape[1] - rot_img.shape[0]), 0
347
- rot_img = np.pad(rot_img, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0))) # type: ignore[assignment]
324
+ rot_img = np.pad(rot_img, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0)))
348
325
  if preserve_origin_shape:
349
326
  # rescale
350
327
  rot_img = cv2.resize(rot_img, image.shape[:-1][::-1], interpolation=cv2.INTER_LINEAR)
@@ -356,11 +333,9 @@ def remove_image_padding(image: np.ndarray) -> np.ndarray:
356
333
  """Remove black border padding from an image
357
334
 
358
335
  Args:
359
- ----
360
336
  image: numpy tensor to remove padding from
361
337
 
362
338
  Returns:
363
- -------
364
339
  Image with padding removed
365
340
  """
366
341
  # Find the bounding box of the non-black region
@@ -390,16 +365,14 @@ def estimate_page_angle(polys: np.ndarray) -> float:
390
365
  return 0.0
391
366
 
392
367
 
393
- def convert_to_relative_coords(geoms: np.ndarray, img_shape: Tuple[int, int]) -> np.ndarray:
368
+ def convert_to_relative_coords(geoms: np.ndarray, img_shape: tuple[int, int]) -> np.ndarray:
394
369
  """Convert a geometry to relative coordinates
395
370
 
396
371
  Args:
397
- ----
398
372
  geoms: a set of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4)
399
373
  img_shape: the height and width of the image
400
374
 
401
375
  Returns:
402
- -------
403
376
  the updated geometry
404
377
  """
405
378
  # Polygon
@@ -417,18 +390,16 @@ def convert_to_relative_coords(geoms: np.ndarray, img_shape: Tuple[int, int]) ->
417
390
  raise ValueError(f"invalid format for arg `geoms`: {geoms.shape}")
418
391
 
419
392
 
420
- def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) -> List[np.ndarray]:
393
+ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) -> list[np.ndarray]:
421
394
  """Created cropped images from list of bounding boxes
422
395
 
423
396
  Args:
424
- ----
425
397
  img: input image
426
398
  boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative
427
399
  coordinates (xmin, ymin, xmax, ymax)
428
400
  channels_last: whether the channel dimensions is the last one instead of the last one
429
401
 
430
402
  Returns:
431
- -------
432
403
  list of cropped images
433
404
  """
434
405
  if boxes.shape[0] == 0:
@@ -453,11 +424,10 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True
453
424
 
454
425
  def extract_rcrops(
455
426
  img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True, assume_horizontal: bool = False
456
- ) -> List[np.ndarray]:
427
+ ) -> list[np.ndarray]:
457
428
  """Created cropped images from list of rotated bounding boxes
458
429
 
459
430
  Args:
460
- ----
461
431
  img: input image
462
432
  polys: bounding boxes of shape (N, 4, 2)
463
433
  dtype: target data type of bounding boxes
@@ -465,7 +435,6 @@ def extract_rcrops(
465
435
  assume_horizontal: whether the boxes are assumed to be only horizontally oriented
466
436
 
467
437
  Returns:
468
- -------
469
438
  list of cropped images
470
439
  """
471
440
  if polys.shape[0] == 0:
@@ -563,4 +532,4 @@ def extract_rcrops(
563
532
  )
564
533
  for idx in range(_boxes.shape[0])
565
534
  ]
566
- return crops # type: ignore[return-value]
535
+ return crops
doctr/utils/metrics.py CHANGED
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Dict, List, Optional, Tuple
7
6
 
8
7
  import numpy as np
9
8
  from anyascii import anyascii
@@ -21,16 +20,14 @@ __all__ = [
21
20
  ]
22
21
 
23
22
 
24
- def string_match(word1: str, word2: str) -> Tuple[bool, bool, bool, bool]:
23
+ def string_match(word1: str, word2: str) -> tuple[bool, bool, bool, bool]:
25
24
  """Performs string comparison with multiple levels of tolerance
26
25
 
27
26
  Args:
28
- ----
29
27
  word1: a string
30
28
  word2: another string
31
29
 
32
30
  Returns:
33
- -------
34
31
  a tuple with booleans specifying respectively whether the raw strings, their lower-case counterparts, their
35
32
  anyascii counterparts and their lower-case anyascii counterparts match
36
33
  """
@@ -78,13 +75,12 @@ class TextMatch:
78
75
 
79
76
  def update(
80
77
  self,
81
- gt: List[str],
82
- pred: List[str],
78
+ gt: list[str],
79
+ pred: list[str],
83
80
  ) -> None:
84
81
  """Update the state of the metric with new predictions
85
82
 
86
83
  Args:
87
- ----
88
84
  gt: list of groung-truth character sequences
89
85
  pred: list of predicted character sequences
90
86
  """
@@ -100,11 +96,10 @@ class TextMatch:
100
96
 
101
97
  self.total += len(gt)
102
98
 
103
- def summary(self) -> Dict[str, float]:
99
+ def summary(self) -> dict[str, float]:
104
100
  """Computes the aggregated metrics
105
101
 
106
- Returns
107
- -------
102
+ Returns:
108
103
  a dictionary with the exact match score for the raw data, its lower-case counterpart, its anyascii
109
104
  counterpart and its lower-case anyascii counterpart
110
105
  """
@@ -130,12 +125,10 @@ def box_iou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray:
130
125
  """Computes the IoU between two sets of bounding boxes
131
126
 
132
127
  Args:
133
- ----
134
128
  boxes_1: bounding boxes of shape (N, 4) in format (xmin, ymin, xmax, ymax)
135
129
  boxes_2: bounding boxes of shape (M, 4) in format (xmin, ymin, xmax, ymax)
136
130
 
137
131
  Returns:
138
- -------
139
132
  the IoU matrix of shape (N, M)
140
133
  """
141
134
  iou_mat: np.ndarray = np.zeros((boxes_1.shape[0], boxes_2.shape[0]), dtype=np.float32)
@@ -160,14 +153,12 @@ def polygon_iou(polys_1: np.ndarray, polys_2: np.ndarray) -> np.ndarray:
160
153
  """Computes the IoU between two sets of rotated bounding boxes
161
154
 
162
155
  Args:
163
- ----
164
156
  polys_1: rotated bounding boxes of shape (N, 4, 2)
165
157
  polys_2: rotated bounding boxes of shape (M, 4, 2)
166
158
  mask_shape: spatial shape of the intermediate masks
167
159
  use_broadcasting: if set to True, leverage broadcasting speedup by consuming more memory
168
160
 
169
161
  Returns:
170
- -------
171
162
  the IoU matrix of shape (N, M)
172
163
  """
173
164
  if polys_1.ndim != 3 or polys_2.ndim != 3:
@@ -187,16 +178,14 @@ def polygon_iou(polys_1: np.ndarray, polys_2: np.ndarray) -> np.ndarray:
187
178
  return iou_mat
188
179
 
189
180
 
190
- def nms(boxes: np.ndarray, thresh: float = 0.5) -> List[int]:
181
+ def nms(boxes: np.ndarray, thresh: float = 0.5) -> list[int]:
191
182
  """Perform non-max suppression, borrowed from <https://github.com/rbgirshick/fast-rcnn>`_.
192
183
 
193
184
  Args:
194
- ----
195
185
  boxes: np array of straight boxes: (*, 5), (xmin, ymin, xmax, ymax, score)
196
186
  thresh: iou threshold to perform box suppression.
197
187
 
198
188
  Returns:
199
- -------
200
189
  A list of box indexes to keep
201
190
  """
202
191
  x1 = boxes[:, 0]
@@ -260,7 +249,6 @@ class LocalizationConfusion:
260
249
  >>> metric.summary()
261
250
 
262
251
  Args:
263
- ----
264
252
  iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match
265
253
  use_polygons: if set to True, predictions and targets will be expected to have rotated format
266
254
  """
@@ -278,7 +266,6 @@ class LocalizationConfusion:
278
266
  """Updates the metric
279
267
 
280
268
  Args:
281
- ----
282
269
  gts: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones
283
270
  preds: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones
284
271
  """
@@ -298,11 +285,10 @@ class LocalizationConfusion:
298
285
  self.num_gts += gts.shape[0]
299
286
  self.num_preds += preds.shape[0]
300
287
 
301
- def summary(self) -> Tuple[Optional[float], Optional[float], Optional[float]]:
288
+ def summary(self) -> tuple[float | None, float | None, float | None]:
302
289
  """Computes the aggregated metrics
303
290
 
304
- Returns
305
- -------
291
+ Returns:
306
292
  a tuple with the recall, precision and meanIoU scores
307
293
  """
308
294
  # Recall
@@ -360,7 +346,6 @@ class OCRMetric:
360
346
  >>> metric.summary()
361
347
 
362
348
  Args:
363
- ----
364
349
  iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match
365
350
  use_polygons: if set to True, predictions and targets will be expected to have rotated format
366
351
  """
@@ -378,13 +363,12 @@ class OCRMetric:
378
363
  self,
379
364
  gt_boxes: np.ndarray,
380
365
  pred_boxes: np.ndarray,
381
- gt_labels: List[str],
382
- pred_labels: List[str],
366
+ gt_labels: list[str],
367
+ pred_labels: list[str],
383
368
  ) -> None:
384
369
  """Updates the metric
385
370
 
386
371
  Args:
387
- ----
388
372
  gt_boxes: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones
389
373
  pred_boxes: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones
390
374
  gt_labels: a list of N string labels
@@ -392,7 +376,7 @@ class OCRMetric:
392
376
  """
393
377
  if gt_boxes.shape[0] != len(gt_labels) or pred_boxes.shape[0] != len(pred_labels):
394
378
  raise AssertionError(
395
- "there should be the same number of boxes and string both for the ground truth " "and the predictions"
379
+ "there should be the same number of boxes and string both for the ground truth and the predictions"
396
380
  )
397
381
 
398
382
  # Compute IoU
@@ -418,11 +402,10 @@ class OCRMetric:
418
402
  self.num_gts += gt_boxes.shape[0]
419
403
  self.num_preds += pred_boxes.shape[0]
420
404
 
421
- def summary(self) -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]], Optional[float]]:
405
+ def summary(self) -> tuple[dict[str, float | None], dict[str, float | None], float | None]:
422
406
  """Computes the aggregated metrics
423
407
 
424
- Returns
425
- -------
408
+ Returns:
426
409
  a tuple with the recall & precision for each string comparison and the mean IoU
427
410
  """
428
411
  # Recall
@@ -493,7 +476,6 @@ class DetectionMetric:
493
476
  >>> metric.summary()
494
477
 
495
478
  Args:
496
- ----
497
479
  iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match
498
480
  use_polygons: if set to True, predictions and targets will be expected to have rotated format
499
481
  """
@@ -517,7 +499,6 @@ class DetectionMetric:
517
499
  """Updates the metric
518
500
 
519
501
  Args:
520
- ----
521
502
  gt_boxes: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones
522
503
  pred_boxes: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones
523
504
  gt_labels: an array of class indices of shape (N,)
@@ -525,7 +506,7 @@ class DetectionMetric:
525
506
  """
526
507
  if gt_boxes.shape[0] != gt_labels.shape[0] or pred_boxes.shape[0] != pred_labels.shape[0]:
527
508
  raise AssertionError(
528
- "there should be the same number of boxes and string both for the ground truth " "and the predictions"
509
+ "there should be the same number of boxes and string both for the ground truth and the predictions"
529
510
  )
530
511
 
531
512
  # Compute IoU
@@ -546,11 +527,10 @@ class DetectionMetric:
546
527
  self.num_gts += gt_boxes.shape[0]
547
528
  self.num_preds += pred_boxes.shape[0]
548
529
 
549
- def summary(self) -> Tuple[Optional[float], Optional[float], Optional[float]]:
530
+ def summary(self) -> tuple[float | None, float | None, float | None]:
550
531
  """Computes the aggregated metrics
551
532
 
552
- Returns
553
- -------
533
+ Returns:
554
534
  a tuple with the recall & precision for each class prediction and the mean IoU
555
535
  """
556
536
  # Recall
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,15 +6,16 @@
6
6
 
7
7
  import multiprocessing as mp
8
8
  import os
9
+ from collections.abc import Callable, Iterable, Iterator
9
10
  from multiprocessing.pool import ThreadPool
10
- from typing import Any, Callable, Iterable, Iterator, Optional
11
+ from typing import Any
11
12
 
12
13
  from doctr.file_utils import ENV_VARS_TRUE_VALUES
13
14
 
14
15
  __all__ = ["multithread_exec"]
15
16
 
16
17
 
17
- def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Optional[int] = None) -> Iterator[Any]:
18
+ def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: int | None = None) -> Iterator[Any]:
18
19
  """Execute a given function in parallel for each element of a given sequence
19
20
 
20
21
  >>> from doctr.utils.multithreading import multithread_exec
@@ -22,17 +23,14 @@ def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Op
22
23
  >>> results = multithread_exec(lambda x: x ** 2, entries)
23
24
 
24
25
  Args:
25
- ----
26
26
  func: function to be executed on each element of the iterable
27
27
  seq: iterable
28
28
  threads: number of workers to be used for multiprocessing
29
29
 
30
30
  Returns:
31
- -------
32
31
  iterator of the function's results using the iterable as inputs
33
32
 
34
33
  Notes:
35
- -----
36
34
  This function uses ThreadPool from multiprocessing package, which uses `/dev/shm` directory for shared memory.
37
35
  If you do not have write permissions for this directory (if you run `doctr` on AWS Lambda for instance),
38
36
  you might want to disable multiprocessing. To achieve that, set 'DOCTR_MULTIPROCESSING_DISABLE' to 'TRUE'.