python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
doctr/utils/geometry.py CHANGED
@@ -1,11 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
7
  from math import ceil
8
- from typing import List, Optional, Tuple, Union
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -20,6 +19,7 @@ __all__ = [
20
19
  "rotate_boxes",
21
20
  "compute_expanded_shape",
22
21
  "rotate_image",
22
+ "remove_image_padding",
23
23
  "estimate_page_angle",
24
24
  "convert_to_relative_coords",
25
25
  "rotate_abs_geoms",
@@ -33,11 +33,9 @@ def bbox_to_polygon(bbox: BoundingBox) -> Polygon4P:
33
33
  """Convert a bounding box to a polygon
34
34
 
35
35
  Args:
36
- ----
37
36
  bbox: a bounding box
38
37
 
39
38
  Returns:
40
- -------
41
39
  a polygon
42
40
  """
43
41
  return bbox[0], (bbox[1][0], bbox[0][1]), (bbox[0][0], bbox[1][1]), bbox[1]
@@ -47,31 +45,27 @@ def polygon_to_bbox(polygon: Polygon4P) -> BoundingBox:
47
45
  """Convert a polygon to a bounding box
48
46
 
49
47
  Args:
50
- ----
51
48
  polygon: a polygon
52
49
 
53
50
  Returns:
54
- -------
55
51
  a bounding box
56
52
  """
57
53
  x, y = zip(*polygon)
58
54
  return (min(x), min(y)), (max(x), max(y))
59
55
 
60
56
 
61
- def detach_scores(boxes: List[np.ndarray]) -> Tuple[List[np.ndarray], List[np.ndarray]]:
57
+ def detach_scores(boxes: list[np.ndarray]) -> tuple[list[np.ndarray], list[np.ndarray]]:
62
58
  """Detach the objectness scores from box predictions
63
59
 
64
60
  Args:
65
- ----
66
61
  boxes: list of arrays with boxes of shape (N, 5) or (N, 5, 2)
67
62
 
68
63
  Returns:
69
- -------
70
64
  a tuple of two lists: the first one contains the boxes without the objectness scores,
71
65
  the second one contains the objectness scores
72
66
  """
73
67
 
74
- def _detach(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
68
+ def _detach(boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
75
69
  if boxes.ndim == 2:
76
70
  return boxes[:, :-1], boxes[:, -1]
77
71
  return boxes[:, :-1], boxes[:, -1, -1]
@@ -80,11 +74,10 @@ def detach_scores(boxes: List[np.ndarray]) -> Tuple[List[np.ndarray], List[np.nd
80
74
  return list(loc_preds), list(obj_scores)
81
75
 
82
76
 
83
- def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Union[BoundingBox, np.ndarray]:
77
+ def resolve_enclosing_bbox(bboxes: list[BoundingBox] | np.ndarray) -> BoundingBox | np.ndarray:
84
78
  """Compute enclosing bbox either from:
85
79
 
86
80
  Args:
87
- ----
88
81
  bboxes: boxes in one of the following formats:
89
82
 
90
83
  - an array of boxes: (*, 4), where boxes have this shape:
@@ -93,7 +86,6 @@ def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Unio
93
86
  - a list of BoundingBox
94
87
 
95
88
  Returns:
96
- -------
97
89
  a (1, 4) array (enclosing boxarray), or a BoundingBox
98
90
  """
99
91
  if isinstance(bboxes, np.ndarray):
@@ -104,11 +96,10 @@ def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Unio
104
96
  return (min(x), min(y)), (max(x), max(y))
105
97
 
106
98
 
107
- def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024) -> np.ndarray:
99
+ def resolve_enclosing_rbbox(rbboxes: list[np.ndarray], intermed_size: int = 1024) -> np.ndarray:
108
100
  """Compute enclosing rotated bbox either from:
109
101
 
110
102
  Args:
111
- ----
112
103
  rbboxes: boxes in one of the following formats:
113
104
 
114
105
  - an array of boxes: (*, 4, 2), where boxes have this shape:
@@ -118,26 +109,23 @@ def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024
118
109
  intermed_size: size of the intermediate image
119
110
 
120
111
  Returns:
121
- -------
122
112
  a (4, 2) array (enclosing rotated box)
123
113
  """
124
114
  cloud: np.ndarray = np.concatenate(rbboxes, axis=0)
125
115
  # Convert to absolute for minAreaRect
126
116
  cloud *= intermed_size
127
117
  rect = cv2.minAreaRect(cloud.astype(np.int32))
128
- return cv2.boxPoints(rect) / intermed_size # type: ignore[return-value]
118
+ return cv2.boxPoints(rect) / intermed_size
129
119
 
130
120
 
131
121
  def rotate_abs_points(points: np.ndarray, angle: float = 0.0) -> np.ndarray:
132
122
  """Rotate points counter-clockwise.
133
123
 
134
124
  Args:
135
- ----
136
125
  points: array of size (N, 2)
137
126
  angle: angle between -90 and +90 degrees
138
127
 
139
128
  Returns:
140
- -------
141
129
  Rotated points
142
130
  """
143
131
  angle_rad = angle * np.pi / 180.0 # compute radian angle for np functions
@@ -147,16 +135,14 @@ def rotate_abs_points(points: np.ndarray, angle: float = 0.0) -> np.ndarray:
147
135
  return np.matmul(points, rotation_mat.T)
148
136
 
149
137
 
150
- def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[int, int]:
138
+ def compute_expanded_shape(img_shape: tuple[int, int], angle: float) -> tuple[int, int]:
151
139
  """Compute the shape of an expanded rotated image
152
140
 
153
141
  Args:
154
- ----
155
142
  img_shape: the height and width of the image
156
143
  angle: angle between -90 and +90 degrees
157
144
 
158
145
  Returns:
159
- -------
160
146
  the height and width of the rotated image
161
147
  """
162
148
  points: np.ndarray = np.array([
@@ -173,21 +159,19 @@ def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[in
173
159
  def rotate_abs_geoms(
174
160
  geoms: np.ndarray,
175
161
  angle: float,
176
- img_shape: Tuple[int, int],
162
+ img_shape: tuple[int, int],
177
163
  expand: bool = True,
178
164
  ) -> np.ndarray:
179
165
  """Rotate a batch of bounding boxes or polygons by an angle around the
180
166
  image center.
181
167
 
182
168
  Args:
183
- ----
184
169
  geoms: (N, 4) or (N, 4, 2) array of ABSOLUTE coordinate boxes
185
170
  angle: anti-clockwise rotation angle in degrees
186
171
  img_shape: the height and width of the image
187
172
  expand: whether the image should be padded to avoid information loss
188
173
 
189
174
  Returns:
190
- -------
191
175
  A batch of rotated polygons (N, 4, 2)
192
176
  """
193
177
  # Switch to polygons
@@ -213,19 +197,17 @@ def rotate_abs_geoms(
213
197
  return rotated_polys
214
198
 
215
199
 
216
- def remap_boxes(loc_preds: np.ndarray, orig_shape: Tuple[int, int], dest_shape: Tuple[int, int]) -> np.ndarray:
200
+ def remap_boxes(loc_preds: np.ndarray, orig_shape: tuple[int, int], dest_shape: tuple[int, int]) -> np.ndarray:
217
201
  """Remaps a batch of rotated locpred (N, 4, 2) expressed for an origin_shape to a destination_shape.
218
202
  This does not impact the absolute shape of the boxes, but allow to calculate the new relative RotatedBbox
219
203
  coordinates after a resizing of the image.
220
204
 
221
205
  Args:
222
- ----
223
206
  loc_preds: (N, 4, 2) array of RELATIVE loc_preds
224
207
  orig_shape: shape of the origin image
225
208
  dest_shape: shape of the destination image
226
209
 
227
210
  Returns:
228
- -------
229
211
  A batch of rotated loc_preds (N, 4, 2) expressed in the destination referencial
230
212
  """
231
213
  if len(dest_shape) != 2:
@@ -244,9 +226,9 @@ def remap_boxes(loc_preds: np.ndarray, orig_shape: Tuple[int, int], dest_shape:
244
226
  def rotate_boxes(
245
227
  loc_preds: np.ndarray,
246
228
  angle: float,
247
- orig_shape: Tuple[int, int],
229
+ orig_shape: tuple[int, int],
248
230
  min_angle: float = 1.0,
249
- target_shape: Optional[Tuple[int, int]] = None,
231
+ target_shape: tuple[int, int] | None = None,
250
232
  ) -> np.ndarray:
251
233
  """Rotate a batch of straight bounding boxes (xmin, ymin, xmax, ymax, c) or rotated bounding boxes
252
234
  (4, 2) of an angle, if angle > min_angle, around the center of the page.
@@ -254,7 +236,6 @@ def rotate_boxes(
254
236
  is done to remove the padding that is created by rotate_page(expand=True)
255
237
 
256
238
  Args:
257
- ----
258
239
  loc_preds: (N, 4) or (N, 4, 2) array of RELATIVE boxes
259
240
  angle: angle between -90 and +90 degrees
260
241
  orig_shape: shape of the origin image
@@ -262,7 +243,6 @@ def rotate_boxes(
262
243
  target_shape: shape of the destination image
263
244
 
264
245
  Returns:
265
- -------
266
246
  A batch of rotated boxes (N, 4, 2): or a batch of straight bounding boxes
267
247
  """
268
248
  # Change format of the boxes to rotated boxes
@@ -309,14 +289,12 @@ def rotate_image(
309
289
  """Rotate an image counterclockwise by an given angle.
310
290
 
311
291
  Args:
312
- ----
313
292
  image: numpy tensor to rotate
314
293
  angle: rotation angle in degrees, between -90 and +90
315
294
  expand: whether the image should be padded before the rotation
316
295
  preserve_origin_shape: if expand is set to True, resizes the final output to the original image size
317
296
 
318
297
  Returns:
319
- -------
320
298
  Rotated array, padded by 0 by default.
321
299
  """
322
300
  # Compute the expanded padding
@@ -343,7 +321,7 @@ def rotate_image(
343
321
  # Pad height
344
322
  else:
345
323
  h_pad, w_pad = int(rot_img.shape[1] * image.shape[0] / image.shape[1] - rot_img.shape[0]), 0
346
- rot_img = np.pad(rot_img, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0))) # type: ignore[assignment]
324
+ rot_img = np.pad(rot_img, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0)))
347
325
  if preserve_origin_shape:
348
326
  # rescale
349
327
  rot_img = cv2.resize(rot_img, image.shape[:-1][::-1], interpolation=cv2.INTER_LINEAR)
@@ -351,6 +329,24 @@ def rotate_image(
351
329
  return rot_img
352
330
 
353
331
 
332
+ def remove_image_padding(image: np.ndarray) -> np.ndarray:
333
+ """Remove black border padding from an image
334
+
335
+ Args:
336
+ image: numpy tensor to remove padding from
337
+
338
+ Returns:
339
+ Image with padding removed
340
+ """
341
+ # Find the bounding box of the non-black region
342
+ rows = np.any(image, axis=1)
343
+ cols = np.any(image, axis=0)
344
+ rmin, rmax = np.where(rows)[0][[0, -1]]
345
+ cmin, cmax = np.where(cols)[0][[0, -1]]
346
+
347
+ return image[rmin : rmax + 1, cmin : cmax + 1]
348
+
349
+
354
350
  def estimate_page_angle(polys: np.ndarray) -> float:
355
351
  """Takes a batch of rotated previously ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the
356
352
  estimated angle ccw in degrees
@@ -369,16 +365,14 @@ def estimate_page_angle(polys: np.ndarray) -> float:
369
365
  return 0.0
370
366
 
371
367
 
372
- def convert_to_relative_coords(geoms: np.ndarray, img_shape: Tuple[int, int]) -> np.ndarray:
368
+ def convert_to_relative_coords(geoms: np.ndarray, img_shape: tuple[int, int]) -> np.ndarray:
373
369
  """Convert a geometry to relative coordinates
374
370
 
375
371
  Args:
376
- ----
377
372
  geoms: a set of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4)
378
373
  img_shape: the height and width of the image
379
374
 
380
375
  Returns:
381
- -------
382
376
  the updated geometry
383
377
  """
384
378
  # Polygon
@@ -396,18 +390,16 @@ def convert_to_relative_coords(geoms: np.ndarray, img_shape: Tuple[int, int]) ->
396
390
  raise ValueError(f"invalid format for arg `geoms`: {geoms.shape}")
397
391
 
398
392
 
399
- def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) -> List[np.ndarray]:
393
+ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) -> list[np.ndarray]:
400
394
  """Created cropped images from list of bounding boxes
401
395
 
402
396
  Args:
403
- ----
404
397
  img: input image
405
398
  boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative
406
399
  coordinates (xmin, ymin, xmax, ymax)
407
400
  channels_last: whether the channel dimensions is the last one instead of the last one
408
401
 
409
402
  Returns:
410
- -------
411
403
  list of cropped images
412
404
  """
413
405
  if boxes.shape[0] == 0:
@@ -431,19 +423,18 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True
431
423
 
432
424
 
433
425
  def extract_rcrops(
434
- img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True
435
- ) -> List[np.ndarray]:
426
+ img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True, assume_horizontal: bool = False
427
+ ) -> list[np.ndarray]:
436
428
  """Created cropped images from list of rotated bounding boxes
437
429
 
438
430
  Args:
439
- ----
440
431
  img: input image
441
432
  polys: bounding boxes of shape (N, 4, 2)
442
433
  dtype: target data type of bounding boxes
443
434
  channels_last: whether the channel dimensions is the last one instead of the last one
435
+ assume_horizontal: whether the boxes are assumed to be only horizontally oriented
444
436
 
445
437
  Returns:
446
- -------
447
438
  list of cropped images
448
439
  """
449
440
  if polys.shape[0] == 0:
@@ -458,22 +449,87 @@ def extract_rcrops(
458
449
  _boxes[:, :, 0] *= width
459
450
  _boxes[:, :, 1] *= height
460
451
 
461
- src_pts = _boxes[:, :3].astype(np.float32)
462
- # Preserve size
463
- d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1)
464
- d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1)
465
- # (N, 3, 2)
466
- dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype)
467
- dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1
468
- dst_pts[:, 2, 1] = d2 - 1
469
- # Use a warp transformation to extract the crop
470
- crops = [
471
- cv2.warpAffine(
472
- img if channels_last else img.transpose(1, 2, 0),
473
- # Transformation matrix
474
- cv2.getAffineTransform(src_pts[idx], dst_pts[idx]),
475
- (int(d1[idx]), int(d2[idx])),
476
- )
477
- for idx in range(_boxes.shape[0])
478
- ]
479
- return crops # type: ignore[return-value]
452
+ src_img = img if channels_last else img.transpose(1, 2, 0)
453
+
454
+ # Handle only horizontal oriented boxes
455
+ if assume_horizontal:
456
+ crops = []
457
+
458
+ for box in _boxes:
459
+ # Calculate the centroid of the quadrilateral
460
+ centroid = np.mean(box, axis=0)
461
+
462
+ # Divide the points into left and right
463
+ left_points = box[box[:, 0] < centroid[0]]
464
+ right_points = box[box[:, 0] >= centroid[0]]
465
+
466
+ # Sort the left points according to the y-axis
467
+ left_points = left_points[np.argsort(left_points[:, 1])]
468
+ top_left_pt = left_points[0]
469
+ bottom_left_pt = left_points[-1]
470
+ # Sort the right points according to the y-axis
471
+ right_points = right_points[np.argsort(right_points[:, 1])]
472
+ top_right_pt = right_points[0]
473
+ bottom_right_pt = right_points[-1]
474
+ box_points = np.array(
475
+ [top_left_pt, bottom_left_pt, top_right_pt, bottom_right_pt],
476
+ dtype=dtype,
477
+ )
478
+
479
+ # Get the width and height of the rectangle that will contain the warped quadrilateral
480
+ width_upper = np.linalg.norm(top_right_pt - top_left_pt)
481
+ width_lower = np.linalg.norm(bottom_right_pt - bottom_left_pt)
482
+ height_left = np.linalg.norm(bottom_left_pt - top_left_pt)
483
+ height_right = np.linalg.norm(bottom_right_pt - top_right_pt)
484
+
485
+ # Get the maximum width and height
486
+ rect_width = max(int(width_upper), int(width_lower))
487
+ rect_height = max(int(height_left), int(height_right))
488
+
489
+ dst_pts = np.array(
490
+ [
491
+ [0, 0], # top-left
492
+ # bottom-left
493
+ [0, rect_height - 1],
494
+ # top-right
495
+ [rect_width - 1, 0],
496
+ # bottom-right
497
+ [rect_width - 1, rect_height - 1],
498
+ ],
499
+ dtype=dtype,
500
+ )
501
+
502
+ # Get the perspective transform matrix using the box points
503
+ affine_mat = cv2.getPerspectiveTransform(box_points, dst_pts)
504
+
505
+ # Perform the perspective warp to get the rectified crop
506
+ crop = cv2.warpPerspective(
507
+ src_img,
508
+ affine_mat,
509
+ (rect_width, rect_height),
510
+ )
511
+
512
+ # Add the crop to the list of crops
513
+ crops.append(crop)
514
+
515
+ # Handle any oriented boxes
516
+ else:
517
+ src_pts = _boxes[:, :3].astype(np.float32)
518
+ # Preserve size
519
+ d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1)
520
+ d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1)
521
+ # (N, 3, 2)
522
+ dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype)
523
+ dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1
524
+ dst_pts[:, 2, 1] = d2 - 1
525
+ # Use a warp transformation to extract the crop
526
+ crops = [
527
+ cv2.warpAffine(
528
+ src_img,
529
+ # Transformation matrix
530
+ cv2.getAffineTransform(src_pts[idx], dst_pts[idx]),
531
+ (int(d1[idx]), int(d2[idx])),
532
+ )
533
+ for idx in range(_boxes.shape[0])
534
+ ]
535
+ return crops
doctr/utils/metrics.py CHANGED
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Dict, List, Optional, Tuple
7
6
 
8
7
  import numpy as np
9
8
  from anyascii import anyascii
@@ -21,16 +20,14 @@ __all__ = [
21
20
  ]
22
21
 
23
22
 
24
- def string_match(word1: str, word2: str) -> Tuple[bool, bool, bool, bool]:
23
+ def string_match(word1: str, word2: str) -> tuple[bool, bool, bool, bool]:
25
24
  """Performs string comparison with multiple levels of tolerance
26
25
 
27
26
  Args:
28
- ----
29
27
  word1: a string
30
28
  word2: another string
31
29
 
32
30
  Returns:
33
- -------
34
31
  a tuple with booleans specifying respectively whether the raw strings, their lower-case counterparts, their
35
32
  anyascii counterparts and their lower-case anyascii counterparts match
36
33
  """
@@ -78,13 +75,12 @@ class TextMatch:
78
75
 
79
76
  def update(
80
77
  self,
81
- gt: List[str],
82
- pred: List[str],
78
+ gt: list[str],
79
+ pred: list[str],
83
80
  ) -> None:
84
81
  """Update the state of the metric with new predictions
85
82
 
86
83
  Args:
87
- ----
88
84
  gt: list of groung-truth character sequences
89
85
  pred: list of predicted character sequences
90
86
  """
@@ -100,11 +96,10 @@ class TextMatch:
100
96
 
101
97
  self.total += len(gt)
102
98
 
103
- def summary(self) -> Dict[str, float]:
99
+ def summary(self) -> dict[str, float]:
104
100
  """Computes the aggregated metrics
105
101
 
106
- Returns
107
- -------
102
+ Returns:
108
103
  a dictionary with the exact match score for the raw data, its lower-case counterpart, its anyascii
109
104
  counterpart and its lower-case anyascii counterpart
110
105
  """
@@ -130,12 +125,10 @@ def box_iou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray:
130
125
  """Computes the IoU between two sets of bounding boxes
131
126
 
132
127
  Args:
133
- ----
134
128
  boxes_1: bounding boxes of shape (N, 4) in format (xmin, ymin, xmax, ymax)
135
129
  boxes_2: bounding boxes of shape (M, 4) in format (xmin, ymin, xmax, ymax)
136
130
 
137
131
  Returns:
138
- -------
139
132
  the IoU matrix of shape (N, M)
140
133
  """
141
134
  iou_mat: np.ndarray = np.zeros((boxes_1.shape[0], boxes_2.shape[0]), dtype=np.float32)
@@ -149,7 +142,7 @@ def box_iou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray:
149
142
  right = np.minimum(r1, r2.T)
150
143
  bot = np.minimum(b1, b2.T)
151
144
 
152
- intersection = np.clip(right - left, 0, np.Inf) * np.clip(bot - top, 0, np.Inf)
145
+ intersection = np.clip(right - left, 0, np.inf) * np.clip(bot - top, 0, np.inf)
153
146
  union = (r1 - l1) * (b1 - t1) + ((r2 - l2) * (b2 - t2)).T - intersection
154
147
  iou_mat = intersection / union
155
148
 
@@ -160,14 +153,12 @@ def polygon_iou(polys_1: np.ndarray, polys_2: np.ndarray) -> np.ndarray:
160
153
  """Computes the IoU between two sets of rotated bounding boxes
161
154
 
162
155
  Args:
163
- ----
164
156
  polys_1: rotated bounding boxes of shape (N, 4, 2)
165
157
  polys_2: rotated bounding boxes of shape (M, 4, 2)
166
158
  mask_shape: spatial shape of the intermediate masks
167
159
  use_broadcasting: if set to True, leverage broadcasting speedup by consuming more memory
168
160
 
169
161
  Returns:
170
- -------
171
162
  the IoU matrix of shape (N, M)
172
163
  """
173
164
  if polys_1.ndim != 3 or polys_2.ndim != 3:
@@ -187,16 +178,14 @@ def polygon_iou(polys_1: np.ndarray, polys_2: np.ndarray) -> np.ndarray:
187
178
  return iou_mat
188
179
 
189
180
 
190
- def nms(boxes: np.ndarray, thresh: float = 0.5) -> List[int]:
181
+ def nms(boxes: np.ndarray, thresh: float = 0.5) -> list[int]:
191
182
  """Perform non-max suppression, borrowed from <https://github.com/rbgirshick/fast-rcnn>`_.
192
183
 
193
184
  Args:
194
- ----
195
185
  boxes: np array of straight boxes: (*, 5), (xmin, ymin, xmax, ymax, score)
196
186
  thresh: iou threshold to perform box suppression.
197
187
 
198
188
  Returns:
199
- -------
200
189
  A list of box indexes to keep
201
190
  """
202
191
  x1 = boxes[:, 0]
@@ -260,7 +249,6 @@ class LocalizationConfusion:
260
249
  >>> metric.summary()
261
250
 
262
251
  Args:
263
- ----
264
252
  iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match
265
253
  use_polygons: if set to True, predictions and targets will be expected to have rotated format
266
254
  """
@@ -278,7 +266,6 @@ class LocalizationConfusion:
278
266
  """Updates the metric
279
267
 
280
268
  Args:
281
- ----
282
269
  gts: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones
283
270
  preds: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones
284
271
  """
@@ -298,11 +285,10 @@ class LocalizationConfusion:
298
285
  self.num_gts += gts.shape[0]
299
286
  self.num_preds += preds.shape[0]
300
287
 
301
- def summary(self) -> Tuple[Optional[float], Optional[float], Optional[float]]:
288
+ def summary(self) -> tuple[float | None, float | None, float | None]:
302
289
  """Computes the aggregated metrics
303
290
 
304
- Returns
305
- -------
291
+ Returns:
306
292
  a tuple with the recall, precision and meanIoU scores
307
293
  """
308
294
  # Recall
@@ -360,7 +346,6 @@ class OCRMetric:
360
346
  >>> metric.summary()
361
347
 
362
348
  Args:
363
- ----
364
349
  iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match
365
350
  use_polygons: if set to True, predictions and targets will be expected to have rotated format
366
351
  """
@@ -378,13 +363,12 @@ class OCRMetric:
378
363
  self,
379
364
  gt_boxes: np.ndarray,
380
365
  pred_boxes: np.ndarray,
381
- gt_labels: List[str],
382
- pred_labels: List[str],
366
+ gt_labels: list[str],
367
+ pred_labels: list[str],
383
368
  ) -> None:
384
369
  """Updates the metric
385
370
 
386
371
  Args:
387
- ----
388
372
  gt_boxes: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones
389
373
  pred_boxes: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones
390
374
  gt_labels: a list of N string labels
@@ -392,7 +376,7 @@ class OCRMetric:
392
376
  """
393
377
  if gt_boxes.shape[0] != len(gt_labels) or pred_boxes.shape[0] != len(pred_labels):
394
378
  raise AssertionError(
395
- "there should be the same number of boxes and string both for the ground truth " "and the predictions"
379
+ "there should be the same number of boxes and string both for the ground truth and the predictions"
396
380
  )
397
381
 
398
382
  # Compute IoU
@@ -418,11 +402,10 @@ class OCRMetric:
418
402
  self.num_gts += gt_boxes.shape[0]
419
403
  self.num_preds += pred_boxes.shape[0]
420
404
 
421
- def summary(self) -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]], Optional[float]]:
405
+ def summary(self) -> tuple[dict[str, float | None], dict[str, float | None], float | None]:
422
406
  """Computes the aggregated metrics
423
407
 
424
- Returns
425
- -------
408
+ Returns:
426
409
  a tuple with the recall & precision for each string comparison and the mean IoU
427
410
  """
428
411
  # Recall
@@ -493,7 +476,6 @@ class DetectionMetric:
493
476
  >>> metric.summary()
494
477
 
495
478
  Args:
496
- ----
497
479
  iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match
498
480
  use_polygons: if set to True, predictions and targets will be expected to have rotated format
499
481
  """
@@ -517,7 +499,6 @@ class DetectionMetric:
517
499
  """Updates the metric
518
500
 
519
501
  Args:
520
- ----
521
502
  gt_boxes: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones
522
503
  pred_boxes: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones
523
504
  gt_labels: an array of class indices of shape (N,)
@@ -525,7 +506,7 @@ class DetectionMetric:
525
506
  """
526
507
  if gt_boxes.shape[0] != gt_labels.shape[0] or pred_boxes.shape[0] != pred_labels.shape[0]:
527
508
  raise AssertionError(
528
- "there should be the same number of boxes and string both for the ground truth " "and the predictions"
509
+ "there should be the same number of boxes and string both for the ground truth and the predictions"
529
510
  )
530
511
 
531
512
  # Compute IoU
@@ -546,11 +527,10 @@ class DetectionMetric:
546
527
  self.num_gts += gt_boxes.shape[0]
547
528
  self.num_preds += pred_boxes.shape[0]
548
529
 
549
- def summary(self) -> Tuple[Optional[float], Optional[float], Optional[float]]:
530
+ def summary(self) -> tuple[float | None, float | None, float | None]:
550
531
  """Computes the aggregated metrics
551
532
 
552
- Returns
553
- -------
533
+ Returns:
554
534
  a tuple with the recall & precision for each class prediction and the mean IoU
555
535
  """
556
536
  # Recall