python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
doctr/contrib/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .artefacts import ArtefactDetector
|
doctr/contrib/artefacts.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import cv2
|
|
9
9
|
import numpy as np
|
|
@@ -14,7 +14,7 @@ from .base import _BasePredictor
|
|
|
14
14
|
|
|
15
15
|
__all__ = ["ArtefactDetector"]
|
|
16
16
|
|
|
17
|
-
default_cfgs:
|
|
17
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
18
18
|
"yolov8_artefact": {
|
|
19
19
|
"input_shape": (3, 1024, 1024),
|
|
20
20
|
"labels": ["bar_code", "qr_code", "logo", "photo"],
|
|
@@ -34,7 +34,6 @@ class ArtefactDetector(_BasePredictor):
|
|
|
34
34
|
>>> results = detector(doc)
|
|
35
35
|
|
|
36
36
|
Args:
|
|
37
|
-
----
|
|
38
37
|
arch: the architecture to use
|
|
39
38
|
batch_size: the batch size to use
|
|
40
39
|
model_path: the path to the model to use
|
|
@@ -50,9 +49,9 @@ class ArtefactDetector(_BasePredictor):
|
|
|
50
49
|
self,
|
|
51
50
|
arch: str = "yolov8_artefact",
|
|
52
51
|
batch_size: int = 2,
|
|
53
|
-
model_path:
|
|
54
|
-
labels:
|
|
55
|
-
input_shape:
|
|
52
|
+
model_path: str | None = None,
|
|
53
|
+
labels: list[str] | None = None,
|
|
54
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
56
55
|
conf_threshold: float = 0.5,
|
|
57
56
|
iou_threshold: float = 0.5,
|
|
58
57
|
**kwargs: Any,
|
|
@@ -66,7 +65,7 @@ class ArtefactDetector(_BasePredictor):
|
|
|
66
65
|
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
|
67
66
|
return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0)
|
|
68
67
|
|
|
69
|
-
def postprocess(self, output:
|
|
68
|
+
def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> list[list[dict[str, Any]]]:
|
|
70
69
|
results = []
|
|
71
70
|
|
|
72
71
|
for batch in zip(output, input_images):
|
|
@@ -109,7 +108,6 @@ class ArtefactDetector(_BasePredictor):
|
|
|
109
108
|
Display the results
|
|
110
109
|
|
|
111
110
|
Args:
|
|
112
|
-
----
|
|
113
111
|
**kwargs: additional keyword arguments to be passed to `plt.show`
|
|
114
112
|
"""
|
|
115
113
|
requires_package("matplotlib", "`.show()` requires matplotlib installed")
|
doctr/contrib/base.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
@@ -16,32 +16,29 @@ class _BasePredictor:
|
|
|
16
16
|
Base class for all predictors
|
|
17
17
|
|
|
18
18
|
Args:
|
|
19
|
-
----
|
|
20
19
|
batch_size: the batch size to use
|
|
21
20
|
url: the url to use to download a model if needed
|
|
22
21
|
model_path: the path to the model to use
|
|
23
22
|
**kwargs: additional arguments to be passed to `download_from_url`
|
|
24
23
|
"""
|
|
25
24
|
|
|
26
|
-
def __init__(self, batch_size: int, url:
|
|
25
|
+
def __init__(self, batch_size: int, url: str | None = None, model_path: str | None = None, **kwargs) -> None:
|
|
27
26
|
self.batch_size = batch_size
|
|
28
27
|
self.session = self._init_model(url, model_path, **kwargs)
|
|
29
28
|
|
|
30
|
-
self._inputs:
|
|
31
|
-
self._results:
|
|
29
|
+
self._inputs: list[np.ndarray] = []
|
|
30
|
+
self._results: list[Any] = []
|
|
32
31
|
|
|
33
|
-
def _init_model(self, url:
|
|
32
|
+
def _init_model(self, url: str | None = None, model_path: str | None = None, **kwargs: Any) -> Any:
|
|
34
33
|
"""
|
|
35
34
|
Download the model from the given url if needed
|
|
36
35
|
|
|
37
36
|
Args:
|
|
38
|
-
----
|
|
39
37
|
url: the url to use
|
|
40
38
|
model_path: the path to the model to use
|
|
41
39
|
**kwargs: additional arguments to be passed to `download_from_url`
|
|
42
40
|
|
|
43
41
|
Returns:
|
|
44
|
-
-------
|
|
45
42
|
Any: the ONNX loaded model
|
|
46
43
|
"""
|
|
47
44
|
requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
|
|
@@ -57,40 +54,34 @@ class _BasePredictor:
|
|
|
57
54
|
Preprocess the input image
|
|
58
55
|
|
|
59
56
|
Args:
|
|
60
|
-
----
|
|
61
57
|
img: the input image to preprocess
|
|
62
58
|
|
|
63
59
|
Returns:
|
|
64
|
-
-------
|
|
65
60
|
np.ndarray: the preprocessed image
|
|
66
61
|
"""
|
|
67
62
|
raise NotImplementedError
|
|
68
63
|
|
|
69
|
-
def postprocess(self, output:
|
|
64
|
+
def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> Any:
|
|
70
65
|
"""
|
|
71
66
|
Postprocess the model output
|
|
72
67
|
|
|
73
68
|
Args:
|
|
74
|
-
----
|
|
75
69
|
output: the model output to postprocess
|
|
76
70
|
input_images: the input images used to generate the output
|
|
77
71
|
|
|
78
72
|
Returns:
|
|
79
|
-
-------
|
|
80
73
|
Any: the postprocessed output
|
|
81
74
|
"""
|
|
82
75
|
raise NotImplementedError
|
|
83
76
|
|
|
84
|
-
def __call__(self, inputs:
|
|
77
|
+
def __call__(self, inputs: list[np.ndarray]) -> Any:
|
|
85
78
|
"""
|
|
86
79
|
Call the model on the given inputs
|
|
87
80
|
|
|
88
81
|
Args:
|
|
89
|
-
----
|
|
90
82
|
inputs: the inputs to use
|
|
91
83
|
|
|
92
84
|
Returns:
|
|
93
|
-
-------
|
|
94
85
|
Any: the postprocessed output
|
|
95
86
|
"""
|
|
96
87
|
self._inputs = inputs
|
doctr/datasets/cord.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from tqdm import tqdm
|
|
@@ -29,10 +29,10 @@ class CORD(VisionDataset):
|
|
|
29
29
|
>>> img, target = train_set[0]
|
|
30
30
|
|
|
31
31
|
Args:
|
|
32
|
-
----
|
|
33
32
|
train: whether the subset should be the training one
|
|
34
33
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
35
34
|
recognition_task: whether the dataset should be used for recognition task
|
|
35
|
+
detection_task: whether the dataset should be used for detection task
|
|
36
36
|
**kwargs: keyword arguments from `VisionDataset`.
|
|
37
37
|
"""
|
|
38
38
|
|
|
@@ -53,6 +53,7 @@ class CORD(VisionDataset):
|
|
|
53
53
|
train: bool = True,
|
|
54
54
|
use_polygons: bool = False,
|
|
55
55
|
recognition_task: bool = False,
|
|
56
|
+
detection_task: bool = False,
|
|
56
57
|
**kwargs: Any,
|
|
57
58
|
) -> None:
|
|
58
59
|
url, sha256, name = self.TRAIN if train else self.TEST
|
|
@@ -64,13 +65,20 @@ class CORD(VisionDataset):
|
|
|
64
65
|
pre_transforms=convert_target_to_relative if not recognition_task else None,
|
|
65
66
|
**kwargs,
|
|
66
67
|
)
|
|
68
|
+
if recognition_task and detection_task:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
"`recognition_task` and `detection_task` cannot be set to True simultaneously. "
|
|
71
|
+
+ "To get the whole dataset with boxes and labels leave both parameters to False."
|
|
72
|
+
)
|
|
67
73
|
|
|
68
|
-
#
|
|
74
|
+
# list images
|
|
69
75
|
tmp_root = os.path.join(self.root, "image")
|
|
70
|
-
self.data:
|
|
76
|
+
self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
|
|
71
77
|
self.train = train
|
|
72
78
|
np_dtype = np.float32
|
|
73
|
-
for img_path in tqdm(
|
|
79
|
+
for img_path in tqdm(
|
|
80
|
+
iterable=os.listdir(tmp_root), desc="Preparing and Loading CORD", total=len(os.listdir(tmp_root))
|
|
81
|
+
):
|
|
74
82
|
# File existence check
|
|
75
83
|
if not os.path.exists(os.path.join(tmp_root, img_path)):
|
|
76
84
|
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
|
|
@@ -84,7 +92,7 @@ class CORD(VisionDataset):
|
|
|
84
92
|
if len(word["text"]) > 0:
|
|
85
93
|
x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"]
|
|
86
94
|
y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"]
|
|
87
|
-
box:
|
|
95
|
+
box: list[float] | np.ndarray
|
|
88
96
|
if use_polygons:
|
|
89
97
|
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
|
90
98
|
box = np.array(
|
|
@@ -109,6 +117,8 @@ class CORD(VisionDataset):
|
|
|
109
117
|
)
|
|
110
118
|
for crop, label in zip(crops, list(text_targets)):
|
|
111
119
|
self.data.append((crop, label))
|
|
120
|
+
elif detection_task:
|
|
121
|
+
self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0)))
|
|
112
122
|
else:
|
|
113
123
|
self.data.append((
|
|
114
124
|
img_path,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
doctr/datasets/datasets/base.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import os
|
|
7
7
|
import shutil
|
|
8
|
+
from collections.abc import Callable
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import Any
|
|
10
|
+
from typing import Any
|
|
10
11
|
|
|
11
12
|
import numpy as np
|
|
12
13
|
|
|
@@ -19,15 +20,15 @@ __all__ = ["_AbstractDataset", "_VisionDataset"]
|
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class _AbstractDataset:
|
|
22
|
-
data:
|
|
23
|
-
_pre_transforms:
|
|
23
|
+
data: list[Any] = []
|
|
24
|
+
_pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None
|
|
24
25
|
|
|
25
26
|
def __init__(
|
|
26
27
|
self,
|
|
27
|
-
root:
|
|
28
|
-
img_transforms:
|
|
29
|
-
sample_transforms:
|
|
30
|
-
pre_transforms:
|
|
28
|
+
root: str | Path,
|
|
29
|
+
img_transforms: Callable[[Any], Any] | None = None,
|
|
30
|
+
sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
|
|
31
|
+
pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
|
|
31
32
|
) -> None:
|
|
32
33
|
if not Path(root).is_dir():
|
|
33
34
|
raise ValueError(f"expected a path to a reachable folder: {root}")
|
|
@@ -41,10 +42,10 @@ class _AbstractDataset:
|
|
|
41
42
|
def __len__(self) -> int:
|
|
42
43
|
return len(self.data)
|
|
43
44
|
|
|
44
|
-
def _read_sample(self, index: int) ->
|
|
45
|
+
def _read_sample(self, index: int) -> tuple[Any, Any]:
|
|
45
46
|
raise NotImplementedError
|
|
46
47
|
|
|
47
|
-
def __getitem__(self, index: int) ->
|
|
48
|
+
def __getitem__(self, index: int) -> tuple[Any, Any]:
|
|
48
49
|
# Read image
|
|
49
50
|
img, target = self._read_sample(index)
|
|
50
51
|
# Pre-transforms (format conversion at run-time etc.)
|
|
@@ -82,7 +83,6 @@ class _VisionDataset(_AbstractDataset):
|
|
|
82
83
|
"""Implements an abstract dataset
|
|
83
84
|
|
|
84
85
|
Args:
|
|
85
|
-
----
|
|
86
86
|
url: URL of the dataset
|
|
87
87
|
file_name: name of the file once downloaded
|
|
88
88
|
file_hash: expected SHA256 of the file
|
|
@@ -96,13 +96,13 @@ class _VisionDataset(_AbstractDataset):
|
|
|
96
96
|
def __init__(
|
|
97
97
|
self,
|
|
98
98
|
url: str,
|
|
99
|
-
file_name:
|
|
100
|
-
file_hash:
|
|
99
|
+
file_name: str | None = None,
|
|
100
|
+
file_hash: str | None = None,
|
|
101
101
|
extract_archive: bool = False,
|
|
102
102
|
download: bool = False,
|
|
103
103
|
overwrite: bool = False,
|
|
104
|
-
cache_dir:
|
|
105
|
-
cache_subdir:
|
|
104
|
+
cache_dir: str | None = None,
|
|
105
|
+
cache_subdir: str | None = None,
|
|
106
106
|
**kwargs: Any,
|
|
107
107
|
) -> None:
|
|
108
108
|
cache_dir = (
|
|
@@ -115,7 +115,7 @@ class _VisionDataset(_AbstractDataset):
|
|
|
115
115
|
|
|
116
116
|
file_name = file_name if isinstance(file_name, str) else os.path.basename(url)
|
|
117
117
|
# Download the file if not present
|
|
118
|
-
archive_path:
|
|
118
|
+
archive_path: str | Path = os.path.join(cache_dir, cache_subdir, file_name)
|
|
119
119
|
|
|
120
120
|
if not os.path.exists(archive_path) and not download:
|
|
121
121
|
raise ValueError("the dataset needs to be downloaded first with download=True")
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import os
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import torch
|
|
@@ -20,7 +20,7 @@ __all__ = ["AbstractDataset", "VisionDataset"]
|
|
|
20
20
|
class AbstractDataset(_AbstractDataset):
|
|
21
21
|
"""Abstract class for all datasets"""
|
|
22
22
|
|
|
23
|
-
def _read_sample(self, index: int) ->
|
|
23
|
+
def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]:
|
|
24
24
|
img_name, target = self.data[index]
|
|
25
25
|
|
|
26
26
|
# Check target
|
|
@@ -29,14 +29,14 @@ class AbstractDataset(_AbstractDataset):
|
|
|
29
29
|
assert "labels" in target, "Target should contain 'labels' key"
|
|
30
30
|
elif isinstance(target, tuple):
|
|
31
31
|
assert len(target) == 2
|
|
32
|
-
assert isinstance(target[0], str) or isinstance(
|
|
33
|
-
|
|
34
|
-
)
|
|
32
|
+
assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
|
|
33
|
+
"first element of the tuple should be a string or a numpy array"
|
|
34
|
+
)
|
|
35
35
|
assert isinstance(target[1], list), "second element of the tuple should be a list"
|
|
36
36
|
else:
|
|
37
|
-
assert isinstance(target, str) or isinstance(
|
|
38
|
-
|
|
39
|
-
)
|
|
37
|
+
assert isinstance(target, str) or isinstance(target, np.ndarray), (
|
|
38
|
+
"Target should be a string or a numpy array"
|
|
39
|
+
)
|
|
40
40
|
|
|
41
41
|
# Read image
|
|
42
42
|
img = (
|
|
@@ -48,11 +48,11 @@ class AbstractDataset(_AbstractDataset):
|
|
|
48
48
|
return img, deepcopy(target)
|
|
49
49
|
|
|
50
50
|
@staticmethod
|
|
51
|
-
def collate_fn(samples:
|
|
51
|
+
def collate_fn(samples: list[tuple[torch.Tensor, Any]]) -> tuple[torch.Tensor, list[Any]]:
|
|
52
52
|
images, targets = zip(*samples)
|
|
53
|
-
images = torch.stack(images, dim=0)
|
|
53
|
+
images = torch.stack(images, dim=0)
|
|
54
54
|
|
|
55
|
-
return images, list(targets)
|
|
55
|
+
return images, list(targets)
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import os
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import tensorflow as tf
|
|
@@ -20,7 +20,7 @@ __all__ = ["AbstractDataset", "VisionDataset"]
|
|
|
20
20
|
class AbstractDataset(_AbstractDataset):
|
|
21
21
|
"""Abstract class for all datasets"""
|
|
22
22
|
|
|
23
|
-
def _read_sample(self, index: int) ->
|
|
23
|
+
def _read_sample(self, index: int) -> tuple[tf.Tensor, Any]:
|
|
24
24
|
img_name, target = self.data[index]
|
|
25
25
|
|
|
26
26
|
# Check target
|
|
@@ -29,14 +29,14 @@ class AbstractDataset(_AbstractDataset):
|
|
|
29
29
|
assert "labels" in target, "Target should contain 'labels' key"
|
|
30
30
|
elif isinstance(target, tuple):
|
|
31
31
|
assert len(target) == 2
|
|
32
|
-
assert isinstance(target[0], str) or isinstance(
|
|
33
|
-
|
|
34
|
-
)
|
|
32
|
+
assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
|
|
33
|
+
"first element of the tuple should be a string or a numpy array"
|
|
34
|
+
)
|
|
35
35
|
assert isinstance(target[1], list), "second element of the tuple should be a list"
|
|
36
36
|
else:
|
|
37
|
-
assert isinstance(target, str) or isinstance(
|
|
38
|
-
|
|
39
|
-
)
|
|
37
|
+
assert isinstance(target, str) or isinstance(target, np.ndarray), (
|
|
38
|
+
"Target should be a string or a numpy array"
|
|
39
|
+
)
|
|
40
40
|
|
|
41
41
|
# Read image
|
|
42
42
|
img = (
|
|
@@ -48,7 +48,7 @@ class AbstractDataset(_AbstractDataset):
|
|
|
48
48
|
return img, deepcopy(target)
|
|
49
49
|
|
|
50
50
|
@staticmethod
|
|
51
|
-
def collate_fn(samples:
|
|
51
|
+
def collate_fn(samples: list[tuple[tf.Tensor, Any]]) -> tuple[tf.Tensor, list[Any]]:
|
|
52
52
|
images, targets = zip(*samples)
|
|
53
53
|
images = tf.stack(images, axis=0)
|
|
54
54
|
|
doctr/datasets/detection.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
@@ -26,7 +26,6 @@ class DetectionDataset(AbstractDataset):
|
|
|
26
26
|
>>> img, target = train_set[0]
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
|
-
----
|
|
30
29
|
img_folder: folder with all the images of the dataset
|
|
31
30
|
label_path: path to the annotations of each image
|
|
32
31
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
@@ -47,13 +46,13 @@ class DetectionDataset(AbstractDataset):
|
|
|
47
46
|
)
|
|
48
47
|
|
|
49
48
|
# File existence check
|
|
50
|
-
self._class_names:
|
|
49
|
+
self._class_names: list = []
|
|
51
50
|
if not os.path.exists(label_path):
|
|
52
51
|
raise FileNotFoundError(f"unable to locate {label_path}")
|
|
53
52
|
with open(label_path, "rb") as f:
|
|
54
53
|
labels = json.load(f)
|
|
55
54
|
|
|
56
|
-
self.data:
|
|
55
|
+
self.data: list[tuple[str, tuple[np.ndarray, list[str]]]] = []
|
|
57
56
|
np_dtype = np.float32
|
|
58
57
|
for img_name, label in labels.items():
|
|
59
58
|
# File existence check
|
|
@@ -65,18 +64,16 @@ class DetectionDataset(AbstractDataset):
|
|
|
65
64
|
self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
|
|
66
65
|
|
|
67
66
|
def format_polygons(
|
|
68
|
-
self, polygons:
|
|
69
|
-
) ->
|
|
67
|
+
self, polygons: list | dict, use_polygons: bool, np_dtype: type
|
|
68
|
+
) -> tuple[np.ndarray, list[str]]:
|
|
70
69
|
"""Format polygons into an array
|
|
71
70
|
|
|
72
71
|
Args:
|
|
73
|
-
----
|
|
74
72
|
polygons: the bounding boxes
|
|
75
73
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
76
74
|
np_dtype: dtype of array
|
|
77
75
|
|
|
78
76
|
Returns:
|
|
79
|
-
-------
|
|
80
77
|
geoms: bounding boxes as np array
|
|
81
78
|
polygons_classes: list of classes for each bounding box
|
|
82
79
|
"""
|
doctr/datasets/doc_artefacts.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
@@ -26,7 +26,6 @@ class DocArtefacts(VisionDataset):
|
|
|
26
26
|
>>> img, target = train_set[0]
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
|
-
----
|
|
30
29
|
train: whether the subset should be the training one
|
|
31
30
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
32
31
|
**kwargs: keyword arguments from `VisionDataset`.
|
|
@@ -51,7 +50,7 @@ class DocArtefacts(VisionDataset):
|
|
|
51
50
|
tmp_root = os.path.join(self.root, "images")
|
|
52
51
|
with open(os.path.join(self.root, "labels.json"), "rb") as f:
|
|
53
52
|
labels = json.load(f)
|
|
54
|
-
self.data:
|
|
53
|
+
self.data: list[tuple[str, dict[str, Any]]] = []
|
|
55
54
|
img_list = os.listdir(tmp_root)
|
|
56
55
|
if len(labels) != len(img_list):
|
|
57
56
|
raise AssertionError("the number of images and labels do not match")
|
doctr/datasets/funsd.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from tqdm import tqdm
|
|
@@ -29,10 +29,10 @@ class FUNSD(VisionDataset):
|
|
|
29
29
|
>>> img, target = train_set[0]
|
|
30
30
|
|
|
31
31
|
Args:
|
|
32
|
-
----
|
|
33
32
|
train: whether the subset should be the training one
|
|
34
33
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
35
34
|
recognition_task: whether the dataset should be used for recognition task
|
|
35
|
+
detection_task: whether the dataset should be used for detection task
|
|
36
36
|
**kwargs: keyword arguments from `VisionDataset`.
|
|
37
37
|
"""
|
|
38
38
|
|
|
@@ -45,6 +45,7 @@ class FUNSD(VisionDataset):
|
|
|
45
45
|
train: bool = True,
|
|
46
46
|
use_polygons: bool = False,
|
|
47
47
|
recognition_task: bool = False,
|
|
48
|
+
detection_task: bool = False,
|
|
48
49
|
**kwargs: Any,
|
|
49
50
|
) -> None:
|
|
50
51
|
super().__init__(
|
|
@@ -55,16 +56,24 @@ class FUNSD(VisionDataset):
|
|
|
55
56
|
pre_transforms=convert_target_to_relative if not recognition_task else None,
|
|
56
57
|
**kwargs,
|
|
57
58
|
)
|
|
59
|
+
if recognition_task and detection_task:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"`recognition_task` and `detection_task` cannot be set to True simultaneously. "
|
|
62
|
+
+ "To get the whole dataset with boxes and labels leave both parameters to False."
|
|
63
|
+
)
|
|
64
|
+
|
|
58
65
|
self.train = train
|
|
59
66
|
np_dtype = np.float32
|
|
60
67
|
|
|
61
68
|
# Use the subset
|
|
62
69
|
subfolder = os.path.join("dataset", "training_data" if train else "testing_data")
|
|
63
70
|
|
|
64
|
-
# #
|
|
71
|
+
# # list images
|
|
65
72
|
tmp_root = os.path.join(self.root, subfolder, "images")
|
|
66
|
-
self.data:
|
|
67
|
-
for img_path in tqdm(
|
|
73
|
+
self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
|
|
74
|
+
for img_path in tqdm(
|
|
75
|
+
iterable=os.listdir(tmp_root), desc="Preparing and Loading FUNSD", total=len(os.listdir(tmp_root))
|
|
76
|
+
):
|
|
68
77
|
# File existence check
|
|
69
78
|
if not os.path.exists(os.path.join(tmp_root, img_path)):
|
|
70
79
|
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
|
|
@@ -100,6 +109,8 @@ class FUNSD(VisionDataset):
|
|
|
100
109
|
# filter labels with unknown characters
|
|
101
110
|
if not any(char in label for char in ["☑", "☐", "\uf703", "\uf702"]):
|
|
102
111
|
self.data.append((crop, label))
|
|
112
|
+
elif detection_task:
|
|
113
|
+
self.data.append((img_path, np.asarray(box_targets, dtype=np_dtype)))
|
|
103
114
|
else:
|
|
104
115
|
self.data.append((
|
|
105
116
|
img_path,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|