python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
doctr/utils/geometry.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
7
|
from math import ceil
|
|
8
|
-
from typing import List, Optional, Tuple, Union
|
|
9
8
|
|
|
10
9
|
import cv2
|
|
11
10
|
import numpy as np
|
|
@@ -20,6 +19,7 @@ __all__ = [
|
|
|
20
19
|
"rotate_boxes",
|
|
21
20
|
"compute_expanded_shape",
|
|
22
21
|
"rotate_image",
|
|
22
|
+
"remove_image_padding",
|
|
23
23
|
"estimate_page_angle",
|
|
24
24
|
"convert_to_relative_coords",
|
|
25
25
|
"rotate_abs_geoms",
|
|
@@ -33,11 +33,9 @@ def bbox_to_polygon(bbox: BoundingBox) -> Polygon4P:
|
|
|
33
33
|
"""Convert a bounding box to a polygon
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
|
-
----
|
|
37
36
|
bbox: a bounding box
|
|
38
37
|
|
|
39
38
|
Returns:
|
|
40
|
-
-------
|
|
41
39
|
a polygon
|
|
42
40
|
"""
|
|
43
41
|
return bbox[0], (bbox[1][0], bbox[0][1]), (bbox[0][0], bbox[1][1]), bbox[1]
|
|
@@ -47,31 +45,27 @@ def polygon_to_bbox(polygon: Polygon4P) -> BoundingBox:
|
|
|
47
45
|
"""Convert a polygon to a bounding box
|
|
48
46
|
|
|
49
47
|
Args:
|
|
50
|
-
----
|
|
51
48
|
polygon: a polygon
|
|
52
49
|
|
|
53
50
|
Returns:
|
|
54
|
-
-------
|
|
55
51
|
a bounding box
|
|
56
52
|
"""
|
|
57
53
|
x, y = zip(*polygon)
|
|
58
54
|
return (min(x), min(y)), (max(x), max(y))
|
|
59
55
|
|
|
60
56
|
|
|
61
|
-
def detach_scores(boxes:
|
|
57
|
+
def detach_scores(boxes: list[np.ndarray]) -> tuple[list[np.ndarray], list[np.ndarray]]:
|
|
62
58
|
"""Detach the objectness scores from box predictions
|
|
63
59
|
|
|
64
60
|
Args:
|
|
65
|
-
----
|
|
66
61
|
boxes: list of arrays with boxes of shape (N, 5) or (N, 5, 2)
|
|
67
62
|
|
|
68
63
|
Returns:
|
|
69
|
-
-------
|
|
70
64
|
a tuple of two lists: the first one contains the boxes without the objectness scores,
|
|
71
65
|
the second one contains the objectness scores
|
|
72
66
|
"""
|
|
73
67
|
|
|
74
|
-
def _detach(boxes: np.ndarray) ->
|
|
68
|
+
def _detach(boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
|
75
69
|
if boxes.ndim == 2:
|
|
76
70
|
return boxes[:, :-1], boxes[:, -1]
|
|
77
71
|
return boxes[:, :-1], boxes[:, -1, -1]
|
|
@@ -80,11 +74,10 @@ def detach_scores(boxes: List[np.ndarray]) -> Tuple[List[np.ndarray], List[np.nd
|
|
|
80
74
|
return list(loc_preds), list(obj_scores)
|
|
81
75
|
|
|
82
76
|
|
|
83
|
-
def resolve_enclosing_bbox(bboxes:
|
|
77
|
+
def resolve_enclosing_bbox(bboxes: list[BoundingBox] | np.ndarray) -> BoundingBox | np.ndarray:
|
|
84
78
|
"""Compute enclosing bbox either from:
|
|
85
79
|
|
|
86
80
|
Args:
|
|
87
|
-
----
|
|
88
81
|
bboxes: boxes in one of the following formats:
|
|
89
82
|
|
|
90
83
|
- an array of boxes: (*, 4), where boxes have this shape:
|
|
@@ -93,7 +86,6 @@ def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Unio
|
|
|
93
86
|
- a list of BoundingBox
|
|
94
87
|
|
|
95
88
|
Returns:
|
|
96
|
-
-------
|
|
97
89
|
a (1, 4) array (enclosing boxarray), or a BoundingBox
|
|
98
90
|
"""
|
|
99
91
|
if isinstance(bboxes, np.ndarray):
|
|
@@ -104,11 +96,10 @@ def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Unio
|
|
|
104
96
|
return (min(x), min(y)), (max(x), max(y))
|
|
105
97
|
|
|
106
98
|
|
|
107
|
-
def resolve_enclosing_rbbox(rbboxes:
|
|
99
|
+
def resolve_enclosing_rbbox(rbboxes: list[np.ndarray], intermed_size: int = 1024) -> np.ndarray:
|
|
108
100
|
"""Compute enclosing rotated bbox either from:
|
|
109
101
|
|
|
110
102
|
Args:
|
|
111
|
-
----
|
|
112
103
|
rbboxes: boxes in one of the following formats:
|
|
113
104
|
|
|
114
105
|
- an array of boxes: (*, 4, 2), where boxes have this shape:
|
|
@@ -118,26 +109,23 @@ def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024
|
|
|
118
109
|
intermed_size: size of the intermediate image
|
|
119
110
|
|
|
120
111
|
Returns:
|
|
121
|
-
-------
|
|
122
112
|
a (4, 2) array (enclosing rotated box)
|
|
123
113
|
"""
|
|
124
114
|
cloud: np.ndarray = np.concatenate(rbboxes, axis=0)
|
|
125
115
|
# Convert to absolute for minAreaRect
|
|
126
116
|
cloud *= intermed_size
|
|
127
117
|
rect = cv2.minAreaRect(cloud.astype(np.int32))
|
|
128
|
-
return cv2.boxPoints(rect) / intermed_size
|
|
118
|
+
return cv2.boxPoints(rect) / intermed_size
|
|
129
119
|
|
|
130
120
|
|
|
131
121
|
def rotate_abs_points(points: np.ndarray, angle: float = 0.0) -> np.ndarray:
|
|
132
122
|
"""Rotate points counter-clockwise.
|
|
133
123
|
|
|
134
124
|
Args:
|
|
135
|
-
----
|
|
136
125
|
points: array of size (N, 2)
|
|
137
126
|
angle: angle between -90 and +90 degrees
|
|
138
127
|
|
|
139
128
|
Returns:
|
|
140
|
-
-------
|
|
141
129
|
Rotated points
|
|
142
130
|
"""
|
|
143
131
|
angle_rad = angle * np.pi / 180.0 # compute radian angle for np functions
|
|
@@ -147,16 +135,14 @@ def rotate_abs_points(points: np.ndarray, angle: float = 0.0) -> np.ndarray:
|
|
|
147
135
|
return np.matmul(points, rotation_mat.T)
|
|
148
136
|
|
|
149
137
|
|
|
150
|
-
def compute_expanded_shape(img_shape:
|
|
138
|
+
def compute_expanded_shape(img_shape: tuple[int, int], angle: float) -> tuple[int, int]:
|
|
151
139
|
"""Compute the shape of an expanded rotated image
|
|
152
140
|
|
|
153
141
|
Args:
|
|
154
|
-
----
|
|
155
142
|
img_shape: the height and width of the image
|
|
156
143
|
angle: angle between -90 and +90 degrees
|
|
157
144
|
|
|
158
145
|
Returns:
|
|
159
|
-
-------
|
|
160
146
|
the height and width of the rotated image
|
|
161
147
|
"""
|
|
162
148
|
points: np.ndarray = np.array([
|
|
@@ -173,21 +159,19 @@ def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[in
|
|
|
173
159
|
def rotate_abs_geoms(
|
|
174
160
|
geoms: np.ndarray,
|
|
175
161
|
angle: float,
|
|
176
|
-
img_shape:
|
|
162
|
+
img_shape: tuple[int, int],
|
|
177
163
|
expand: bool = True,
|
|
178
164
|
) -> np.ndarray:
|
|
179
165
|
"""Rotate a batch of bounding boxes or polygons by an angle around the
|
|
180
166
|
image center.
|
|
181
167
|
|
|
182
168
|
Args:
|
|
183
|
-
----
|
|
184
169
|
geoms: (N, 4) or (N, 4, 2) array of ABSOLUTE coordinate boxes
|
|
185
170
|
angle: anti-clockwise rotation angle in degrees
|
|
186
171
|
img_shape: the height and width of the image
|
|
187
172
|
expand: whether the image should be padded to avoid information loss
|
|
188
173
|
|
|
189
174
|
Returns:
|
|
190
|
-
-------
|
|
191
175
|
A batch of rotated polygons (N, 4, 2)
|
|
192
176
|
"""
|
|
193
177
|
# Switch to polygons
|
|
@@ -213,19 +197,17 @@ def rotate_abs_geoms(
|
|
|
213
197
|
return rotated_polys
|
|
214
198
|
|
|
215
199
|
|
|
216
|
-
def remap_boxes(loc_preds: np.ndarray, orig_shape:
|
|
200
|
+
def remap_boxes(loc_preds: np.ndarray, orig_shape: tuple[int, int], dest_shape: tuple[int, int]) -> np.ndarray:
|
|
217
201
|
"""Remaps a batch of rotated locpred (N, 4, 2) expressed for an origin_shape to a destination_shape.
|
|
218
202
|
This does not impact the absolute shape of the boxes, but allow to calculate the new relative RotatedBbox
|
|
219
203
|
coordinates after a resizing of the image.
|
|
220
204
|
|
|
221
205
|
Args:
|
|
222
|
-
----
|
|
223
206
|
loc_preds: (N, 4, 2) array of RELATIVE loc_preds
|
|
224
207
|
orig_shape: shape of the origin image
|
|
225
208
|
dest_shape: shape of the destination image
|
|
226
209
|
|
|
227
210
|
Returns:
|
|
228
|
-
-------
|
|
229
211
|
A batch of rotated loc_preds (N, 4, 2) expressed in the destination referencial
|
|
230
212
|
"""
|
|
231
213
|
if len(dest_shape) != 2:
|
|
@@ -244,9 +226,9 @@ def remap_boxes(loc_preds: np.ndarray, orig_shape: Tuple[int, int], dest_shape:
|
|
|
244
226
|
def rotate_boxes(
|
|
245
227
|
loc_preds: np.ndarray,
|
|
246
228
|
angle: float,
|
|
247
|
-
orig_shape:
|
|
229
|
+
orig_shape: tuple[int, int],
|
|
248
230
|
min_angle: float = 1.0,
|
|
249
|
-
target_shape:
|
|
231
|
+
target_shape: tuple[int, int] | None = None,
|
|
250
232
|
) -> np.ndarray:
|
|
251
233
|
"""Rotate a batch of straight bounding boxes (xmin, ymin, xmax, ymax, c) or rotated bounding boxes
|
|
252
234
|
(4, 2) of an angle, if angle > min_angle, around the center of the page.
|
|
@@ -254,7 +236,6 @@ def rotate_boxes(
|
|
|
254
236
|
is done to remove the padding that is created by rotate_page(expand=True)
|
|
255
237
|
|
|
256
238
|
Args:
|
|
257
|
-
----
|
|
258
239
|
loc_preds: (N, 4) or (N, 4, 2) array of RELATIVE boxes
|
|
259
240
|
angle: angle between -90 and +90 degrees
|
|
260
241
|
orig_shape: shape of the origin image
|
|
@@ -262,7 +243,6 @@ def rotate_boxes(
|
|
|
262
243
|
target_shape: shape of the destination image
|
|
263
244
|
|
|
264
245
|
Returns:
|
|
265
|
-
-------
|
|
266
246
|
A batch of rotated boxes (N, 4, 2): or a batch of straight bounding boxes
|
|
267
247
|
"""
|
|
268
248
|
# Change format of the boxes to rotated boxes
|
|
@@ -309,14 +289,12 @@ def rotate_image(
|
|
|
309
289
|
"""Rotate an image counterclockwise by an given angle.
|
|
310
290
|
|
|
311
291
|
Args:
|
|
312
|
-
----
|
|
313
292
|
image: numpy tensor to rotate
|
|
314
293
|
angle: rotation angle in degrees, between -90 and +90
|
|
315
294
|
expand: whether the image should be padded before the rotation
|
|
316
295
|
preserve_origin_shape: if expand is set to True, resizes the final output to the original image size
|
|
317
296
|
|
|
318
297
|
Returns:
|
|
319
|
-
-------
|
|
320
298
|
Rotated array, padded by 0 by default.
|
|
321
299
|
"""
|
|
322
300
|
# Compute the expanded padding
|
|
@@ -343,7 +321,7 @@ def rotate_image(
|
|
|
343
321
|
# Pad height
|
|
344
322
|
else:
|
|
345
323
|
h_pad, w_pad = int(rot_img.shape[1] * image.shape[0] / image.shape[1] - rot_img.shape[0]), 0
|
|
346
|
-
rot_img = np.pad(rot_img, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0)))
|
|
324
|
+
rot_img = np.pad(rot_img, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0)))
|
|
347
325
|
if preserve_origin_shape:
|
|
348
326
|
# rescale
|
|
349
327
|
rot_img = cv2.resize(rot_img, image.shape[:-1][::-1], interpolation=cv2.INTER_LINEAR)
|
|
@@ -351,6 +329,24 @@ def rotate_image(
|
|
|
351
329
|
return rot_img
|
|
352
330
|
|
|
353
331
|
|
|
332
|
+
def remove_image_padding(image: np.ndarray) -> np.ndarray:
|
|
333
|
+
"""Remove black border padding from an image
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
image: numpy tensor to remove padding from
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
Image with padding removed
|
|
340
|
+
"""
|
|
341
|
+
# Find the bounding box of the non-black region
|
|
342
|
+
rows = np.any(image, axis=1)
|
|
343
|
+
cols = np.any(image, axis=0)
|
|
344
|
+
rmin, rmax = np.where(rows)[0][[0, -1]]
|
|
345
|
+
cmin, cmax = np.where(cols)[0][[0, -1]]
|
|
346
|
+
|
|
347
|
+
return image[rmin : rmax + 1, cmin : cmax + 1]
|
|
348
|
+
|
|
349
|
+
|
|
354
350
|
def estimate_page_angle(polys: np.ndarray) -> float:
|
|
355
351
|
"""Takes a batch of rotated previously ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the
|
|
356
352
|
estimated angle ccw in degrees
|
|
@@ -369,16 +365,14 @@ def estimate_page_angle(polys: np.ndarray) -> float:
|
|
|
369
365
|
return 0.0
|
|
370
366
|
|
|
371
367
|
|
|
372
|
-
def convert_to_relative_coords(geoms: np.ndarray, img_shape:
|
|
368
|
+
def convert_to_relative_coords(geoms: np.ndarray, img_shape: tuple[int, int]) -> np.ndarray:
|
|
373
369
|
"""Convert a geometry to relative coordinates
|
|
374
370
|
|
|
375
371
|
Args:
|
|
376
|
-
----
|
|
377
372
|
geoms: a set of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4)
|
|
378
373
|
img_shape: the height and width of the image
|
|
379
374
|
|
|
380
375
|
Returns:
|
|
381
|
-
-------
|
|
382
376
|
the updated geometry
|
|
383
377
|
"""
|
|
384
378
|
# Polygon
|
|
@@ -396,18 +390,16 @@ def convert_to_relative_coords(geoms: np.ndarray, img_shape: Tuple[int, int]) ->
|
|
|
396
390
|
raise ValueError(f"invalid format for arg `geoms`: {geoms.shape}")
|
|
397
391
|
|
|
398
392
|
|
|
399
|
-
def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) ->
|
|
393
|
+
def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) -> list[np.ndarray]:
|
|
400
394
|
"""Created cropped images from list of bounding boxes
|
|
401
395
|
|
|
402
396
|
Args:
|
|
403
|
-
----
|
|
404
397
|
img: input image
|
|
405
398
|
boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative
|
|
406
399
|
coordinates (xmin, ymin, xmax, ymax)
|
|
407
400
|
channels_last: whether the channel dimensions is the last one instead of the last one
|
|
408
401
|
|
|
409
402
|
Returns:
|
|
410
|
-
-------
|
|
411
403
|
list of cropped images
|
|
412
404
|
"""
|
|
413
405
|
if boxes.shape[0] == 0:
|
|
@@ -431,19 +423,18 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True
|
|
|
431
423
|
|
|
432
424
|
|
|
433
425
|
def extract_rcrops(
|
|
434
|
-
img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True
|
|
435
|
-
) ->
|
|
426
|
+
img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True, assume_horizontal: bool = False
|
|
427
|
+
) -> list[np.ndarray]:
|
|
436
428
|
"""Created cropped images from list of rotated bounding boxes
|
|
437
429
|
|
|
438
430
|
Args:
|
|
439
|
-
----
|
|
440
431
|
img: input image
|
|
441
432
|
polys: bounding boxes of shape (N, 4, 2)
|
|
442
433
|
dtype: target data type of bounding boxes
|
|
443
434
|
channels_last: whether the channel dimensions is the last one instead of the last one
|
|
435
|
+
assume_horizontal: whether the boxes are assumed to be only horizontally oriented
|
|
444
436
|
|
|
445
437
|
Returns:
|
|
446
|
-
-------
|
|
447
438
|
list of cropped images
|
|
448
439
|
"""
|
|
449
440
|
if polys.shape[0] == 0:
|
|
@@ -458,22 +449,87 @@ def extract_rcrops(
|
|
|
458
449
|
_boxes[:, :, 0] *= width
|
|
459
450
|
_boxes[:, :, 1] *= height
|
|
460
451
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
452
|
+
src_img = img if channels_last else img.transpose(1, 2, 0)
|
|
453
|
+
|
|
454
|
+
# Handle only horizontal oriented boxes
|
|
455
|
+
if assume_horizontal:
|
|
456
|
+
crops = []
|
|
457
|
+
|
|
458
|
+
for box in _boxes:
|
|
459
|
+
# Calculate the centroid of the quadrilateral
|
|
460
|
+
centroid = np.mean(box, axis=0)
|
|
461
|
+
|
|
462
|
+
# Divide the points into left and right
|
|
463
|
+
left_points = box[box[:, 0] < centroid[0]]
|
|
464
|
+
right_points = box[box[:, 0] >= centroid[0]]
|
|
465
|
+
|
|
466
|
+
# Sort the left points according to the y-axis
|
|
467
|
+
left_points = left_points[np.argsort(left_points[:, 1])]
|
|
468
|
+
top_left_pt = left_points[0]
|
|
469
|
+
bottom_left_pt = left_points[-1]
|
|
470
|
+
# Sort the right points according to the y-axis
|
|
471
|
+
right_points = right_points[np.argsort(right_points[:, 1])]
|
|
472
|
+
top_right_pt = right_points[0]
|
|
473
|
+
bottom_right_pt = right_points[-1]
|
|
474
|
+
box_points = np.array(
|
|
475
|
+
[top_left_pt, bottom_left_pt, top_right_pt, bottom_right_pt],
|
|
476
|
+
dtype=dtype,
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
# Get the width and height of the rectangle that will contain the warped quadrilateral
|
|
480
|
+
width_upper = np.linalg.norm(top_right_pt - top_left_pt)
|
|
481
|
+
width_lower = np.linalg.norm(bottom_right_pt - bottom_left_pt)
|
|
482
|
+
height_left = np.linalg.norm(bottom_left_pt - top_left_pt)
|
|
483
|
+
height_right = np.linalg.norm(bottom_right_pt - top_right_pt)
|
|
484
|
+
|
|
485
|
+
# Get the maximum width and height
|
|
486
|
+
rect_width = max(int(width_upper), int(width_lower))
|
|
487
|
+
rect_height = max(int(height_left), int(height_right))
|
|
488
|
+
|
|
489
|
+
dst_pts = np.array(
|
|
490
|
+
[
|
|
491
|
+
[0, 0], # top-left
|
|
492
|
+
# bottom-left
|
|
493
|
+
[0, rect_height - 1],
|
|
494
|
+
# top-right
|
|
495
|
+
[rect_width - 1, 0],
|
|
496
|
+
# bottom-right
|
|
497
|
+
[rect_width - 1, rect_height - 1],
|
|
498
|
+
],
|
|
499
|
+
dtype=dtype,
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# Get the perspective transform matrix using the box points
|
|
503
|
+
affine_mat = cv2.getPerspectiveTransform(box_points, dst_pts)
|
|
504
|
+
|
|
505
|
+
# Perform the perspective warp to get the rectified crop
|
|
506
|
+
crop = cv2.warpPerspective(
|
|
507
|
+
src_img,
|
|
508
|
+
affine_mat,
|
|
509
|
+
(rect_width, rect_height),
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
# Add the crop to the list of crops
|
|
513
|
+
crops.append(crop)
|
|
514
|
+
|
|
515
|
+
# Handle any oriented boxes
|
|
516
|
+
else:
|
|
517
|
+
src_pts = _boxes[:, :3].astype(np.float32)
|
|
518
|
+
# Preserve size
|
|
519
|
+
d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1)
|
|
520
|
+
d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1)
|
|
521
|
+
# (N, 3, 2)
|
|
522
|
+
dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype)
|
|
523
|
+
dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1
|
|
524
|
+
dst_pts[:, 2, 1] = d2 - 1
|
|
525
|
+
# Use a warp transformation to extract the crop
|
|
526
|
+
crops = [
|
|
527
|
+
cv2.warpAffine(
|
|
528
|
+
src_img,
|
|
529
|
+
# Transformation matrix
|
|
530
|
+
cv2.getAffineTransform(src_pts[idx], dst_pts[idx]),
|
|
531
|
+
(int(d1[idx]), int(d2[idx])),
|
|
532
|
+
)
|
|
533
|
+
for idx in range(_boxes.shape[0])
|
|
534
|
+
]
|
|
535
|
+
return crops
|
doctr/utils/metrics.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Dict, List, Optional, Tuple
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
from anyascii import anyascii
|
|
@@ -21,16 +20,14 @@ __all__ = [
|
|
|
21
20
|
]
|
|
22
21
|
|
|
23
22
|
|
|
24
|
-
def string_match(word1: str, word2: str) ->
|
|
23
|
+
def string_match(word1: str, word2: str) -> tuple[bool, bool, bool, bool]:
|
|
25
24
|
"""Performs string comparison with multiple levels of tolerance
|
|
26
25
|
|
|
27
26
|
Args:
|
|
28
|
-
----
|
|
29
27
|
word1: a string
|
|
30
28
|
word2: another string
|
|
31
29
|
|
|
32
30
|
Returns:
|
|
33
|
-
-------
|
|
34
31
|
a tuple with booleans specifying respectively whether the raw strings, their lower-case counterparts, their
|
|
35
32
|
anyascii counterparts and their lower-case anyascii counterparts match
|
|
36
33
|
"""
|
|
@@ -78,13 +75,12 @@ class TextMatch:
|
|
|
78
75
|
|
|
79
76
|
def update(
|
|
80
77
|
self,
|
|
81
|
-
gt:
|
|
82
|
-
pred:
|
|
78
|
+
gt: list[str],
|
|
79
|
+
pred: list[str],
|
|
83
80
|
) -> None:
|
|
84
81
|
"""Update the state of the metric with new predictions
|
|
85
82
|
|
|
86
83
|
Args:
|
|
87
|
-
----
|
|
88
84
|
gt: list of groung-truth character sequences
|
|
89
85
|
pred: list of predicted character sequences
|
|
90
86
|
"""
|
|
@@ -100,11 +96,10 @@ class TextMatch:
|
|
|
100
96
|
|
|
101
97
|
self.total += len(gt)
|
|
102
98
|
|
|
103
|
-
def summary(self) ->
|
|
99
|
+
def summary(self) -> dict[str, float]:
|
|
104
100
|
"""Computes the aggregated metrics
|
|
105
101
|
|
|
106
|
-
Returns
|
|
107
|
-
-------
|
|
102
|
+
Returns:
|
|
108
103
|
a dictionary with the exact match score for the raw data, its lower-case counterpart, its anyascii
|
|
109
104
|
counterpart and its lower-case anyascii counterpart
|
|
110
105
|
"""
|
|
@@ -130,12 +125,10 @@ def box_iou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray:
|
|
|
130
125
|
"""Computes the IoU between two sets of bounding boxes
|
|
131
126
|
|
|
132
127
|
Args:
|
|
133
|
-
----
|
|
134
128
|
boxes_1: bounding boxes of shape (N, 4) in format (xmin, ymin, xmax, ymax)
|
|
135
129
|
boxes_2: bounding boxes of shape (M, 4) in format (xmin, ymin, xmax, ymax)
|
|
136
130
|
|
|
137
131
|
Returns:
|
|
138
|
-
-------
|
|
139
132
|
the IoU matrix of shape (N, M)
|
|
140
133
|
"""
|
|
141
134
|
iou_mat: np.ndarray = np.zeros((boxes_1.shape[0], boxes_2.shape[0]), dtype=np.float32)
|
|
@@ -149,7 +142,7 @@ def box_iou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray:
|
|
|
149
142
|
right = np.minimum(r1, r2.T)
|
|
150
143
|
bot = np.minimum(b1, b2.T)
|
|
151
144
|
|
|
152
|
-
intersection = np.clip(right - left, 0, np.
|
|
145
|
+
intersection = np.clip(right - left, 0, np.inf) * np.clip(bot - top, 0, np.inf)
|
|
153
146
|
union = (r1 - l1) * (b1 - t1) + ((r2 - l2) * (b2 - t2)).T - intersection
|
|
154
147
|
iou_mat = intersection / union
|
|
155
148
|
|
|
@@ -160,14 +153,12 @@ def polygon_iou(polys_1: np.ndarray, polys_2: np.ndarray) -> np.ndarray:
|
|
|
160
153
|
"""Computes the IoU between two sets of rotated bounding boxes
|
|
161
154
|
|
|
162
155
|
Args:
|
|
163
|
-
----
|
|
164
156
|
polys_1: rotated bounding boxes of shape (N, 4, 2)
|
|
165
157
|
polys_2: rotated bounding boxes of shape (M, 4, 2)
|
|
166
158
|
mask_shape: spatial shape of the intermediate masks
|
|
167
159
|
use_broadcasting: if set to True, leverage broadcasting speedup by consuming more memory
|
|
168
160
|
|
|
169
161
|
Returns:
|
|
170
|
-
-------
|
|
171
162
|
the IoU matrix of shape (N, M)
|
|
172
163
|
"""
|
|
173
164
|
if polys_1.ndim != 3 or polys_2.ndim != 3:
|
|
@@ -187,16 +178,14 @@ def polygon_iou(polys_1: np.ndarray, polys_2: np.ndarray) -> np.ndarray:
|
|
|
187
178
|
return iou_mat
|
|
188
179
|
|
|
189
180
|
|
|
190
|
-
def nms(boxes: np.ndarray, thresh: float = 0.5) ->
|
|
181
|
+
def nms(boxes: np.ndarray, thresh: float = 0.5) -> list[int]:
|
|
191
182
|
"""Perform non-max suppression, borrowed from <https://github.com/rbgirshick/fast-rcnn>`_.
|
|
192
183
|
|
|
193
184
|
Args:
|
|
194
|
-
----
|
|
195
185
|
boxes: np array of straight boxes: (*, 5), (xmin, ymin, xmax, ymax, score)
|
|
196
186
|
thresh: iou threshold to perform box suppression.
|
|
197
187
|
|
|
198
188
|
Returns:
|
|
199
|
-
-------
|
|
200
189
|
A list of box indexes to keep
|
|
201
190
|
"""
|
|
202
191
|
x1 = boxes[:, 0]
|
|
@@ -260,7 +249,6 @@ class LocalizationConfusion:
|
|
|
260
249
|
>>> metric.summary()
|
|
261
250
|
|
|
262
251
|
Args:
|
|
263
|
-
----
|
|
264
252
|
iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match
|
|
265
253
|
use_polygons: if set to True, predictions and targets will be expected to have rotated format
|
|
266
254
|
"""
|
|
@@ -278,7 +266,6 @@ class LocalizationConfusion:
|
|
|
278
266
|
"""Updates the metric
|
|
279
267
|
|
|
280
268
|
Args:
|
|
281
|
-
----
|
|
282
269
|
gts: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones
|
|
283
270
|
preds: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones
|
|
284
271
|
"""
|
|
@@ -298,11 +285,10 @@ class LocalizationConfusion:
|
|
|
298
285
|
self.num_gts += gts.shape[0]
|
|
299
286
|
self.num_preds += preds.shape[0]
|
|
300
287
|
|
|
301
|
-
def summary(self) ->
|
|
288
|
+
def summary(self) -> tuple[float | None, float | None, float | None]:
|
|
302
289
|
"""Computes the aggregated metrics
|
|
303
290
|
|
|
304
|
-
Returns
|
|
305
|
-
-------
|
|
291
|
+
Returns:
|
|
306
292
|
a tuple with the recall, precision and meanIoU scores
|
|
307
293
|
"""
|
|
308
294
|
# Recall
|
|
@@ -360,7 +346,6 @@ class OCRMetric:
|
|
|
360
346
|
>>> metric.summary()
|
|
361
347
|
|
|
362
348
|
Args:
|
|
363
|
-
----
|
|
364
349
|
iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match
|
|
365
350
|
use_polygons: if set to True, predictions and targets will be expected to have rotated format
|
|
366
351
|
"""
|
|
@@ -378,13 +363,12 @@ class OCRMetric:
|
|
|
378
363
|
self,
|
|
379
364
|
gt_boxes: np.ndarray,
|
|
380
365
|
pred_boxes: np.ndarray,
|
|
381
|
-
gt_labels:
|
|
382
|
-
pred_labels:
|
|
366
|
+
gt_labels: list[str],
|
|
367
|
+
pred_labels: list[str],
|
|
383
368
|
) -> None:
|
|
384
369
|
"""Updates the metric
|
|
385
370
|
|
|
386
371
|
Args:
|
|
387
|
-
----
|
|
388
372
|
gt_boxes: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones
|
|
389
373
|
pred_boxes: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones
|
|
390
374
|
gt_labels: a list of N string labels
|
|
@@ -392,7 +376,7 @@ class OCRMetric:
|
|
|
392
376
|
"""
|
|
393
377
|
if gt_boxes.shape[0] != len(gt_labels) or pred_boxes.shape[0] != len(pred_labels):
|
|
394
378
|
raise AssertionError(
|
|
395
|
-
"there should be the same number of boxes and string both for the ground truth
|
|
379
|
+
"there should be the same number of boxes and string both for the ground truth and the predictions"
|
|
396
380
|
)
|
|
397
381
|
|
|
398
382
|
# Compute IoU
|
|
@@ -418,11 +402,10 @@ class OCRMetric:
|
|
|
418
402
|
self.num_gts += gt_boxes.shape[0]
|
|
419
403
|
self.num_preds += pred_boxes.shape[0]
|
|
420
404
|
|
|
421
|
-
def summary(self) ->
|
|
405
|
+
def summary(self) -> tuple[dict[str, float | None], dict[str, float | None], float | None]:
|
|
422
406
|
"""Computes the aggregated metrics
|
|
423
407
|
|
|
424
|
-
Returns
|
|
425
|
-
-------
|
|
408
|
+
Returns:
|
|
426
409
|
a tuple with the recall & precision for each string comparison and the mean IoU
|
|
427
410
|
"""
|
|
428
411
|
# Recall
|
|
@@ -493,7 +476,6 @@ class DetectionMetric:
|
|
|
493
476
|
>>> metric.summary()
|
|
494
477
|
|
|
495
478
|
Args:
|
|
496
|
-
----
|
|
497
479
|
iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match
|
|
498
480
|
use_polygons: if set to True, predictions and targets will be expected to have rotated format
|
|
499
481
|
"""
|
|
@@ -517,7 +499,6 @@ class DetectionMetric:
|
|
|
517
499
|
"""Updates the metric
|
|
518
500
|
|
|
519
501
|
Args:
|
|
520
|
-
----
|
|
521
502
|
gt_boxes: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones
|
|
522
503
|
pred_boxes: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones
|
|
523
504
|
gt_labels: an array of class indices of shape (N,)
|
|
@@ -525,7 +506,7 @@ class DetectionMetric:
|
|
|
525
506
|
"""
|
|
526
507
|
if gt_boxes.shape[0] != gt_labels.shape[0] or pred_boxes.shape[0] != pred_labels.shape[0]:
|
|
527
508
|
raise AssertionError(
|
|
528
|
-
"there should be the same number of boxes and string both for the ground truth
|
|
509
|
+
"there should be the same number of boxes and string both for the ground truth and the predictions"
|
|
529
510
|
)
|
|
530
511
|
|
|
531
512
|
# Compute IoU
|
|
@@ -546,11 +527,10 @@ class DetectionMetric:
|
|
|
546
527
|
self.num_gts += gt_boxes.shape[0]
|
|
547
528
|
self.num_preds += pred_boxes.shape[0]
|
|
548
529
|
|
|
549
|
-
def summary(self) ->
|
|
530
|
+
def summary(self) -> tuple[float | None, float | None, float | None]:
|
|
550
531
|
"""Computes the aggregated metrics
|
|
551
532
|
|
|
552
|
-
Returns
|
|
553
|
-
-------
|
|
533
|
+
Returns:
|
|
554
534
|
a tuple with the recall & precision for each class prediction and the mean IoU
|
|
555
535
|
"""
|
|
556
536
|
# Recall
|