python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
doctr/datasets/sroie.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
import csv
|
|
7
7
|
import os
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from tqdm import tqdm
|
|
@@ -29,7 +29,6 @@ class SROIE(VisionDataset):
|
|
|
29
29
|
>>> img, target = train_set[0]
|
|
30
30
|
|
|
31
31
|
Args:
|
|
32
|
-
----
|
|
33
32
|
train: whether the subset should be the training one
|
|
34
33
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
35
34
|
recognition_task: whether the dataset should be used for recognition task
|
|
@@ -74,10 +73,12 @@ class SROIE(VisionDataset):
|
|
|
74
73
|
self.train = train
|
|
75
74
|
|
|
76
75
|
tmp_root = os.path.join(self.root, "images")
|
|
77
|
-
self.data:
|
|
76
|
+
self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
|
|
78
77
|
np_dtype = np.float32
|
|
79
78
|
|
|
80
|
-
for img_path in tqdm(
|
|
79
|
+
for img_path in tqdm(
|
|
80
|
+
iterable=os.listdir(tmp_root), desc="Preparing and Loading SROIE", total=len(os.listdir(tmp_root))
|
|
81
|
+
):
|
|
81
82
|
# File existence check
|
|
82
83
|
if not os.path.exists(os.path.join(tmp_root, img_path)):
|
|
83
84
|
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
|
doctr/datasets/svhn.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import os
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import h5py
|
|
10
10
|
import numpy as np
|
|
@@ -28,7 +28,6 @@ class SVHN(VisionDataset):
|
|
|
28
28
|
>>> img, target = train_set[0]
|
|
29
29
|
|
|
30
30
|
Args:
|
|
31
|
-
----
|
|
32
31
|
train: whether the subset should be the training one
|
|
33
32
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
34
33
|
recognition_task: whether the dataset should be used for recognition task
|
|
@@ -72,7 +71,7 @@ class SVHN(VisionDataset):
|
|
|
72
71
|
)
|
|
73
72
|
|
|
74
73
|
self.train = train
|
|
75
|
-
self.data:
|
|
74
|
+
self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
|
|
76
75
|
np_dtype = np.float32
|
|
77
76
|
|
|
78
77
|
tmp_root = os.path.join(self.root, "train" if train else "test")
|
|
@@ -81,7 +80,9 @@ class SVHN(VisionDataset):
|
|
|
81
80
|
with h5py.File(os.path.join(tmp_root, "digitStruct.mat"), "r") as f:
|
|
82
81
|
img_refs = f["digitStruct/name"]
|
|
83
82
|
box_refs = f["digitStruct/bbox"]
|
|
84
|
-
for img_ref, box_ref in tqdm(
|
|
83
|
+
for img_ref, box_ref in tqdm(
|
|
84
|
+
iterable=zip(img_refs, box_refs), desc="Preparing and Loading SVHN", total=len(img_refs)
|
|
85
|
+
):
|
|
85
86
|
# convert ascii matrix to string
|
|
86
87
|
img_name = "".join(map(chr, f[img_ref[0]][()].flatten()))
|
|
87
88
|
|
|
@@ -128,7 +129,7 @@ class SVHN(VisionDataset):
|
|
|
128
129
|
if recognition_task:
|
|
129
130
|
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_name), geoms=box_targets)
|
|
130
131
|
for crop, label in zip(crops, label_targets):
|
|
131
|
-
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
|
|
132
|
+
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0 and " " not in label:
|
|
132
133
|
self.data.append((crop, label))
|
|
133
134
|
elif detection_task:
|
|
134
135
|
self.data.append((img_name, box_targets))
|
doctr/datasets/svt.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import os
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import defusedxml.ElementTree as ET
|
|
10
10
|
import numpy as np
|
|
@@ -28,7 +28,6 @@ class SVT(VisionDataset):
|
|
|
28
28
|
>>> img, target = train_set[0]
|
|
29
29
|
|
|
30
30
|
Args:
|
|
31
|
-
----
|
|
32
31
|
train: whether the subset should be the training one
|
|
33
32
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
34
33
|
recognition_task: whether the dataset should be used for recognition task
|
|
@@ -36,7 +35,7 @@ class SVT(VisionDataset):
|
|
|
36
35
|
**kwargs: keyword arguments from `VisionDataset`.
|
|
37
36
|
"""
|
|
38
37
|
|
|
39
|
-
URL = "http://
|
|
38
|
+
URL = "http://www.iapr-tc11.org/dataset/SVT/svt.zip"
|
|
40
39
|
SHA256 = "63b3d55e6b6d1e036e2a844a20c034fe3af3c32e4d914d6e0c4a3cd43df3bebf"
|
|
41
40
|
|
|
42
41
|
def __init__(
|
|
@@ -62,7 +61,7 @@ class SVT(VisionDataset):
|
|
|
62
61
|
)
|
|
63
62
|
|
|
64
63
|
self.train = train
|
|
65
|
-
self.data:
|
|
64
|
+
self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
|
|
66
65
|
np_dtype = np.float32
|
|
67
66
|
|
|
68
67
|
# Load xml data
|
|
@@ -74,7 +73,7 @@ class SVT(VisionDataset):
|
|
|
74
73
|
)
|
|
75
74
|
xml_root = xml_tree.getroot()
|
|
76
75
|
|
|
77
|
-
for image in tqdm(iterable=xml_root, desc="
|
|
76
|
+
for image in tqdm(iterable=xml_root, desc="Preparing and Loading SVT", total=len(xml_root)):
|
|
78
77
|
name, _, _, _resolution, rectangles = image
|
|
79
78
|
|
|
80
79
|
# File existence check
|
|
@@ -114,7 +113,7 @@ class SVT(VisionDataset):
|
|
|
114
113
|
if recognition_task:
|
|
115
114
|
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
|
|
116
115
|
for crop, label in zip(crops, labels):
|
|
117
|
-
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
|
|
116
|
+
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0 and " " not in label:
|
|
118
117
|
self.data.append((crop, label))
|
|
119
118
|
elif detection_task:
|
|
120
119
|
self.data.append((name.text, boxes))
|
doctr/datasets/synthtext.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import glob
|
|
7
7
|
import os
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
from PIL import Image
|
|
@@ -31,7 +31,6 @@ class SynthText(VisionDataset):
|
|
|
31
31
|
>>> img, target = train_set[0]
|
|
32
32
|
|
|
33
33
|
Args:
|
|
34
|
-
----
|
|
35
34
|
train: whether the subset should be the training one
|
|
36
35
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
37
36
|
recognition_task: whether the dataset should be used for recognition task
|
|
@@ -42,6 +41,12 @@ class SynthText(VisionDataset):
|
|
|
42
41
|
URL = "https://thor.robots.ox.ac.uk/~vgg/data/scenetext/SynthText.zip"
|
|
43
42
|
SHA256 = "28ab030485ec8df3ed612c568dd71fb2793b9afbfa3a9d9c6e792aef33265bf1"
|
|
44
43
|
|
|
44
|
+
# filter corrupted or missing images
|
|
45
|
+
BLACKLIST = (
|
|
46
|
+
"67/fruits_129_",
|
|
47
|
+
"194/window_19_",
|
|
48
|
+
)
|
|
49
|
+
|
|
45
50
|
def __init__(
|
|
46
51
|
self,
|
|
47
52
|
train: bool = True,
|
|
@@ -65,7 +70,7 @@ class SynthText(VisionDataset):
|
|
|
65
70
|
)
|
|
66
71
|
|
|
67
72
|
self.train = train
|
|
68
|
-
self.data:
|
|
73
|
+
self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
|
|
69
74
|
np_dtype = np.float32
|
|
70
75
|
|
|
71
76
|
# Load mat data
|
|
@@ -91,7 +96,7 @@ class SynthText(VisionDataset):
|
|
|
91
96
|
del mat_data
|
|
92
97
|
|
|
93
98
|
for img_path, word_boxes, txt in tqdm(
|
|
94
|
-
iterable=zip(paths, boxes, labels), desc="
|
|
99
|
+
iterable=zip(paths, boxes, labels), desc="Preparing and Loading SynthText", total=len(paths)
|
|
95
100
|
):
|
|
96
101
|
# File existence check
|
|
97
102
|
if not os.path.exists(os.path.join(tmp_root, img_path[0])):
|
|
@@ -112,7 +117,13 @@ class SynthText(VisionDataset):
|
|
|
112
117
|
if recognition_task:
|
|
113
118
|
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path[0]), geoms=word_boxes)
|
|
114
119
|
for crop, label in zip(crops, labels):
|
|
115
|
-
if
|
|
120
|
+
if (
|
|
121
|
+
crop.shape[0] > 0
|
|
122
|
+
and crop.shape[1] > 0
|
|
123
|
+
and len(label) > 0
|
|
124
|
+
and len(label) < 30
|
|
125
|
+
and " " not in label
|
|
126
|
+
):
|
|
116
127
|
# write data to disk
|
|
117
128
|
with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f:
|
|
118
129
|
f.write(label)
|
|
@@ -133,6 +144,7 @@ class SynthText(VisionDataset):
|
|
|
133
144
|
return f"train={self.train}"
|
|
134
145
|
|
|
135
146
|
def _read_from_folder(self, path: str) -> None:
|
|
136
|
-
|
|
147
|
+
img_paths = glob.glob(os.path.join(path, "*.png"))
|
|
148
|
+
for img_path in tqdm(iterable=img_paths, desc="Preparing and Loading SynthText", total=len(img_paths)):
|
|
137
149
|
with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
|
|
138
150
|
self.data.append((img_path, f.read()))
|
doctr/datasets/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,10 +6,10 @@
|
|
|
6
6
|
import string
|
|
7
7
|
import unicodedata
|
|
8
8
|
from collections.abc import Sequence
|
|
9
|
+
from collections.abc import Sequence as SequenceType
|
|
9
10
|
from functools import partial
|
|
10
11
|
from pathlib import Path
|
|
11
|
-
from typing import Any,
|
|
12
|
-
from typing import Sequence as SequenceType
|
|
12
|
+
from typing import Any, TypeVar
|
|
13
13
|
|
|
14
14
|
import numpy as np
|
|
15
15
|
from PIL import Image
|
|
@@ -19,7 +19,15 @@ from doctr.utils.geometry import convert_to_relative_coords, extract_crops, extr
|
|
|
19
19
|
|
|
20
20
|
from .vocabs import VOCABS
|
|
21
21
|
|
|
22
|
-
__all__ = [
|
|
22
|
+
__all__ = [
|
|
23
|
+
"translate",
|
|
24
|
+
"encode_string",
|
|
25
|
+
"decode_sequence",
|
|
26
|
+
"encode_sequences",
|
|
27
|
+
"pre_transform_multiclass",
|
|
28
|
+
"crop_bboxes_from_image",
|
|
29
|
+
"convert_target_to_relative",
|
|
30
|
+
]
|
|
23
31
|
|
|
24
32
|
ImageTensor = TypeVar("ImageTensor")
|
|
25
33
|
|
|
@@ -32,17 +40,15 @@ def translate(
|
|
|
32
40
|
"""Translate a string input in a given vocabulary
|
|
33
41
|
|
|
34
42
|
Args:
|
|
35
|
-
----
|
|
36
43
|
input_string: input string to translate
|
|
37
44
|
vocab_name: vocabulary to use (french, latin, ...)
|
|
38
45
|
unknown_char: unknown character for non-translatable characters
|
|
39
46
|
|
|
40
47
|
Returns:
|
|
41
|
-
-------
|
|
42
48
|
A string translated in a given vocab
|
|
43
49
|
"""
|
|
44
50
|
if VOCABS.get(vocab_name) is None:
|
|
45
|
-
raise KeyError("output vocabulary must be in vocabs
|
|
51
|
+
raise KeyError("output vocabulary must be in vocabs dictionary")
|
|
46
52
|
|
|
47
53
|
translated = ""
|
|
48
54
|
for char in input_string:
|
|
@@ -63,40 +69,37 @@ def translate(
|
|
|
63
69
|
def encode_string(
|
|
64
70
|
input_string: str,
|
|
65
71
|
vocab: str,
|
|
66
|
-
) ->
|
|
72
|
+
) -> list[int]:
|
|
67
73
|
"""Given a predefined mapping, encode the string to a sequence of numbers
|
|
68
74
|
|
|
69
75
|
Args:
|
|
70
|
-
----
|
|
71
76
|
input_string: string to encode
|
|
72
77
|
vocab: vocabulary (string), the encoding is given by the indexing of the character sequence
|
|
73
78
|
|
|
74
79
|
Returns:
|
|
75
|
-
-------
|
|
76
80
|
A list encoding the input_string
|
|
77
81
|
"""
|
|
78
82
|
try:
|
|
79
83
|
return list(map(vocab.index, input_string))
|
|
80
|
-
except ValueError:
|
|
84
|
+
except ValueError as e:
|
|
85
|
+
missing_chars = [char for char in input_string if char not in vocab]
|
|
81
86
|
raise ValueError(
|
|
82
|
-
f"
|
|
83
|
-
|
|
84
|
-
)
|
|
87
|
+
f"Some characters cannot be found in 'vocab': {set(missing_chars)}.\n"
|
|
88
|
+
f"Please check the input string `{input_string}` and the vocabulary `{vocab}`"
|
|
89
|
+
) from e
|
|
85
90
|
|
|
86
91
|
|
|
87
92
|
def decode_sequence(
|
|
88
|
-
input_seq:
|
|
93
|
+
input_seq: np.ndarray | SequenceType[int],
|
|
89
94
|
mapping: str,
|
|
90
95
|
) -> str:
|
|
91
96
|
"""Given a predefined mapping, decode the sequence of numbers to a string
|
|
92
97
|
|
|
93
98
|
Args:
|
|
94
|
-
----
|
|
95
99
|
input_seq: array to decode
|
|
96
100
|
mapping: vocabulary (string), the encoding is given by the indexing of the character sequence
|
|
97
101
|
|
|
98
102
|
Returns:
|
|
99
|
-
-------
|
|
100
103
|
A string, decoded from input_seq
|
|
101
104
|
"""
|
|
102
105
|
if not isinstance(input_seq, (Sequence, np.ndarray)):
|
|
@@ -108,18 +111,17 @@ def decode_sequence(
|
|
|
108
111
|
|
|
109
112
|
|
|
110
113
|
def encode_sequences(
|
|
111
|
-
sequences:
|
|
114
|
+
sequences: list[str],
|
|
112
115
|
vocab: str,
|
|
113
|
-
target_size:
|
|
116
|
+
target_size: int | None = None,
|
|
114
117
|
eos: int = -1,
|
|
115
|
-
sos:
|
|
116
|
-
pad:
|
|
118
|
+
sos: int | None = None,
|
|
119
|
+
pad: int | None = None,
|
|
117
120
|
dynamic_seq_length: bool = False,
|
|
118
121
|
) -> np.ndarray:
|
|
119
122
|
"""Encode character sequences using a given vocab as mapping
|
|
120
123
|
|
|
121
124
|
Args:
|
|
122
|
-
----
|
|
123
125
|
sequences: the list of character sequences of size N
|
|
124
126
|
vocab: the ordered vocab to use for encoding
|
|
125
127
|
target_size: maximum length of the encoded data
|
|
@@ -129,7 +131,6 @@ def encode_sequences(
|
|
|
129
131
|
dynamic_seq_length: if `target_size` is specified, uses it as upper bound and enables dynamic sequence size
|
|
130
132
|
|
|
131
133
|
Returns:
|
|
132
|
-
-------
|
|
133
134
|
the padded encoded data as a tensor
|
|
134
135
|
"""
|
|
135
136
|
if 0 <= eos < len(vocab):
|
|
@@ -170,29 +171,36 @@ def encode_sequences(
|
|
|
170
171
|
|
|
171
172
|
|
|
172
173
|
def convert_target_to_relative(
|
|
173
|
-
img: ImageTensor, target:
|
|
174
|
-
) ->
|
|
174
|
+
img: ImageTensor, target: np.ndarray | dict[str, Any]
|
|
175
|
+
) -> tuple[ImageTensor, dict[str, Any] | np.ndarray]:
|
|
176
|
+
"""Converts target to relative coordinates
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
img: tf.Tensor or torch.Tensor representing the image
|
|
180
|
+
target: target to convert to relative coordinates (boxes (N, 4) or polygons (N, 4, 2))
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
The image and the target in relative coordinates
|
|
184
|
+
"""
|
|
175
185
|
if isinstance(target, np.ndarray):
|
|
176
|
-
target = convert_to_relative_coords(target, get_img_shape(img))
|
|
186
|
+
target = convert_to_relative_coords(target, get_img_shape(img)) # type: ignore[arg-type]
|
|
177
187
|
else:
|
|
178
|
-
target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img))
|
|
188
|
+
target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img)) # type: ignore[arg-type]
|
|
179
189
|
return img, target
|
|
180
190
|
|
|
181
191
|
|
|
182
|
-
def crop_bboxes_from_image(img_path:
|
|
192
|
+
def crop_bboxes_from_image(img_path: str | Path, geoms: np.ndarray) -> list[np.ndarray]:
|
|
183
193
|
"""Crop a set of bounding boxes from an image
|
|
184
194
|
|
|
185
195
|
Args:
|
|
186
|
-
----
|
|
187
196
|
img_path: path to the image
|
|
188
197
|
geoms: a array of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4)
|
|
189
198
|
|
|
190
199
|
Returns:
|
|
191
|
-
-------
|
|
192
200
|
a list of cropped images
|
|
193
201
|
"""
|
|
194
202
|
with Image.open(img_path) as pil_img:
|
|
195
|
-
img: np.ndarray = np.
|
|
203
|
+
img: np.ndarray = np.asarray(pil_img.convert("RGB"))
|
|
196
204
|
# Polygon
|
|
197
205
|
if geoms.ndim == 3 and geoms.shape[1:] == (4, 2):
|
|
198
206
|
return extract_rcrops(img, geoms.astype(dtype=int))
|
|
@@ -201,21 +209,19 @@ def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> Lis
|
|
|
201
209
|
raise ValueError("Invalid geometry format")
|
|
202
210
|
|
|
203
211
|
|
|
204
|
-
def pre_transform_multiclass(img, target:
|
|
212
|
+
def pre_transform_multiclass(img, target: tuple[np.ndarray, list]) -> tuple[np.ndarray, dict[str, list]]:
|
|
205
213
|
"""Converts multiclass target to relative coordinates.
|
|
206
214
|
|
|
207
215
|
Args:
|
|
208
|
-
----
|
|
209
216
|
img: Image
|
|
210
217
|
target: tuple of target polygons and their classes names
|
|
211
218
|
|
|
212
219
|
Returns:
|
|
213
|
-
-------
|
|
214
220
|
Image and dictionary of boxes, with class names as keys
|
|
215
221
|
"""
|
|
216
222
|
boxes = convert_to_relative_coords(target[0], get_img_shape(img))
|
|
217
223
|
boxes_classes = target[1]
|
|
218
|
-
boxes_dict:
|
|
224
|
+
boxes_dict: dict = {k: [] for k in sorted(set(boxes_classes))}
|
|
219
225
|
for k, poly in zip(boxes_classes, boxes):
|
|
220
226
|
boxes_dict[k].append(poly)
|
|
221
227
|
boxes_dict = {k: np.stack(v, axis=0) for k, v in boxes_dict.items()}
|