python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
doctr/contrib/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .artefacts import ArtefactDetector
|
doctr/contrib/artefacts.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import cv2
|
|
9
9
|
import numpy as np
|
|
@@ -14,7 +14,7 @@ from .base import _BasePredictor
|
|
|
14
14
|
|
|
15
15
|
__all__ = ["ArtefactDetector"]
|
|
16
16
|
|
|
17
|
-
default_cfgs:
|
|
17
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
18
18
|
"yolov8_artefact": {
|
|
19
19
|
"input_shape": (3, 1024, 1024),
|
|
20
20
|
"labels": ["bar_code", "qr_code", "logo", "photo"],
|
|
@@ -34,7 +34,6 @@ class ArtefactDetector(_BasePredictor):
|
|
|
34
34
|
>>> results = detector(doc)
|
|
35
35
|
|
|
36
36
|
Args:
|
|
37
|
-
----
|
|
38
37
|
arch: the architecture to use
|
|
39
38
|
batch_size: the batch size to use
|
|
40
39
|
model_path: the path to the model to use
|
|
@@ -50,9 +49,9 @@ class ArtefactDetector(_BasePredictor):
|
|
|
50
49
|
self,
|
|
51
50
|
arch: str = "yolov8_artefact",
|
|
52
51
|
batch_size: int = 2,
|
|
53
|
-
model_path:
|
|
54
|
-
labels:
|
|
55
|
-
input_shape:
|
|
52
|
+
model_path: str | None = None,
|
|
53
|
+
labels: list[str] | None = None,
|
|
54
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
56
55
|
conf_threshold: float = 0.5,
|
|
57
56
|
iou_threshold: float = 0.5,
|
|
58
57
|
**kwargs: Any,
|
|
@@ -66,7 +65,7 @@ class ArtefactDetector(_BasePredictor):
|
|
|
66
65
|
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
|
67
66
|
return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0)
|
|
68
67
|
|
|
69
|
-
def postprocess(self, output:
|
|
68
|
+
def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> list[list[dict[str, Any]]]:
|
|
70
69
|
results = []
|
|
71
70
|
|
|
72
71
|
for batch in zip(output, input_images):
|
|
@@ -109,7 +108,6 @@ class ArtefactDetector(_BasePredictor):
|
|
|
109
108
|
Display the results
|
|
110
109
|
|
|
111
110
|
Args:
|
|
112
|
-
----
|
|
113
111
|
**kwargs: additional keyword arguments to be passed to `plt.show`
|
|
114
112
|
"""
|
|
115
113
|
requires_package("matplotlib", "`.show()` requires matplotlib installed")
|
doctr/contrib/base.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
@@ -16,32 +16,29 @@ class _BasePredictor:
|
|
|
16
16
|
Base class for all predictors
|
|
17
17
|
|
|
18
18
|
Args:
|
|
19
|
-
----
|
|
20
19
|
batch_size: the batch size to use
|
|
21
20
|
url: the url to use to download a model if needed
|
|
22
21
|
model_path: the path to the model to use
|
|
23
22
|
**kwargs: additional arguments to be passed to `download_from_url`
|
|
24
23
|
"""
|
|
25
24
|
|
|
26
|
-
def __init__(self, batch_size: int, url:
|
|
25
|
+
def __init__(self, batch_size: int, url: str | None = None, model_path: str | None = None, **kwargs) -> None:
|
|
27
26
|
self.batch_size = batch_size
|
|
28
27
|
self.session = self._init_model(url, model_path, **kwargs)
|
|
29
28
|
|
|
30
|
-
self._inputs:
|
|
31
|
-
self._results:
|
|
29
|
+
self._inputs: list[np.ndarray] = []
|
|
30
|
+
self._results: list[Any] = []
|
|
32
31
|
|
|
33
|
-
def _init_model(self, url:
|
|
32
|
+
def _init_model(self, url: str | None = None, model_path: str | None = None, **kwargs: Any) -> Any:
|
|
34
33
|
"""
|
|
35
34
|
Download the model from the given url if needed
|
|
36
35
|
|
|
37
36
|
Args:
|
|
38
|
-
----
|
|
39
37
|
url: the url to use
|
|
40
38
|
model_path: the path to the model to use
|
|
41
39
|
**kwargs: additional arguments to be passed to `download_from_url`
|
|
42
40
|
|
|
43
41
|
Returns:
|
|
44
|
-
-------
|
|
45
42
|
Any: the ONNX loaded model
|
|
46
43
|
"""
|
|
47
44
|
requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
|
|
@@ -57,40 +54,34 @@ class _BasePredictor:
|
|
|
57
54
|
Preprocess the input image
|
|
58
55
|
|
|
59
56
|
Args:
|
|
60
|
-
----
|
|
61
57
|
img: the input image to preprocess
|
|
62
58
|
|
|
63
59
|
Returns:
|
|
64
|
-
-------
|
|
65
60
|
np.ndarray: the preprocessed image
|
|
66
61
|
"""
|
|
67
62
|
raise NotImplementedError
|
|
68
63
|
|
|
69
|
-
def postprocess(self, output:
|
|
64
|
+
def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> Any:
|
|
70
65
|
"""
|
|
71
66
|
Postprocess the model output
|
|
72
67
|
|
|
73
68
|
Args:
|
|
74
|
-
----
|
|
75
69
|
output: the model output to postprocess
|
|
76
70
|
input_images: the input images used to generate the output
|
|
77
71
|
|
|
78
72
|
Returns:
|
|
79
|
-
-------
|
|
80
73
|
Any: the postprocessed output
|
|
81
74
|
"""
|
|
82
75
|
raise NotImplementedError
|
|
83
76
|
|
|
84
|
-
def __call__(self, inputs:
|
|
77
|
+
def __call__(self, inputs: list[np.ndarray]) -> Any:
|
|
85
78
|
"""
|
|
86
79
|
Call the model on the given inputs
|
|
87
80
|
|
|
88
81
|
Args:
|
|
89
|
-
----
|
|
90
82
|
inputs: the inputs to use
|
|
91
83
|
|
|
92
84
|
Returns:
|
|
93
|
-
-------
|
|
94
85
|
Any: the postprocessed output
|
|
95
86
|
"""
|
|
96
87
|
self._inputs = inputs
|
doctr/datasets/__init__.py
CHANGED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
from .datasets import AbstractDataset
|
|
15
|
+
from .utils import convert_target_to_relative, crop_bboxes_from_image
|
|
16
|
+
|
|
17
|
+
__all__ = ["COCOTEXT"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class COCOTEXT(AbstractDataset):
|
|
21
|
+
"""
|
|
22
|
+
COCO-Text dataset from `"COCO-Text: Dataset and Benchmark for Text Detection and Recognition in Natural Images"
|
|
23
|
+
<https://arxiv.org/pdf/1601.07140v2>`_ |
|
|
24
|
+
`"homepage" <https://bgshih.github.io/cocotext/>`_.
|
|
25
|
+
|
|
26
|
+
>>> # NOTE: You need to download the dataset first.
|
|
27
|
+
>>> from doctr.datasets import COCOTEXT
|
|
28
|
+
>>> train_set = COCOTEXT(train=True, img_folder="/path/to/coco_text/train2014/",
|
|
29
|
+
>>> label_path="/path/to/coco_text/cocotext.v2.json")
|
|
30
|
+
>>> img, target = train_set[0]
|
|
31
|
+
>>> test_set = COCOTEXT(train=False, img_folder="/path/to/coco_text/train2014/",
|
|
32
|
+
>>> label_path = "/path/to/coco_text/cocotext.v2.json")
|
|
33
|
+
>>> img, target = test_set[0]
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
img_folder: folder with all the images of the dataset
|
|
37
|
+
label_path: path to the annotations file of the dataset
|
|
38
|
+
train: whether the subset should be the training one
|
|
39
|
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
40
|
+
recognition_task: whether the dataset should be used for recognition task
|
|
41
|
+
detection_task: whether the dataset should be used for detection task
|
|
42
|
+
**kwargs: keyword arguments from `AbstractDataset`.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
img_folder: str,
|
|
48
|
+
label_path: str,
|
|
49
|
+
train: bool = True,
|
|
50
|
+
use_polygons: bool = False,
|
|
51
|
+
recognition_task: bool = False,
|
|
52
|
+
detection_task: bool = False,
|
|
53
|
+
**kwargs: Any,
|
|
54
|
+
) -> None:
|
|
55
|
+
super().__init__(
|
|
56
|
+
img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
|
|
57
|
+
)
|
|
58
|
+
# Task check
|
|
59
|
+
if recognition_task and detection_task:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
" 'recognition' and 'detection task' cannot be set to True simultaneously. "
|
|
62
|
+
+ " To get the whole dataset with boxes and labels leave both parameters to False "
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# File existence check
|
|
66
|
+
if not os.path.exists(label_path) or not os.path.exists(img_folder):
|
|
67
|
+
raise FileNotFoundError(f"unable to find {label_path if not os.path.exists(label_path) else img_folder}")
|
|
68
|
+
|
|
69
|
+
tmp_root = img_folder
|
|
70
|
+
self.train = train
|
|
71
|
+
np_dtype = np.float32
|
|
72
|
+
self.data: list[tuple[str | Path | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
|
|
73
|
+
|
|
74
|
+
with open(label_path, "r") as file:
|
|
75
|
+
data = json.load(file)
|
|
76
|
+
|
|
77
|
+
# Filter images based on the set
|
|
78
|
+
img_items = [img for img in data["imgs"].items() if (img[1]["set"] == "train") == train]
|
|
79
|
+
box: list[float] | np.ndarray
|
|
80
|
+
|
|
81
|
+
for img_id, img_info in tqdm(img_items, desc="Preparing and Loading COCOTEXT", total=len(img_items)):
|
|
82
|
+
img_path = os.path.join(img_folder, img_info["file_name"])
|
|
83
|
+
|
|
84
|
+
# File existence check
|
|
85
|
+
if not os.path.exists(img_path): # pragma: no cover
|
|
86
|
+
raise FileNotFoundError(f"Unable to locate {img_path}")
|
|
87
|
+
|
|
88
|
+
# Get annotations for the current image (only legible text)
|
|
89
|
+
annotations = [
|
|
90
|
+
ann
|
|
91
|
+
for ann in data["anns"].values()
|
|
92
|
+
if ann["image_id"] == int(img_id) and ann["legibility"] == "legible"
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
# Some images have no annotations with readable text
|
|
96
|
+
if not annotations: # pragma: no cover
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
_targets = []
|
|
100
|
+
|
|
101
|
+
for annotation in annotations:
|
|
102
|
+
x, y, w, h = annotation["bbox"]
|
|
103
|
+
if use_polygons:
|
|
104
|
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
|
105
|
+
box = np.array(
|
|
106
|
+
[
|
|
107
|
+
[x, y],
|
|
108
|
+
[x + w, y],
|
|
109
|
+
[x + w, y + h],
|
|
110
|
+
[x, y + h],
|
|
111
|
+
],
|
|
112
|
+
dtype=np_dtype,
|
|
113
|
+
)
|
|
114
|
+
else:
|
|
115
|
+
# (xmin, ymin, xmax, ymax) coordinates
|
|
116
|
+
box = [x, y, x + w, y + h]
|
|
117
|
+
_targets.append((annotation["utf8_string"], box))
|
|
118
|
+
text_targets, box_targets = zip(*_targets)
|
|
119
|
+
|
|
120
|
+
if recognition_task:
|
|
121
|
+
crops = crop_bboxes_from_image(
|
|
122
|
+
img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
|
|
123
|
+
)
|
|
124
|
+
for crop, label in zip(crops, list(text_targets)):
|
|
125
|
+
if label and " " not in label:
|
|
126
|
+
self.data.append((crop, label))
|
|
127
|
+
|
|
128
|
+
elif detection_task:
|
|
129
|
+
self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0)))
|
|
130
|
+
else:
|
|
131
|
+
self.data.append((
|
|
132
|
+
img_path,
|
|
133
|
+
dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)),
|
|
134
|
+
))
|
|
135
|
+
|
|
136
|
+
self.root = tmp_root
|
|
137
|
+
|
|
138
|
+
def extra_repr(self) -> str:
|
|
139
|
+
return f"train={self.train}"
|
doctr/datasets/cord.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from tqdm import tqdm
|
|
@@ -29,7 +29,6 @@ class CORD(VisionDataset):
|
|
|
29
29
|
>>> img, target = train_set[0]
|
|
30
30
|
|
|
31
31
|
Args:
|
|
32
|
-
----
|
|
33
32
|
train: whether the subset should be the training one
|
|
34
33
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
35
34
|
recognition_task: whether the dataset should be used for recognition task
|
|
@@ -72,12 +71,14 @@ class CORD(VisionDataset):
|
|
|
72
71
|
+ "To get the whole dataset with boxes and labels leave both parameters to False."
|
|
73
72
|
)
|
|
74
73
|
|
|
75
|
-
#
|
|
74
|
+
# list images
|
|
76
75
|
tmp_root = os.path.join(self.root, "image")
|
|
77
|
-
self.data:
|
|
76
|
+
self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
|
|
78
77
|
self.train = train
|
|
79
78
|
np_dtype = np.float32
|
|
80
|
-
for img_path in tqdm(
|
|
79
|
+
for img_path in tqdm(
|
|
80
|
+
iterable=os.listdir(tmp_root), desc="Preparing and Loading CORD", total=len(os.listdir(tmp_root))
|
|
81
|
+
):
|
|
81
82
|
# File existence check
|
|
82
83
|
if not os.path.exists(os.path.join(tmp_root, img_path)):
|
|
83
84
|
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
|
|
@@ -91,7 +92,7 @@ class CORD(VisionDataset):
|
|
|
91
92
|
if len(word["text"]) > 0:
|
|
92
93
|
x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"]
|
|
93
94
|
y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"]
|
|
94
|
-
box:
|
|
95
|
+
box: list[float] | np.ndarray
|
|
95
96
|
if use_polygons:
|
|
96
97
|
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
|
97
98
|
box = np.array(
|
|
@@ -115,7 +116,8 @@ class CORD(VisionDataset):
|
|
|
115
116
|
img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
|
|
116
117
|
)
|
|
117
118
|
for crop, label in zip(crops, list(text_targets)):
|
|
118
|
-
|
|
119
|
+
if " " not in label:
|
|
120
|
+
self.data.append((crop, label))
|
|
119
121
|
elif detection_task:
|
|
120
122
|
self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0)))
|
|
121
123
|
else:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
doctr/datasets/datasets/base.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import os
|
|
7
7
|
import shutil
|
|
8
|
+
from collections.abc import Callable
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import Any
|
|
10
|
+
from typing import Any
|
|
10
11
|
|
|
11
12
|
import numpy as np
|
|
12
13
|
|
|
@@ -19,15 +20,15 @@ __all__ = ["_AbstractDataset", "_VisionDataset"]
|
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class _AbstractDataset:
|
|
22
|
-
data:
|
|
23
|
-
_pre_transforms:
|
|
23
|
+
data: list[Any] = []
|
|
24
|
+
_pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None
|
|
24
25
|
|
|
25
26
|
def __init__(
|
|
26
27
|
self,
|
|
27
|
-
root:
|
|
28
|
-
img_transforms:
|
|
29
|
-
sample_transforms:
|
|
30
|
-
pre_transforms:
|
|
28
|
+
root: str | Path,
|
|
29
|
+
img_transforms: Callable[[Any], Any] | None = None,
|
|
30
|
+
sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
|
|
31
|
+
pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
|
|
31
32
|
) -> None:
|
|
32
33
|
if not Path(root).is_dir():
|
|
33
34
|
raise ValueError(f"expected a path to a reachable folder: {root}")
|
|
@@ -41,10 +42,10 @@ class _AbstractDataset:
|
|
|
41
42
|
def __len__(self) -> int:
|
|
42
43
|
return len(self.data)
|
|
43
44
|
|
|
44
|
-
def _read_sample(self, index: int) ->
|
|
45
|
+
def _read_sample(self, index: int) -> tuple[Any, Any]:
|
|
45
46
|
raise NotImplementedError
|
|
46
47
|
|
|
47
|
-
def __getitem__(self, index: int) ->
|
|
48
|
+
def __getitem__(self, index: int) -> tuple[Any, Any]:
|
|
48
49
|
# Read image
|
|
49
50
|
img, target = self._read_sample(index)
|
|
50
51
|
# Pre-transforms (format conversion at run-time etc.)
|
|
@@ -82,7 +83,6 @@ class _VisionDataset(_AbstractDataset):
|
|
|
82
83
|
"""Implements an abstract dataset
|
|
83
84
|
|
|
84
85
|
Args:
|
|
85
|
-
----
|
|
86
86
|
url: URL of the dataset
|
|
87
87
|
file_name: name of the file once downloaded
|
|
88
88
|
file_hash: expected SHA256 of the file
|
|
@@ -96,13 +96,13 @@ class _VisionDataset(_AbstractDataset):
|
|
|
96
96
|
def __init__(
|
|
97
97
|
self,
|
|
98
98
|
url: str,
|
|
99
|
-
file_name:
|
|
100
|
-
file_hash:
|
|
99
|
+
file_name: str | None = None,
|
|
100
|
+
file_hash: str | None = None,
|
|
101
101
|
extract_archive: bool = False,
|
|
102
102
|
download: bool = False,
|
|
103
103
|
overwrite: bool = False,
|
|
104
|
-
cache_dir:
|
|
105
|
-
cache_subdir:
|
|
104
|
+
cache_dir: str | None = None,
|
|
105
|
+
cache_subdir: str | None = None,
|
|
106
106
|
**kwargs: Any,
|
|
107
107
|
) -> None:
|
|
108
108
|
cache_dir = (
|
|
@@ -115,7 +115,7 @@ class _VisionDataset(_AbstractDataset):
|
|
|
115
115
|
|
|
116
116
|
file_name = file_name if isinstance(file_name, str) else os.path.basename(url)
|
|
117
117
|
# Download the file if not present
|
|
118
|
-
archive_path:
|
|
118
|
+
archive_path: str | Path = os.path.join(cache_dir, cache_subdir, file_name)
|
|
119
119
|
|
|
120
120
|
if not os.path.exists(archive_path) and not download:
|
|
121
121
|
raise ValueError("the dataset needs to be downloaded first with download=True")
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import os
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import torch
|
|
@@ -20,7 +20,7 @@ __all__ = ["AbstractDataset", "VisionDataset"]
|
|
|
20
20
|
class AbstractDataset(_AbstractDataset):
|
|
21
21
|
"""Abstract class for all datasets"""
|
|
22
22
|
|
|
23
|
-
def _read_sample(self, index: int) ->
|
|
23
|
+
def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]:
|
|
24
24
|
img_name, target = self.data[index]
|
|
25
25
|
|
|
26
26
|
# Check target
|
|
@@ -29,14 +29,14 @@ class AbstractDataset(_AbstractDataset):
|
|
|
29
29
|
assert "labels" in target, "Target should contain 'labels' key"
|
|
30
30
|
elif isinstance(target, tuple):
|
|
31
31
|
assert len(target) == 2
|
|
32
|
-
assert isinstance(target[0], str) or isinstance(
|
|
33
|
-
|
|
34
|
-
)
|
|
32
|
+
assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
|
|
33
|
+
"first element of the tuple should be a string or a numpy array"
|
|
34
|
+
)
|
|
35
35
|
assert isinstance(target[1], list), "second element of the tuple should be a list"
|
|
36
36
|
else:
|
|
37
|
-
assert isinstance(target, str) or isinstance(
|
|
38
|
-
|
|
39
|
-
)
|
|
37
|
+
assert isinstance(target, str) or isinstance(target, np.ndarray), (
|
|
38
|
+
"Target should be a string or a numpy array"
|
|
39
|
+
)
|
|
40
40
|
|
|
41
41
|
# Read image
|
|
42
42
|
img = (
|
|
@@ -48,11 +48,11 @@ class AbstractDataset(_AbstractDataset):
|
|
|
48
48
|
return img, deepcopy(target)
|
|
49
49
|
|
|
50
50
|
@staticmethod
|
|
51
|
-
def collate_fn(samples:
|
|
51
|
+
def collate_fn(samples: list[tuple[torch.Tensor, Any]]) -> tuple[torch.Tensor, list[Any]]:
|
|
52
52
|
images, targets = zip(*samples)
|
|
53
|
-
images = torch.stack(images, dim=0)
|
|
53
|
+
images = torch.stack(images, dim=0)
|
|
54
54
|
|
|
55
|
-
return images, list(targets)
|
|
55
|
+
return images, list(targets)
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import os
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import tensorflow as tf
|
|
@@ -20,7 +20,7 @@ __all__ = ["AbstractDataset", "VisionDataset"]
|
|
|
20
20
|
class AbstractDataset(_AbstractDataset):
|
|
21
21
|
"""Abstract class for all datasets"""
|
|
22
22
|
|
|
23
|
-
def _read_sample(self, index: int) ->
|
|
23
|
+
def _read_sample(self, index: int) -> tuple[tf.Tensor, Any]:
|
|
24
24
|
img_name, target = self.data[index]
|
|
25
25
|
|
|
26
26
|
# Check target
|
|
@@ -29,14 +29,14 @@ class AbstractDataset(_AbstractDataset):
|
|
|
29
29
|
assert "labels" in target, "Target should contain 'labels' key"
|
|
30
30
|
elif isinstance(target, tuple):
|
|
31
31
|
assert len(target) == 2
|
|
32
|
-
assert isinstance(target[0], str) or isinstance(
|
|
33
|
-
|
|
34
|
-
)
|
|
32
|
+
assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
|
|
33
|
+
"first element of the tuple should be a string or a numpy array"
|
|
34
|
+
)
|
|
35
35
|
assert isinstance(target[1], list), "second element of the tuple should be a list"
|
|
36
36
|
else:
|
|
37
|
-
assert isinstance(target, str) or isinstance(
|
|
38
|
-
|
|
39
|
-
)
|
|
37
|
+
assert isinstance(target, str) or isinstance(target, np.ndarray), (
|
|
38
|
+
"Target should be a string or a numpy array"
|
|
39
|
+
)
|
|
40
40
|
|
|
41
41
|
# Read image
|
|
42
42
|
img = (
|
|
@@ -48,7 +48,7 @@ class AbstractDataset(_AbstractDataset):
|
|
|
48
48
|
return img, deepcopy(target)
|
|
49
49
|
|
|
50
50
|
@staticmethod
|
|
51
|
-
def collate_fn(samples:
|
|
51
|
+
def collate_fn(samples: list[tuple[tf.Tensor, Any]]) -> tuple[tf.Tensor, list[Any]]:
|
|
52
52
|
images, targets = zip(*samples)
|
|
53
53
|
images = tf.stack(images, axis=0)
|
|
54
54
|
|
doctr/datasets/detection.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
@@ -26,7 +26,6 @@ class DetectionDataset(AbstractDataset):
|
|
|
26
26
|
>>> img, target = train_set[0]
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
|
-
----
|
|
30
29
|
img_folder: folder with all the images of the dataset
|
|
31
30
|
label_path: path to the annotations of each image
|
|
32
31
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
@@ -47,13 +46,13 @@ class DetectionDataset(AbstractDataset):
|
|
|
47
46
|
)
|
|
48
47
|
|
|
49
48
|
# File existence check
|
|
50
|
-
self._class_names:
|
|
49
|
+
self._class_names: list = []
|
|
51
50
|
if not os.path.exists(label_path):
|
|
52
51
|
raise FileNotFoundError(f"unable to locate {label_path}")
|
|
53
52
|
with open(label_path, "rb") as f:
|
|
54
53
|
labels = json.load(f)
|
|
55
54
|
|
|
56
|
-
self.data:
|
|
55
|
+
self.data: list[tuple[str, tuple[np.ndarray, list[str]]]] = []
|
|
57
56
|
np_dtype = np.float32
|
|
58
57
|
for img_name, label in labels.items():
|
|
59
58
|
# File existence check
|
|
@@ -65,18 +64,16 @@ class DetectionDataset(AbstractDataset):
|
|
|
65
64
|
self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
|
|
66
65
|
|
|
67
66
|
def format_polygons(
|
|
68
|
-
self, polygons:
|
|
69
|
-
) ->
|
|
67
|
+
self, polygons: list | dict, use_polygons: bool, np_dtype: type
|
|
68
|
+
) -> tuple[np.ndarray, list[str]]:
|
|
70
69
|
"""Format polygons into an array
|
|
71
70
|
|
|
72
71
|
Args:
|
|
73
|
-
----
|
|
74
72
|
polygons: the bounding boxes
|
|
75
73
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
76
74
|
np_dtype: dtype of array
|
|
77
75
|
|
|
78
76
|
Returns:
|
|
79
|
-
-------
|
|
80
77
|
geoms: bounding boxes as np array
|
|
81
78
|
polygons_classes: list of classes for each bounding box
|
|
82
79
|
"""
|
doctr/datasets/doc_artefacts.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
@@ -26,7 +26,6 @@ class DocArtefacts(VisionDataset):
|
|
|
26
26
|
>>> img, target = train_set[0]
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
|
-
----
|
|
30
29
|
train: whether the subset should be the training one
|
|
31
30
|
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
32
31
|
**kwargs: keyword arguments from `VisionDataset`.
|
|
@@ -51,7 +50,7 @@ class DocArtefacts(VisionDataset):
|
|
|
51
50
|
tmp_root = os.path.join(self.root, "images")
|
|
52
51
|
with open(os.path.join(self.root, "labels.json"), "rb") as f:
|
|
53
52
|
labels = json.load(f)
|
|
54
|
-
self.data:
|
|
53
|
+
self.data: list[tuple[str, dict[str, Any]]] = []
|
|
55
54
|
img_list = os.listdir(tmp_root)
|
|
56
55
|
if len(labels) != len(img_list):
|
|
57
56
|
raise AssertionError("the number of images and labels do not match")
|