python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +8 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +7 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +4 -5
- doctr/datasets/ic13.py +4 -5
- doctr/datasets/iiit5k.py +6 -5
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +6 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +6 -5
- doctr/datasets/svt.py +4 -5
- doctr/datasets/synthtext.py +4 -5
- doctr/datasets/utils.py +34 -29
- doctr/datasets/vocabs.py +17 -7
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +2 -7
- doctr/io/elements.py +59 -79
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +30 -48
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +8 -11
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +5 -17
- doctr/models/classification/mobilenet/tensorflow.py +8 -21
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +6 -8
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +20 -31
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +8 -15
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +9 -12
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +6 -12
- doctr/models/classification/zoo.py +19 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +14 -26
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +14 -23
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +3 -7
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +18 -19
- doctr/models/kie_predictor/tensorflow.py +13 -14
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +12 -13
- doctr/models/predictor/tensorflow.py +8 -9
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +11 -23
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +12 -22
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +16 -22
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +12 -21
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +12 -20
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +14 -17
- doctr/models/utils/tensorflow.py +17 -16
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +16 -47
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
doctr/contrib/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .artefacts import ArtefactDetector
|
doctr/contrib/artefacts.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import cv2
|
|
9
9
|
import numpy as np
|
|
@@ -14,7 +14,7 @@ from .base import _BasePredictor
|
|
|
14
14
|
|
|
15
15
|
__all__ = ["ArtefactDetector"]
|
|
16
16
|
|
|
17
|
-
default_cfgs:
|
|
17
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
18
18
|
"yolov8_artefact": {
|
|
19
19
|
"input_shape": (3, 1024, 1024),
|
|
20
20
|
"labels": ["bar_code", "qr_code", "logo", "photo"],
|
|
@@ -34,7 +34,6 @@ class ArtefactDetector(_BasePredictor):
|
|
|
34
34
|
>>> results = detector(doc)
|
|
35
35
|
|
|
36
36
|
Args:
|
|
37
|
-
----
|
|
38
37
|
arch: the architecture to use
|
|
39
38
|
batch_size: the batch size to use
|
|
40
39
|
model_path: the path to the model to use
|
|
@@ -50,9 +49,9 @@ class ArtefactDetector(_BasePredictor):
|
|
|
50
49
|
self,
|
|
51
50
|
arch: str = "yolov8_artefact",
|
|
52
51
|
batch_size: int = 2,
|
|
53
|
-
model_path:
|
|
54
|
-
labels:
|
|
55
|
-
input_shape:
|
|
52
|
+
model_path: str | None = None,
|
|
53
|
+
labels: list[str] | None = None,
|
|
54
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
56
55
|
conf_threshold: float = 0.5,
|
|
57
56
|
iou_threshold: float = 0.5,
|
|
58
57
|
**kwargs: Any,
|
|
@@ -66,7 +65,7 @@ class ArtefactDetector(_BasePredictor):
|
|
|
66
65
|
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
|
67
66
|
return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0)
|
|
68
67
|
|
|
69
|
-
def postprocess(self, output:
|
|
68
|
+
def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> list[list[dict[str, Any]]]:
|
|
70
69
|
results = []
|
|
71
70
|
|
|
72
71
|
for batch in zip(output, input_images):
|
|
@@ -109,7 +108,6 @@ class ArtefactDetector(_BasePredictor):
|
|
|
109
108
|
Display the results
|
|
110
109
|
|
|
111
110
|
Args:
|
|
112
|
-
----
|
|
113
111
|
**kwargs: additional keyword arguments to be passed to `plt.show`
|
|
114
112
|
"""
|
|
115
113
|
requires_package("matplotlib", "`.show()` requires matplotlib installed")
|
doctr/contrib/base.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
@@ -16,32 +16,29 @@ class _BasePredictor:
|
|
|
16
16
|
Base class for all predictors
|
|
17
17
|
|
|
18
18
|
Args:
|
|
19
|
-
----
|
|
20
19
|
batch_size: the batch size to use
|
|
21
20
|
url: the url to use to download a model if needed
|
|
22
21
|
model_path: the path to the model to use
|
|
23
22
|
**kwargs: additional arguments to be passed to `download_from_url`
|
|
24
23
|
"""
|
|
25
24
|
|
|
26
|
-
def __init__(self, batch_size: int, url:
|
|
25
|
+
def __init__(self, batch_size: int, url: str | None = None, model_path: str | None = None, **kwargs) -> None:
|
|
27
26
|
self.batch_size = batch_size
|
|
28
27
|
self.session = self._init_model(url, model_path, **kwargs)
|
|
29
28
|
|
|
30
|
-
self._inputs:
|
|
31
|
-
self._results:
|
|
29
|
+
self._inputs: list[np.ndarray] = []
|
|
30
|
+
self._results: list[Any] = []
|
|
32
31
|
|
|
33
|
-
def _init_model(self, url:
|
|
32
|
+
def _init_model(self, url: str | None = None, model_path: str | None = None, **kwargs: Any) -> Any:
|
|
34
33
|
"""
|
|
35
34
|
Download the model from the given url if needed
|
|
36
35
|
|
|
37
36
|
Args:
|
|
38
|
-
----
|
|
39
37
|
url: the url to use
|
|
40
38
|
model_path: the path to the model to use
|
|
41
39
|
**kwargs: additional arguments to be passed to `download_from_url`
|
|
42
40
|
|
|
43
41
|
Returns:
|
|
44
|
-
-------
|
|
45
42
|
Any: the ONNX loaded model
|
|
46
43
|
"""
|
|
47
44
|
requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
|
|
@@ -57,40 +54,34 @@ class _BasePredictor:
|
|
|
57
54
|
Preprocess the input image
|
|
58
55
|
|
|
59
56
|
Args:
|
|
60
|
-
----
|
|
61
57
|
img: the input image to preprocess
|
|
62
58
|
|
|
63
59
|
Returns:
|
|
64
|
-
-------
|
|
65
60
|
np.ndarray: the preprocessed image
|
|
66
61
|
"""
|
|
67
62
|
raise NotImplementedError
|
|
68
63
|
|
|
69
|
-
def postprocess(self, output:
|
|
64
|
+
def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> Any:
|
|
70
65
|
"""
|
|
71
66
|
Postprocess the model output
|
|
72
67
|
|
|
73
68
|
Args:
|
|
74
|
-
----
|
|
75
69
|
output: the model output to postprocess
|
|
76
70
|
input_images: the input images used to generate the output
|
|
77
71
|
|
|
78
72
|
Returns:
|
|
79
|
-
-------
|
|
80
73
|
Any: the postprocessed output
|
|
81
74
|
"""
|
|
82
75
|
raise NotImplementedError
|
|
83
76
|
|
|
84
|
-
def __call__(self, inputs:
|
|
77
|
+
def __call__(self, inputs: list[np.ndarray]) -> Any:
|
|
85
78
|
"""
|
|
86
79
|
Call the model on the given inputs
|
|
87
80
|
|
|
88
81
|
Args:
|
|
89
|
-
----
|
|
90
82
|
inputs: the inputs to use
|
|
91
83
|
|
|
92
84
|
Returns:
|
|
93
|
-
-------
|
|
94
85
|
Any: the postprocessed output
|
|
95
86
|
"""
|
|
96
87
|
self._inputs = inputs
|
doctr/datasets/cord.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from tqdm import tqdm
|
|
@@ -29,7 +29,6 @@ class CORD(VisionDataset):
|
|
|
29
29
|
>>> img, target = train_set[0]
|
|
30
30
|
|
|
31
31
|
Args:
|
|
32
|
-
----
|
|
33
32
|
train: whether the subset should be the training one
|
|
34
33
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
35
34
|
recognition_task: whether the dataset should be used for recognition task
|
|
@@ -72,12 +71,14 @@ class CORD(VisionDataset):
|
|
|
72
71
|
+ "To get the whole dataset with boxes and labels leave both parameters to False."
|
|
73
72
|
)
|
|
74
73
|
|
|
75
|
-
#
|
|
74
|
+
# list images
|
|
76
75
|
tmp_root = os.path.join(self.root, "image")
|
|
77
|
-
self.data:
|
|
76
|
+
self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
|
|
78
77
|
self.train = train
|
|
79
78
|
np_dtype = np.float32
|
|
80
|
-
for img_path in tqdm(
|
|
79
|
+
for img_path in tqdm(
|
|
80
|
+
iterable=os.listdir(tmp_root), desc="Preparing and Loading CORD", total=len(os.listdir(tmp_root))
|
|
81
|
+
):
|
|
81
82
|
# File existence check
|
|
82
83
|
if not os.path.exists(os.path.join(tmp_root, img_path)):
|
|
83
84
|
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
|
|
@@ -91,7 +92,7 @@ class CORD(VisionDataset):
|
|
|
91
92
|
if len(word["text"]) > 0:
|
|
92
93
|
x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"]
|
|
93
94
|
y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"]
|
|
94
|
-
box:
|
|
95
|
+
box: list[float] | np.ndarray
|
|
95
96
|
if use_polygons:
|
|
96
97
|
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
|
97
98
|
box = np.array(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
doctr/datasets/datasets/base.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import os
|
|
7
7
|
import shutil
|
|
8
|
+
from collections.abc import Callable
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import Any
|
|
10
|
+
from typing import Any
|
|
10
11
|
|
|
11
12
|
import numpy as np
|
|
12
13
|
|
|
@@ -19,15 +20,15 @@ __all__ = ["_AbstractDataset", "_VisionDataset"]
|
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class _AbstractDataset:
|
|
22
|
-
data:
|
|
23
|
-
_pre_transforms:
|
|
23
|
+
data: list[Any] = []
|
|
24
|
+
_pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None
|
|
24
25
|
|
|
25
26
|
def __init__(
|
|
26
27
|
self,
|
|
27
|
-
root:
|
|
28
|
-
img_transforms:
|
|
29
|
-
sample_transforms:
|
|
30
|
-
pre_transforms:
|
|
28
|
+
root: str | Path,
|
|
29
|
+
img_transforms: Callable[[Any], Any] | None = None,
|
|
30
|
+
sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
|
|
31
|
+
pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
|
|
31
32
|
) -> None:
|
|
32
33
|
if not Path(root).is_dir():
|
|
33
34
|
raise ValueError(f"expected a path to a reachable folder: {root}")
|
|
@@ -41,10 +42,10 @@ class _AbstractDataset:
|
|
|
41
42
|
def __len__(self) -> int:
|
|
42
43
|
return len(self.data)
|
|
43
44
|
|
|
44
|
-
def _read_sample(self, index: int) ->
|
|
45
|
+
def _read_sample(self, index: int) -> tuple[Any, Any]:
|
|
45
46
|
raise NotImplementedError
|
|
46
47
|
|
|
47
|
-
def __getitem__(self, index: int) ->
|
|
48
|
+
def __getitem__(self, index: int) -> tuple[Any, Any]:
|
|
48
49
|
# Read image
|
|
49
50
|
img, target = self._read_sample(index)
|
|
50
51
|
# Pre-transforms (format conversion at run-time etc.)
|
|
@@ -82,7 +83,6 @@ class _VisionDataset(_AbstractDataset):
|
|
|
82
83
|
"""Implements an abstract dataset
|
|
83
84
|
|
|
84
85
|
Args:
|
|
85
|
-
----
|
|
86
86
|
url: URL of the dataset
|
|
87
87
|
file_name: name of the file once downloaded
|
|
88
88
|
file_hash: expected SHA256 of the file
|
|
@@ -96,13 +96,13 @@ class _VisionDataset(_AbstractDataset):
|
|
|
96
96
|
def __init__(
|
|
97
97
|
self,
|
|
98
98
|
url: str,
|
|
99
|
-
file_name:
|
|
100
|
-
file_hash:
|
|
99
|
+
file_name: str | None = None,
|
|
100
|
+
file_hash: str | None = None,
|
|
101
101
|
extract_archive: bool = False,
|
|
102
102
|
download: bool = False,
|
|
103
103
|
overwrite: bool = False,
|
|
104
|
-
cache_dir:
|
|
105
|
-
cache_subdir:
|
|
104
|
+
cache_dir: str | None = None,
|
|
105
|
+
cache_subdir: str | None = None,
|
|
106
106
|
**kwargs: Any,
|
|
107
107
|
) -> None:
|
|
108
108
|
cache_dir = (
|
|
@@ -115,7 +115,7 @@ class _VisionDataset(_AbstractDataset):
|
|
|
115
115
|
|
|
116
116
|
file_name = file_name if isinstance(file_name, str) else os.path.basename(url)
|
|
117
117
|
# Download the file if not present
|
|
118
|
-
archive_path:
|
|
118
|
+
archive_path: str | Path = os.path.join(cache_dir, cache_subdir, file_name)
|
|
119
119
|
|
|
120
120
|
if not os.path.exists(archive_path) and not download:
|
|
121
121
|
raise ValueError("the dataset needs to be downloaded first with download=True")
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import os
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import torch
|
|
@@ -20,7 +20,7 @@ __all__ = ["AbstractDataset", "VisionDataset"]
|
|
|
20
20
|
class AbstractDataset(_AbstractDataset):
|
|
21
21
|
"""Abstract class for all datasets"""
|
|
22
22
|
|
|
23
|
-
def _read_sample(self, index: int) ->
|
|
23
|
+
def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]:
|
|
24
24
|
img_name, target = self.data[index]
|
|
25
25
|
|
|
26
26
|
# Check target
|
|
@@ -29,14 +29,14 @@ class AbstractDataset(_AbstractDataset):
|
|
|
29
29
|
assert "labels" in target, "Target should contain 'labels' key"
|
|
30
30
|
elif isinstance(target, tuple):
|
|
31
31
|
assert len(target) == 2
|
|
32
|
-
assert isinstance(target[0], str) or isinstance(
|
|
33
|
-
|
|
34
|
-
)
|
|
32
|
+
assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
|
|
33
|
+
"first element of the tuple should be a string or a numpy array"
|
|
34
|
+
)
|
|
35
35
|
assert isinstance(target[1], list), "second element of the tuple should be a list"
|
|
36
36
|
else:
|
|
37
|
-
assert isinstance(target, str) or isinstance(
|
|
38
|
-
|
|
39
|
-
)
|
|
37
|
+
assert isinstance(target, str) or isinstance(target, np.ndarray), (
|
|
38
|
+
"Target should be a string or a numpy array"
|
|
39
|
+
)
|
|
40
40
|
|
|
41
41
|
# Read image
|
|
42
42
|
img = (
|
|
@@ -48,11 +48,11 @@ class AbstractDataset(_AbstractDataset):
|
|
|
48
48
|
return img, deepcopy(target)
|
|
49
49
|
|
|
50
50
|
@staticmethod
|
|
51
|
-
def collate_fn(samples:
|
|
51
|
+
def collate_fn(samples: list[tuple[torch.Tensor, Any]]) -> tuple[torch.Tensor, list[Any]]:
|
|
52
52
|
images, targets = zip(*samples)
|
|
53
|
-
images = torch.stack(images, dim=0)
|
|
53
|
+
images = torch.stack(images, dim=0)
|
|
54
54
|
|
|
55
|
-
return images, list(targets)
|
|
55
|
+
return images, list(targets)
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import os
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import tensorflow as tf
|
|
@@ -20,7 +20,7 @@ __all__ = ["AbstractDataset", "VisionDataset"]
|
|
|
20
20
|
class AbstractDataset(_AbstractDataset):
|
|
21
21
|
"""Abstract class for all datasets"""
|
|
22
22
|
|
|
23
|
-
def _read_sample(self, index: int) ->
|
|
23
|
+
def _read_sample(self, index: int) -> tuple[tf.Tensor, Any]:
|
|
24
24
|
img_name, target = self.data[index]
|
|
25
25
|
|
|
26
26
|
# Check target
|
|
@@ -29,14 +29,14 @@ class AbstractDataset(_AbstractDataset):
|
|
|
29
29
|
assert "labels" in target, "Target should contain 'labels' key"
|
|
30
30
|
elif isinstance(target, tuple):
|
|
31
31
|
assert len(target) == 2
|
|
32
|
-
assert isinstance(target[0], str) or isinstance(
|
|
33
|
-
|
|
34
|
-
)
|
|
32
|
+
assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
|
|
33
|
+
"first element of the tuple should be a string or a numpy array"
|
|
34
|
+
)
|
|
35
35
|
assert isinstance(target[1], list), "second element of the tuple should be a list"
|
|
36
36
|
else:
|
|
37
|
-
assert isinstance(target, str) or isinstance(
|
|
38
|
-
|
|
39
|
-
)
|
|
37
|
+
assert isinstance(target, str) or isinstance(target, np.ndarray), (
|
|
38
|
+
"Target should be a string or a numpy array"
|
|
39
|
+
)
|
|
40
40
|
|
|
41
41
|
# Read image
|
|
42
42
|
img = (
|
|
@@ -48,7 +48,7 @@ class AbstractDataset(_AbstractDataset):
|
|
|
48
48
|
return img, deepcopy(target)
|
|
49
49
|
|
|
50
50
|
@staticmethod
|
|
51
|
-
def collate_fn(samples:
|
|
51
|
+
def collate_fn(samples: list[tuple[tf.Tensor, Any]]) -> tuple[tf.Tensor, list[Any]]:
|
|
52
52
|
images, targets = zip(*samples)
|
|
53
53
|
images = tf.stack(images, axis=0)
|
|
54
54
|
|
doctr/datasets/detection.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
@@ -26,7 +26,6 @@ class DetectionDataset(AbstractDataset):
|
|
|
26
26
|
>>> img, target = train_set[0]
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
|
-
----
|
|
30
29
|
img_folder: folder with all the images of the dataset
|
|
31
30
|
label_path: path to the annotations of each image
|
|
32
31
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
@@ -47,13 +46,13 @@ class DetectionDataset(AbstractDataset):
|
|
|
47
46
|
)
|
|
48
47
|
|
|
49
48
|
# File existence check
|
|
50
|
-
self._class_names:
|
|
49
|
+
self._class_names: list = []
|
|
51
50
|
if not os.path.exists(label_path):
|
|
52
51
|
raise FileNotFoundError(f"unable to locate {label_path}")
|
|
53
52
|
with open(label_path, "rb") as f:
|
|
54
53
|
labels = json.load(f)
|
|
55
54
|
|
|
56
|
-
self.data:
|
|
55
|
+
self.data: list[tuple[str, tuple[np.ndarray, list[str]]]] = []
|
|
57
56
|
np_dtype = np.float32
|
|
58
57
|
for img_name, label in labels.items():
|
|
59
58
|
# File existence check
|
|
@@ -65,18 +64,16 @@ class DetectionDataset(AbstractDataset):
|
|
|
65
64
|
self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
|
|
66
65
|
|
|
67
66
|
def format_polygons(
|
|
68
|
-
self, polygons:
|
|
69
|
-
) ->
|
|
67
|
+
self, polygons: list | dict, use_polygons: bool, np_dtype: type
|
|
68
|
+
) -> tuple[np.ndarray, list[str]]:
|
|
70
69
|
"""Format polygons into an array
|
|
71
70
|
|
|
72
71
|
Args:
|
|
73
|
-
----
|
|
74
72
|
polygons: the bounding boxes
|
|
75
73
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
76
74
|
np_dtype: dtype of array
|
|
77
75
|
|
|
78
76
|
Returns:
|
|
79
|
-
-------
|
|
80
77
|
geoms: bounding boxes as np array
|
|
81
78
|
polygons_classes: list of classes for each bounding box
|
|
82
79
|
"""
|
doctr/datasets/doc_artefacts.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
@@ -26,7 +26,6 @@ class DocArtefacts(VisionDataset):
|
|
|
26
26
|
>>> img, target = train_set[0]
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
|
-
----
|
|
30
29
|
train: whether the subset should be the training one
|
|
31
30
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
32
31
|
**kwargs: keyword arguments from `VisionDataset`.
|
|
@@ -51,7 +50,7 @@ class DocArtefacts(VisionDataset):
|
|
|
51
50
|
tmp_root = os.path.join(self.root, "images")
|
|
52
51
|
with open(os.path.join(self.root, "labels.json"), "rb") as f:
|
|
53
52
|
labels = json.load(f)
|
|
54
|
-
self.data:
|
|
53
|
+
self.data: list[tuple[str, dict[str, Any]]] = []
|
|
55
54
|
img_list = os.listdir(tmp_root)
|
|
56
55
|
if len(labels) != len(img_list):
|
|
57
56
|
raise AssertionError("the number of images and labels do not match")
|
doctr/datasets/funsd.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from tqdm import tqdm
|
|
@@ -29,7 +29,6 @@ class FUNSD(VisionDataset):
|
|
|
29
29
|
>>> img, target = train_set[0]
|
|
30
30
|
|
|
31
31
|
Args:
|
|
32
|
-
----
|
|
33
32
|
train: whether the subset should be the training one
|
|
34
33
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
35
34
|
recognition_task: whether the dataset should be used for recognition task
|
|
@@ -69,10 +68,12 @@ class FUNSD(VisionDataset):
|
|
|
69
68
|
# Use the subset
|
|
70
69
|
subfolder = os.path.join("dataset", "training_data" if train else "testing_data")
|
|
71
70
|
|
|
72
|
-
# #
|
|
71
|
+
# # list images
|
|
73
72
|
tmp_root = os.path.join(self.root, subfolder, "images")
|
|
74
|
-
self.data:
|
|
75
|
-
for img_path in tqdm(
|
|
73
|
+
self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
|
|
74
|
+
for img_path in tqdm(
|
|
75
|
+
iterable=os.listdir(tmp_root), desc="Preparing and Loading FUNSD", total=len(os.listdir(tmp_root))
|
|
76
|
+
):
|
|
76
77
|
# File existence check
|
|
77
78
|
if not os.path.exists(os.path.join(tmp_root, img_path)):
|
|
78
79
|
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
doctr/datasets/generator/base.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import random
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
from PIL import Image, ImageDraw
|
|
10
11
|
|
|
@@ -17,14 +18,13 @@ from ..datasets import AbstractDataset
|
|
|
17
18
|
def synthesize_text_img(
|
|
18
19
|
text: str,
|
|
19
20
|
font_size: int = 32,
|
|
20
|
-
font_family:
|
|
21
|
-
background_color:
|
|
22
|
-
text_color:
|
|
21
|
+
font_family: str | None = None,
|
|
22
|
+
background_color: tuple[int, int, int] | None = None,
|
|
23
|
+
text_color: tuple[int, int, int] | None = None,
|
|
23
24
|
) -> Image.Image:
|
|
24
25
|
"""Generate a synthetic text image
|
|
25
26
|
|
|
26
27
|
Args:
|
|
27
|
-
----
|
|
28
28
|
text: the text to render as an image
|
|
29
29
|
font_size: the size of the font
|
|
30
30
|
font_family: the font family (has to be installed on your system)
|
|
@@ -32,7 +32,6 @@ def synthesize_text_img(
|
|
|
32
32
|
text_color: text color on the final image
|
|
33
33
|
|
|
34
34
|
Returns:
|
|
35
|
-
-------
|
|
36
35
|
PIL image of the text
|
|
37
36
|
"""
|
|
38
37
|
background_color = (0, 0, 0) if background_color is None else background_color
|
|
@@ -61,9 +60,9 @@ class _CharacterGenerator(AbstractDataset):
|
|
|
61
60
|
vocab: str,
|
|
62
61
|
num_samples: int,
|
|
63
62
|
cache_samples: bool = False,
|
|
64
|
-
font_family:
|
|
65
|
-
img_transforms:
|
|
66
|
-
sample_transforms:
|
|
63
|
+
font_family: str | list[str] | None = None,
|
|
64
|
+
img_transforms: Callable[[Any], Any] | None = None,
|
|
65
|
+
sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
|
|
67
66
|
) -> None:
|
|
68
67
|
self.vocab = vocab
|
|
69
68
|
self._num_samples = num_samples
|
|
@@ -78,7 +77,7 @@ class _CharacterGenerator(AbstractDataset):
|
|
|
78
77
|
self.img_transforms = img_transforms
|
|
79
78
|
self.sample_transforms = sample_transforms
|
|
80
79
|
|
|
81
|
-
self._data:
|
|
80
|
+
self._data: list[Image.Image] = []
|
|
82
81
|
if cache_samples:
|
|
83
82
|
self._data = [
|
|
84
83
|
(synthesize_text_img(char, font_family=font), idx) # type: ignore[misc]
|
|
@@ -89,7 +88,7 @@ class _CharacterGenerator(AbstractDataset):
|
|
|
89
88
|
def __len__(self) -> int:
|
|
90
89
|
return self._num_samples
|
|
91
90
|
|
|
92
|
-
def _read_sample(self, index: int) ->
|
|
91
|
+
def _read_sample(self, index: int) -> tuple[Any, int]:
|
|
93
92
|
# Samples are already cached
|
|
94
93
|
if len(self._data) > 0:
|
|
95
94
|
idx = index % len(self._data)
|
|
@@ -110,9 +109,9 @@ class _WordGenerator(AbstractDataset):
|
|
|
110
109
|
max_chars: int,
|
|
111
110
|
num_samples: int,
|
|
112
111
|
cache_samples: bool = False,
|
|
113
|
-
font_family:
|
|
114
|
-
img_transforms:
|
|
115
|
-
sample_transforms:
|
|
112
|
+
font_family: str | list[str] | None = None,
|
|
113
|
+
img_transforms: Callable[[Any], Any] | None = None,
|
|
114
|
+
sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
|
|
116
115
|
) -> None:
|
|
117
116
|
self.vocab = vocab
|
|
118
117
|
self.wordlen_range = (min_chars, max_chars)
|
|
@@ -128,7 +127,7 @@ class _WordGenerator(AbstractDataset):
|
|
|
128
127
|
self.img_transforms = img_transforms
|
|
129
128
|
self.sample_transforms = sample_transforms
|
|
130
129
|
|
|
131
|
-
self._data:
|
|
130
|
+
self._data: list[Image.Image] = []
|
|
132
131
|
if cache_samples:
|
|
133
132
|
_words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)]
|
|
134
133
|
self._data = [
|
|
@@ -143,7 +142,7 @@ class _WordGenerator(AbstractDataset):
|
|
|
143
142
|
def __len__(self) -> int:
|
|
144
143
|
return self._num_samples
|
|
145
144
|
|
|
146
|
-
def _read_sample(self, index: int) ->
|
|
145
|
+
def _read_sample(self, index: int) -> tuple[Any, str]:
|
|
147
146
|
# Samples are already cached
|
|
148
147
|
if len(self._data) > 0:
|
|
149
148
|
pil_img, target = self._data[index] # type: ignore[misc]
|