python-doctr 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/cord.py +10 -1
- doctr/datasets/funsd.py +11 -1
- doctr/datasets/ic03.py +11 -1
- doctr/datasets/ic13.py +10 -1
- doctr/datasets/iiit5k.py +26 -16
- doctr/datasets/imgur5k.py +10 -1
- doctr/datasets/sroie.py +11 -1
- doctr/datasets/svhn.py +11 -1
- doctr/datasets/svt.py +11 -1
- doctr/datasets/synthtext.py +11 -1
- doctr/datasets/utils.py +7 -2
- doctr/datasets/vocabs.py +6 -2
- doctr/datasets/wildreceipt.py +12 -1
- doctr/file_utils.py +19 -0
- doctr/io/elements.py +12 -4
- doctr/models/builder.py +2 -2
- doctr/models/classification/magc_resnet/tensorflow.py +13 -6
- doctr/models/classification/mobilenet/pytorch.py +2 -0
- doctr/models/classification/mobilenet/tensorflow.py +14 -8
- doctr/models/classification/predictor/pytorch.py +11 -7
- doctr/models/classification/predictor/tensorflow.py +10 -6
- doctr/models/classification/resnet/tensorflow.py +21 -8
- doctr/models/classification/textnet/tensorflow.py +11 -5
- doctr/models/classification/vgg/tensorflow.py +9 -3
- doctr/models/classification/vit/tensorflow.py +10 -4
- doctr/models/classification/zoo.py +22 -10
- doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
- doctr/models/detection/fast/tensorflow.py +14 -11
- doctr/models/detection/linknet/tensorflow.py +23 -11
- doctr/models/detection/predictor/tensorflow.py +2 -2
- doctr/models/factory/hub.py +5 -6
- doctr/models/kie_predictor/base.py +4 -0
- doctr/models/kie_predictor/pytorch.py +4 -0
- doctr/models/kie_predictor/tensorflow.py +8 -1
- doctr/models/modules/transformer/tensorflow.py +0 -2
- doctr/models/modules/vision_transformer/pytorch.py +1 -1
- doctr/models/modules/vision_transformer/tensorflow.py +1 -1
- doctr/models/predictor/base.py +24 -12
- doctr/models/predictor/pytorch.py +4 -0
- doctr/models/predictor/tensorflow.py +8 -1
- doctr/models/preprocessor/tensorflow.py +1 -1
- doctr/models/recognition/crnn/tensorflow.py +8 -6
- doctr/models/recognition/master/tensorflow.py +9 -4
- doctr/models/recognition/parseq/tensorflow.py +10 -8
- doctr/models/recognition/sar/tensorflow.py +7 -3
- doctr/models/recognition/vitstr/tensorflow.py +9 -4
- doctr/models/utils/pytorch.py +1 -1
- doctr/models/utils/tensorflow.py +15 -15
- doctr/transforms/functional/pytorch.py +1 -1
- doctr/transforms/modules/pytorch.py +7 -6
- doctr/transforms/modules/tensorflow.py +15 -12
- doctr/utils/geometry.py +106 -19
- doctr/utils/metrics.py +1 -1
- doctr/utils/reconstitution.py +151 -65
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/METADATA +11 -11
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/RECORD +61 -61
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
|
@@ -12,7 +12,7 @@ from tensorflow.keras import Model, layers
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
|
|
14
14
|
from ...classification import vit_b, vit_s
|
|
15
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
15
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
16
16
|
from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
17
17
|
|
|
18
18
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
@@ -23,14 +23,14 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
23
23
|
"std": (0.299, 0.296, 0.301),
|
|
24
24
|
"input_shape": (32, 128, 3),
|
|
25
25
|
"vocab": VOCABS["french"],
|
|
26
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
26
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0",
|
|
27
27
|
},
|
|
28
28
|
"vitstr_base": {
|
|
29
29
|
"mean": (0.694, 0.695, 0.693),
|
|
30
30
|
"std": (0.299, 0.296, 0.301),
|
|
31
31
|
"input_shape": (32, 128, 3),
|
|
32
32
|
"vocab": VOCABS["french"],
|
|
33
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
33
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0",
|
|
34
34
|
},
|
|
35
35
|
}
|
|
36
36
|
|
|
@@ -216,9 +216,14 @@ def _vitstr(
|
|
|
216
216
|
|
|
217
217
|
# Build the model
|
|
218
218
|
model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs)
|
|
219
|
+
_build_model(model)
|
|
220
|
+
|
|
219
221
|
# Load pretrained parameters
|
|
220
222
|
if pretrained:
|
|
221
|
-
|
|
223
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
224
|
+
load_pretrained_params(
|
|
225
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
226
|
+
)
|
|
222
227
|
|
|
223
228
|
return model
|
|
224
229
|
|
doctr/models/utils/pytorch.py
CHANGED
|
@@ -157,7 +157,7 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
|
|
|
157
157
|
"""
|
|
158
158
|
torch.onnx.export(
|
|
159
159
|
model,
|
|
160
|
-
dummy_input,
|
|
160
|
+
dummy_input, # type: ignore[arg-type]
|
|
161
161
|
f"{model_name}.onnx",
|
|
162
162
|
input_names=["input"],
|
|
163
163
|
output_names=["logits"],
|
doctr/models/utils/tensorflow.py
CHANGED
|
@@ -4,9 +4,7 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
-
import os
|
|
8
7
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
|
9
|
-
from zipfile import ZipFile
|
|
10
8
|
|
|
11
9
|
import tensorflow as tf
|
|
12
10
|
import tf2onnx
|
|
@@ -19,6 +17,7 @@ logging.getLogger("tensorflow").setLevel(logging.DEBUG)
|
|
|
19
17
|
|
|
20
18
|
__all__ = [
|
|
21
19
|
"load_pretrained_params",
|
|
20
|
+
"_build_model",
|
|
22
21
|
"conv_sequence",
|
|
23
22
|
"IntermediateLayerGetter",
|
|
24
23
|
"export_model_to_onnx",
|
|
@@ -36,41 +35,42 @@ def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor:
|
|
|
36
35
|
return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x
|
|
37
36
|
|
|
38
37
|
|
|
38
|
+
def _build_model(model: Model):
|
|
39
|
+
"""Build a model by calling it once with dummy input
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
----
|
|
43
|
+
model: the model to be built
|
|
44
|
+
"""
|
|
45
|
+
model(tf.zeros((1, *model.cfg["input_shape"])), training=False)
|
|
46
|
+
|
|
47
|
+
|
|
39
48
|
def load_pretrained_params(
|
|
40
49
|
model: Model,
|
|
41
50
|
url: Optional[str] = None,
|
|
42
51
|
hash_prefix: Optional[str] = None,
|
|
43
|
-
|
|
44
|
-
internal_name: str = "weights",
|
|
52
|
+
skip_mismatch: bool = False,
|
|
45
53
|
**kwargs: Any,
|
|
46
54
|
) -> None:
|
|
47
55
|
"""Load a set of parameters onto a model
|
|
48
56
|
|
|
49
57
|
>>> from doctr.models import load_pretrained_params
|
|
50
|
-
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.
|
|
58
|
+
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5")
|
|
51
59
|
|
|
52
60
|
Args:
|
|
53
61
|
----
|
|
54
62
|
model: the keras model to be loaded
|
|
55
63
|
url: URL of the zipped set of parameters
|
|
56
64
|
hash_prefix: first characters of SHA256 expected hash
|
|
57
|
-
|
|
58
|
-
internal_name: name of the ckpt files
|
|
65
|
+
skip_mismatch: skip loading layers with mismatched shapes
|
|
59
66
|
**kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
|
|
60
67
|
"""
|
|
61
68
|
if url is None:
|
|
62
69
|
logging.warning("Invalid model URL, using default initialization.")
|
|
63
70
|
else:
|
|
64
71
|
archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
|
|
65
|
-
|
|
66
|
-
# Unzip the archive
|
|
67
|
-
params_path = archive_path.parent.joinpath(archive_path.stem)
|
|
68
|
-
if not params_path.is_dir() or overwrite:
|
|
69
|
-
with ZipFile(archive_path, "r") as f:
|
|
70
|
-
f.extractall(path=params_path)
|
|
71
|
-
|
|
72
72
|
# Load weights
|
|
73
|
-
model.load_weights(
|
|
73
|
+
model.load_weights(archive_path, skip_mismatch=skip_mismatch)
|
|
74
74
|
|
|
75
75
|
|
|
76
76
|
def conv_sequence(
|
|
@@ -89,7 +89,7 @@ def rotate_sample(
|
|
|
89
89
|
rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[2]
|
|
90
90
|
rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[1]
|
|
91
91
|
|
|
92
|
-
return rotated_img, np.clip(rotated_geoms, 0, 1)
|
|
92
|
+
return rotated_img, np.clip(np.around(rotated_geoms, decimals=15), 0, 1)
|
|
93
93
|
|
|
94
94
|
|
|
95
95
|
def crop_detection(
|
|
@@ -74,16 +74,18 @@ class Resize(T.Resize):
|
|
|
74
74
|
if self.symmetric_pad:
|
|
75
75
|
half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2))
|
|
76
76
|
_pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1])
|
|
77
|
+
# Pad image
|
|
77
78
|
img = pad(img, _pad)
|
|
78
79
|
|
|
79
80
|
# In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
|
|
80
81
|
if target is not None:
|
|
82
|
+
if self.symmetric_pad:
|
|
83
|
+
offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
|
|
84
|
+
|
|
81
85
|
if self.preserve_aspect_ratio:
|
|
82
86
|
# Get absolute coords
|
|
83
87
|
if target.shape[1:] == (4,):
|
|
84
88
|
if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
|
|
85
|
-
if np.max(target) <= 1:
|
|
86
|
-
offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
|
|
87
89
|
target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
|
|
88
90
|
target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
|
|
89
91
|
else:
|
|
@@ -91,16 +93,15 @@ class Resize(T.Resize):
|
|
|
91
93
|
target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
|
|
92
94
|
elif target.shape[1:] == (4, 2):
|
|
93
95
|
if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
|
|
94
|
-
if np.max(target) <= 1:
|
|
95
|
-
offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
|
|
96
96
|
target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
|
|
97
97
|
target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
|
|
98
98
|
else:
|
|
99
99
|
target[..., 0] *= raw_shape[-1] / img.shape[-1]
|
|
100
100
|
target[..., 1] *= raw_shape[-2] / img.shape[-2]
|
|
101
101
|
else:
|
|
102
|
-
raise AssertionError
|
|
103
|
-
|
|
102
|
+
raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")
|
|
103
|
+
|
|
104
|
+
return img, np.clip(target, 0, 1)
|
|
104
105
|
|
|
105
106
|
return img
|
|
106
107
|
|
|
@@ -107,29 +107,34 @@ class Resize(NestedObject):
|
|
|
107
107
|
target: Optional[np.ndarray] = None,
|
|
108
108
|
) -> Union[tf.Tensor, Tuple[tf.Tensor, np.ndarray]]:
|
|
109
109
|
input_dtype = img.dtype
|
|
110
|
+
self.output_size = (
|
|
111
|
+
(self.output_size, self.output_size) if isinstance(self.output_size, int) else self.output_size
|
|
112
|
+
)
|
|
110
113
|
|
|
111
114
|
img = tf.image.resize(img, self.wanted_size, self.method, self.preserve_aspect_ratio, self.antialias)
|
|
112
115
|
# It will produce an un-padded resized image, with a side shorter than wanted if we preserve aspect ratio
|
|
113
116
|
raw_shape = img.shape[:2]
|
|
117
|
+
if self.symmetric_pad:
|
|
118
|
+
half_pad = (int((self.output_size[0] - img.shape[0]) / 2), 0)
|
|
114
119
|
if self.preserve_aspect_ratio:
|
|
115
120
|
if isinstance(self.output_size, (tuple, list)):
|
|
116
121
|
# In that case we need to pad because we want to enforce both width and height
|
|
117
122
|
if not self.symmetric_pad:
|
|
118
|
-
|
|
123
|
+
half_pad = (0, 0)
|
|
119
124
|
elif self.output_size[0] == img.shape[0]:
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
img = tf.image.pad_to_bounding_box(img, *offset, *self.output_size)
|
|
125
|
+
half_pad = (0, int((self.output_size[1] - img.shape[1]) / 2))
|
|
126
|
+
# Pad image
|
|
127
|
+
img = tf.image.pad_to_bounding_box(img, *half_pad, *self.output_size)
|
|
124
128
|
|
|
125
129
|
# In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
|
|
126
130
|
if target is not None:
|
|
131
|
+
if self.symmetric_pad:
|
|
132
|
+
offset = half_pad[0] / img.shape[0], half_pad[1] / img.shape[1]
|
|
133
|
+
|
|
127
134
|
if self.preserve_aspect_ratio:
|
|
128
135
|
# Get absolute coords
|
|
129
136
|
if target.shape[1:] == (4,):
|
|
130
137
|
if isinstance(self.output_size, (tuple, list)) and self.symmetric_pad:
|
|
131
|
-
if np.max(target) <= 1:
|
|
132
|
-
offset = offset[0] / img.shape[0], offset[1] / img.shape[1]
|
|
133
138
|
target[:, [0, 2]] = offset[1] + target[:, [0, 2]] * raw_shape[1] / img.shape[1]
|
|
134
139
|
target[:, [1, 3]] = offset[0] + target[:, [1, 3]] * raw_shape[0] / img.shape[0]
|
|
135
140
|
else:
|
|
@@ -137,16 +142,15 @@ class Resize(NestedObject):
|
|
|
137
142
|
target[:, [1, 3]] *= raw_shape[0] / img.shape[0]
|
|
138
143
|
elif target.shape[1:] == (4, 2):
|
|
139
144
|
if isinstance(self.output_size, (tuple, list)) and self.symmetric_pad:
|
|
140
|
-
if np.max(target) <= 1:
|
|
141
|
-
offset = offset[0] / img.shape[0], offset[1] / img.shape[1]
|
|
142
145
|
target[..., 0] = offset[1] + target[..., 0] * raw_shape[1] / img.shape[1]
|
|
143
146
|
target[..., 1] = offset[0] + target[..., 1] * raw_shape[0] / img.shape[0]
|
|
144
147
|
else:
|
|
145
148
|
target[..., 0] *= raw_shape[1] / img.shape[1]
|
|
146
149
|
target[..., 1] *= raw_shape[0] / img.shape[0]
|
|
147
150
|
else:
|
|
148
|
-
raise AssertionError
|
|
149
|
-
|
|
151
|
+
raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")
|
|
152
|
+
|
|
153
|
+
return tf.cast(img, dtype=input_dtype), np.clip(target, 0, 1)
|
|
150
154
|
|
|
151
155
|
return tf.cast(img, dtype=input_dtype)
|
|
152
156
|
|
|
@@ -395,7 +399,6 @@ class GaussianBlur(NestedObject):
|
|
|
395
399
|
def extra_repr(self) -> str:
|
|
396
400
|
return f"kernel_shape={self.kernel_shape}, std={self.std}"
|
|
397
401
|
|
|
398
|
-
@tf.function
|
|
399
402
|
def __call__(self, img: tf.Tensor) -> tf.Tensor:
|
|
400
403
|
return tf.squeeze(
|
|
401
404
|
_gaussian_filter(
|
doctr/utils/geometry.py
CHANGED
|
@@ -20,6 +20,7 @@ __all__ = [
|
|
|
20
20
|
"rotate_boxes",
|
|
21
21
|
"compute_expanded_shape",
|
|
22
22
|
"rotate_image",
|
|
23
|
+
"remove_image_padding",
|
|
23
24
|
"estimate_page_angle",
|
|
24
25
|
"convert_to_relative_coords",
|
|
25
26
|
"rotate_abs_geoms",
|
|
@@ -351,6 +352,26 @@ def rotate_image(
|
|
|
351
352
|
return rot_img
|
|
352
353
|
|
|
353
354
|
|
|
355
|
+
def remove_image_padding(image: np.ndarray) -> np.ndarray:
|
|
356
|
+
"""Remove black border padding from an image
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
----
|
|
360
|
+
image: numpy tensor to remove padding from
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
-------
|
|
364
|
+
Image with padding removed
|
|
365
|
+
"""
|
|
366
|
+
# Find the bounding box of the non-black region
|
|
367
|
+
rows = np.any(image, axis=1)
|
|
368
|
+
cols = np.any(image, axis=0)
|
|
369
|
+
rmin, rmax = np.where(rows)[0][[0, -1]]
|
|
370
|
+
cmin, cmax = np.where(cols)[0][[0, -1]]
|
|
371
|
+
|
|
372
|
+
return image[rmin : rmax + 1, cmin : cmax + 1]
|
|
373
|
+
|
|
374
|
+
|
|
354
375
|
def estimate_page_angle(polys: np.ndarray) -> float:
|
|
355
376
|
"""Takes a batch of rotated previously ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the
|
|
356
377
|
estimated angle ccw in degrees
|
|
@@ -431,7 +452,7 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True
|
|
|
431
452
|
|
|
432
453
|
|
|
433
454
|
def extract_rcrops(
|
|
434
|
-
img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True
|
|
455
|
+
img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True, assume_horizontal: bool = False
|
|
435
456
|
) -> List[np.ndarray]:
|
|
436
457
|
"""Created cropped images from list of rotated bounding boxes
|
|
437
458
|
|
|
@@ -441,6 +462,7 @@ def extract_rcrops(
|
|
|
441
462
|
polys: bounding boxes of shape (N, 4, 2)
|
|
442
463
|
dtype: target data type of bounding boxes
|
|
443
464
|
channels_last: whether the channel dimensions is the last one instead of the last one
|
|
465
|
+
assume_horizontal: whether the boxes are assumed to be only horizontally oriented
|
|
444
466
|
|
|
445
467
|
Returns:
|
|
446
468
|
-------
|
|
@@ -458,22 +480,87 @@ def extract_rcrops(
|
|
|
458
480
|
_boxes[:, :, 0] *= width
|
|
459
481
|
_boxes[:, :, 1] *= height
|
|
460
482
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
483
|
+
src_img = img if channels_last else img.transpose(1, 2, 0)
|
|
484
|
+
|
|
485
|
+
# Handle only horizontal oriented boxes
|
|
486
|
+
if assume_horizontal:
|
|
487
|
+
crops = []
|
|
488
|
+
|
|
489
|
+
for box in _boxes:
|
|
490
|
+
# Calculate the centroid of the quadrilateral
|
|
491
|
+
centroid = np.mean(box, axis=0)
|
|
492
|
+
|
|
493
|
+
# Divide the points into left and right
|
|
494
|
+
left_points = box[box[:, 0] < centroid[0]]
|
|
495
|
+
right_points = box[box[:, 0] >= centroid[0]]
|
|
496
|
+
|
|
497
|
+
# Sort the left points according to the y-axis
|
|
498
|
+
left_points = left_points[np.argsort(left_points[:, 1])]
|
|
499
|
+
top_left_pt = left_points[0]
|
|
500
|
+
bottom_left_pt = left_points[-1]
|
|
501
|
+
# Sort the right points according to the y-axis
|
|
502
|
+
right_points = right_points[np.argsort(right_points[:, 1])]
|
|
503
|
+
top_right_pt = right_points[0]
|
|
504
|
+
bottom_right_pt = right_points[-1]
|
|
505
|
+
box_points = np.array(
|
|
506
|
+
[top_left_pt, bottom_left_pt, top_right_pt, bottom_right_pt],
|
|
507
|
+
dtype=dtype,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# Get the width and height of the rectangle that will contain the warped quadrilateral
|
|
511
|
+
width_upper = np.linalg.norm(top_right_pt - top_left_pt)
|
|
512
|
+
width_lower = np.linalg.norm(bottom_right_pt - bottom_left_pt)
|
|
513
|
+
height_left = np.linalg.norm(bottom_left_pt - top_left_pt)
|
|
514
|
+
height_right = np.linalg.norm(bottom_right_pt - top_right_pt)
|
|
515
|
+
|
|
516
|
+
# Get the maximum width and height
|
|
517
|
+
rect_width = max(int(width_upper), int(width_lower))
|
|
518
|
+
rect_height = max(int(height_left), int(height_right))
|
|
519
|
+
|
|
520
|
+
dst_pts = np.array(
|
|
521
|
+
[
|
|
522
|
+
[0, 0], # top-left
|
|
523
|
+
# bottom-left
|
|
524
|
+
[0, rect_height - 1],
|
|
525
|
+
# top-right
|
|
526
|
+
[rect_width - 1, 0],
|
|
527
|
+
# bottom-right
|
|
528
|
+
[rect_width - 1, rect_height - 1],
|
|
529
|
+
],
|
|
530
|
+
dtype=dtype,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# Get the perspective transform matrix using the box points
|
|
534
|
+
affine_mat = cv2.getPerspectiveTransform(box_points, dst_pts)
|
|
535
|
+
|
|
536
|
+
# Perform the perspective warp to get the rectified crop
|
|
537
|
+
crop = cv2.warpPerspective(
|
|
538
|
+
src_img,
|
|
539
|
+
affine_mat,
|
|
540
|
+
(rect_width, rect_height),
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# Add the crop to the list of crops
|
|
544
|
+
crops.append(crop)
|
|
545
|
+
|
|
546
|
+
# Handle any oriented boxes
|
|
547
|
+
else:
|
|
548
|
+
src_pts = _boxes[:, :3].astype(np.float32)
|
|
549
|
+
# Preserve size
|
|
550
|
+
d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1)
|
|
551
|
+
d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1)
|
|
552
|
+
# (N, 3, 2)
|
|
553
|
+
dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype)
|
|
554
|
+
dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1
|
|
555
|
+
dst_pts[:, 2, 1] = d2 - 1
|
|
556
|
+
# Use a warp transformation to extract the crop
|
|
557
|
+
crops = [
|
|
558
|
+
cv2.warpAffine(
|
|
559
|
+
src_img,
|
|
560
|
+
# Transformation matrix
|
|
561
|
+
cv2.getAffineTransform(src_pts[idx], dst_pts[idx]),
|
|
562
|
+
(int(d1[idx]), int(d2[idx])),
|
|
563
|
+
)
|
|
564
|
+
for idx in range(_boxes.shape[0])
|
|
565
|
+
]
|
|
479
566
|
return crops # type: ignore[return-value]
|
doctr/utils/metrics.py
CHANGED
|
@@ -149,7 +149,7 @@ def box_iou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray:
|
|
|
149
149
|
right = np.minimum(r1, r2.T)
|
|
150
150
|
bot = np.minimum(b1, b2.T)
|
|
151
151
|
|
|
152
|
-
intersection = np.clip(right - left, 0, np.
|
|
152
|
+
intersection = np.clip(right - left, 0, np.inf) * np.clip(bot - top, 0, np.inf)
|
|
153
153
|
union = (r1 - l1) * (b1 - t1) + ((r2 - l2) * (b2 - t2)).T - intersection
|
|
154
154
|
iou_mat = intersection / union
|
|
155
155
|
|