python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
doctr/contrib/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ from .artefacts import ArtefactDetector
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Optional, Tuple
6
+ from typing import Any
7
7
 
8
8
  import cv2
9
9
  import numpy as np
@@ -14,7 +14,7 @@ from .base import _BasePredictor
14
14
 
15
15
  __all__ = ["ArtefactDetector"]
16
16
 
17
- default_cfgs: Dict[str, Dict[str, Any]] = {
17
+ default_cfgs: dict[str, dict[str, Any]] = {
18
18
  "yolov8_artefact": {
19
19
  "input_shape": (3, 1024, 1024),
20
20
  "labels": ["bar_code", "qr_code", "logo", "photo"],
@@ -34,7 +34,6 @@ class ArtefactDetector(_BasePredictor):
34
34
  >>> results = detector(doc)
35
35
 
36
36
  Args:
37
- ----
38
37
  arch: the architecture to use
39
38
  batch_size: the batch size to use
40
39
  model_path: the path to the model to use
@@ -50,9 +49,9 @@ class ArtefactDetector(_BasePredictor):
50
49
  self,
51
50
  arch: str = "yolov8_artefact",
52
51
  batch_size: int = 2,
53
- model_path: Optional[str] = None,
54
- labels: Optional[List[str]] = None,
55
- input_shape: Optional[Tuple[int, int, int]] = None,
52
+ model_path: str | None = None,
53
+ labels: list[str] | None = None,
54
+ input_shape: tuple[int, int, int] | None = None,
56
55
  conf_threshold: float = 0.5,
57
56
  iou_threshold: float = 0.5,
58
57
  **kwargs: Any,
@@ -66,7 +65,7 @@ class ArtefactDetector(_BasePredictor):
66
65
  def preprocess(self, img: np.ndarray) -> np.ndarray:
67
66
  return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0)
68
67
 
69
- def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]:
68
+ def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> list[list[dict[str, Any]]]:
70
69
  results = []
71
70
 
72
71
  for batch in zip(output, input_images):
@@ -109,7 +108,6 @@ class ArtefactDetector(_BasePredictor):
109
108
  Display the results
110
109
 
111
110
  Args:
112
- ----
113
111
  **kwargs: additional keyword arguments to be passed to `plt.show`
114
112
  """
115
113
  requires_package("matplotlib", "`.show()` requires matplotlib installed")
doctr/contrib/base.py CHANGED
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Optional
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
 
@@ -16,32 +16,29 @@ class _BasePredictor:
16
16
  Base class for all predictors
17
17
 
18
18
  Args:
19
- ----
20
19
  batch_size: the batch size to use
21
20
  url: the url to use to download a model if needed
22
21
  model_path: the path to the model to use
23
22
  **kwargs: additional arguments to be passed to `download_from_url`
24
23
  """
25
24
 
26
- def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None:
25
+ def __init__(self, batch_size: int, url: str | None = None, model_path: str | None = None, **kwargs) -> None:
27
26
  self.batch_size = batch_size
28
27
  self.session = self._init_model(url, model_path, **kwargs)
29
28
 
30
- self._inputs: List[np.ndarray] = []
31
- self._results: List[Any] = []
29
+ self._inputs: list[np.ndarray] = []
30
+ self._results: list[Any] = []
32
31
 
33
- def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any:
32
+ def _init_model(self, url: str | None = None, model_path: str | None = None, **kwargs: Any) -> Any:
34
33
  """
35
34
  Download the model from the given url if needed
36
35
 
37
36
  Args:
38
- ----
39
37
  url: the url to use
40
38
  model_path: the path to the model to use
41
39
  **kwargs: additional arguments to be passed to `download_from_url`
42
40
 
43
41
  Returns:
44
- -------
45
42
  Any: the ONNX loaded model
46
43
  """
47
44
  requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
@@ -57,40 +54,34 @@ class _BasePredictor:
57
54
  Preprocess the input image
58
55
 
59
56
  Args:
60
- ----
61
57
  img: the input image to preprocess
62
58
 
63
59
  Returns:
64
- -------
65
60
  np.ndarray: the preprocessed image
66
61
  """
67
62
  raise NotImplementedError
68
63
 
69
- def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any:
64
+ def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> Any:
70
65
  """
71
66
  Postprocess the model output
72
67
 
73
68
  Args:
74
- ----
75
69
  output: the model output to postprocess
76
70
  input_images: the input images used to generate the output
77
71
 
78
72
  Returns:
79
- -------
80
73
  Any: the postprocessed output
81
74
  """
82
75
  raise NotImplementedError
83
76
 
84
- def __call__(self, inputs: List[np.ndarray]) -> Any:
77
+ def __call__(self, inputs: list[np.ndarray]) -> Any:
85
78
  """
86
79
  Call the model on the given inputs
87
80
 
88
81
  Args:
89
- ----
90
82
  inputs: the inputs to use
91
83
 
92
84
  Returns:
93
- -------
94
85
  Any: the postprocessed output
95
86
  """
96
87
  self._inputs = inputs
@@ -1,6 +1,7 @@
1
1
  from doctr.file_utils import is_tf_available
2
2
 
3
3
  from .generator import *
4
+ from .coco_text import *
4
5
  from .cord import *
5
6
  from .detection import *
6
7
  from .doc_artefacts import *
@@ -0,0 +1,139 @@
1
+ # Copyright (C) 2021-2025, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ from .datasets import AbstractDataset
15
+ from .utils import convert_target_to_relative, crop_bboxes_from_image
16
+
17
+ __all__ = ["COCOTEXT"]
18
+
19
+
20
+ class COCOTEXT(AbstractDataset):
21
+ """
22
+ COCO-Text dataset from `"COCO-Text: Dataset and Benchmark for Text Detection and Recognition in Natural Images"
23
+ <https://arxiv.org/pdf/1601.07140v2>`_ |
24
+ `"homepage" <https://bgshih.github.io/cocotext/>`_.
25
+
26
+ >>> # NOTE: You need to download the dataset first.
27
+ >>> from doctr.datasets import COCOTEXT
28
+ >>> train_set = COCOTEXT(train=True, img_folder="/path/to/coco_text/train2014/",
29
+ >>> label_path="/path/to/coco_text/cocotext.v2.json")
30
+ >>> img, target = train_set[0]
31
+ >>> test_set = COCOTEXT(train=False, img_folder="/path/to/coco_text/train2014/",
32
+ >>> label_path = "/path/to/coco_text/cocotext.v2.json")
33
+ >>> img, target = test_set[0]
34
+
35
+ Args:
36
+ img_folder: folder with all the images of the dataset
37
+ label_path: path to the annotations file of the dataset
38
+ train: whether the subset should be the training one
39
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
40
+ recognition_task: whether the dataset should be used for recognition task
41
+ detection_task: whether the dataset should be used for detection task
42
+ **kwargs: keyword arguments from `AbstractDataset`.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ img_folder: str,
48
+ label_path: str,
49
+ train: bool = True,
50
+ use_polygons: bool = False,
51
+ recognition_task: bool = False,
52
+ detection_task: bool = False,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ super().__init__(
56
+ img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
57
+ )
58
+ # Task check
59
+ if recognition_task and detection_task:
60
+ raise ValueError(
61
+ " 'recognition' and 'detection task' cannot be set to True simultaneously. "
62
+ + " To get the whole dataset with boxes and labels leave both parameters to False "
63
+ )
64
+
65
+ # File existence check
66
+ if not os.path.exists(label_path) or not os.path.exists(img_folder):
67
+ raise FileNotFoundError(f"unable to find {label_path if not os.path.exists(label_path) else img_folder}")
68
+
69
+ tmp_root = img_folder
70
+ self.train = train
71
+ np_dtype = np.float32
72
+ self.data: list[tuple[str | Path | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
73
+
74
+ with open(label_path, "r") as file:
75
+ data = json.load(file)
76
+
77
+ # Filter images based on the set
78
+ img_items = [img for img in data["imgs"].items() if (img[1]["set"] == "train") == train]
79
+ box: list[float] | np.ndarray
80
+
81
+ for img_id, img_info in tqdm(img_items, desc="Preparing and Loading COCOTEXT", total=len(img_items)):
82
+ img_path = os.path.join(img_folder, img_info["file_name"])
83
+
84
+ # File existence check
85
+ if not os.path.exists(img_path): # pragma: no cover
86
+ raise FileNotFoundError(f"Unable to locate {img_path}")
87
+
88
+ # Get annotations for the current image (only legible text)
89
+ annotations = [
90
+ ann
91
+ for ann in data["anns"].values()
92
+ if ann["image_id"] == int(img_id) and ann["legibility"] == "legible"
93
+ ]
94
+
95
+ # Some images have no annotations with readable text
96
+ if not annotations: # pragma: no cover
97
+ continue
98
+
99
+ _targets = []
100
+
101
+ for annotation in annotations:
102
+ x, y, w, h = annotation["bbox"]
103
+ if use_polygons:
104
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
105
+ box = np.array(
106
+ [
107
+ [x, y],
108
+ [x + w, y],
109
+ [x + w, y + h],
110
+ [x, y + h],
111
+ ],
112
+ dtype=np_dtype,
113
+ )
114
+ else:
115
+ # (xmin, ymin, xmax, ymax) coordinates
116
+ box = [x, y, x + w, y + h]
117
+ _targets.append((annotation["utf8_string"], box))
118
+ text_targets, box_targets = zip(*_targets)
119
+
120
+ if recognition_task:
121
+ crops = crop_bboxes_from_image(
122
+ img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
123
+ )
124
+ for crop, label in zip(crops, list(text_targets)):
125
+ if label and " " not in label:
126
+ self.data.append((crop, label))
127
+
128
+ elif detection_task:
129
+ self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0)))
130
+ else:
131
+ self.data.append((
132
+ img_path,
133
+ dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)),
134
+ ))
135
+
136
+ self.root = tmp_root
137
+
138
+ def extra_repr(self) -> str:
139
+ return f"train={self.train}"
doctr/datasets/cord.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  import json
7
7
  import os
8
8
  from pathlib import Path
9
- from typing import Any, Dict, List, Tuple, Union
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  from tqdm import tqdm
@@ -29,7 +29,6 @@ class CORD(VisionDataset):
29
29
  >>> img, target = train_set[0]
30
30
 
31
31
  Args:
32
- ----
33
32
  train: whether the subset should be the training one
34
33
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
35
34
  recognition_task: whether the dataset should be used for recognition task
@@ -72,12 +71,14 @@ class CORD(VisionDataset):
72
71
  + "To get the whole dataset with boxes and labels leave both parameters to False."
73
72
  )
74
73
 
75
- # List images
74
+ # list images
76
75
  tmp_root = os.path.join(self.root, "image")
77
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
76
+ self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
78
77
  self.train = train
79
78
  np_dtype = np.float32
80
- for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking CORD", total=len(os.listdir(tmp_root))):
79
+ for img_path in tqdm(
80
+ iterable=os.listdir(tmp_root), desc="Preparing and Loading CORD", total=len(os.listdir(tmp_root))
81
+ ):
81
82
  # File existence check
82
83
  if not os.path.exists(os.path.join(tmp_root, img_path)):
83
84
  raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
@@ -91,7 +92,7 @@ class CORD(VisionDataset):
91
92
  if len(word["text"]) > 0:
92
93
  x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"]
93
94
  y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"]
94
- box: Union[List[float], np.ndarray]
95
+ box: list[float] | np.ndarray
95
96
  if use_polygons:
96
97
  # (x, y) coordinates of top left, top right, bottom right, bottom left corners
97
98
  box = np.array(
@@ -115,7 +116,8 @@ class CORD(VisionDataset):
115
116
  img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
116
117
  )
117
118
  for crop, label in zip(crops, list(text_targets)):
118
- self.data.append((crop, label))
119
+ if " " not in label:
120
+ self.data.append((crop, label))
119
121
  elif detection_task:
120
122
  self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0)))
121
123
  else:
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,12 +1,13 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import os
7
7
  import shutil
8
+ from collections.abc import Callable
8
9
  from pathlib import Path
9
- from typing import Any, Callable, List, Optional, Tuple, Union
10
+ from typing import Any
10
11
 
11
12
  import numpy as np
12
13
 
@@ -19,15 +20,15 @@ __all__ = ["_AbstractDataset", "_VisionDataset"]
19
20
 
20
21
 
21
22
  class _AbstractDataset:
22
- data: List[Any] = []
23
- _pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None
23
+ data: list[Any] = []
24
+ _pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None
24
25
 
25
26
  def __init__(
26
27
  self,
27
- root: Union[str, Path],
28
- img_transforms: Optional[Callable[[Any], Any]] = None,
29
- sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
30
- pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
28
+ root: str | Path,
29
+ img_transforms: Callable[[Any], Any] | None = None,
30
+ sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
31
+ pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
31
32
  ) -> None:
32
33
  if not Path(root).is_dir():
33
34
  raise ValueError(f"expected a path to a reachable folder: {root}")
@@ -41,10 +42,10 @@ class _AbstractDataset:
41
42
  def __len__(self) -> int:
42
43
  return len(self.data)
43
44
 
44
- def _read_sample(self, index: int) -> Tuple[Any, Any]:
45
+ def _read_sample(self, index: int) -> tuple[Any, Any]:
45
46
  raise NotImplementedError
46
47
 
47
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
48
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
48
49
  # Read image
49
50
  img, target = self._read_sample(index)
50
51
  # Pre-transforms (format conversion at run-time etc.)
@@ -82,7 +83,6 @@ class _VisionDataset(_AbstractDataset):
82
83
  """Implements an abstract dataset
83
84
 
84
85
  Args:
85
- ----
86
86
  url: URL of the dataset
87
87
  file_name: name of the file once downloaded
88
88
  file_hash: expected SHA256 of the file
@@ -96,13 +96,13 @@ class _VisionDataset(_AbstractDataset):
96
96
  def __init__(
97
97
  self,
98
98
  url: str,
99
- file_name: Optional[str] = None,
100
- file_hash: Optional[str] = None,
99
+ file_name: str | None = None,
100
+ file_hash: str | None = None,
101
101
  extract_archive: bool = False,
102
102
  download: bool = False,
103
103
  overwrite: bool = False,
104
- cache_dir: Optional[str] = None,
105
- cache_subdir: Optional[str] = None,
104
+ cache_dir: str | None = None,
105
+ cache_subdir: str | None = None,
106
106
  **kwargs: Any,
107
107
  ) -> None:
108
108
  cache_dir = (
@@ -115,7 +115,7 @@ class _VisionDataset(_AbstractDataset):
115
115
 
116
116
  file_name = file_name if isinstance(file_name, str) else os.path.basename(url)
117
117
  # Download the file if not present
118
- archive_path: Union[str, Path] = os.path.join(cache_dir, cache_subdir, file_name)
118
+ archive_path: str | Path = os.path.join(cache_dir, cache_subdir, file_name)
119
119
 
120
120
  if not os.path.exists(archive_path) and not download:
121
121
  raise ValueError("the dataset needs to be downloaded first with download=True")
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import os
7
7
  from copy import deepcopy
8
- from typing import Any, List, Tuple
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
  import torch
@@ -20,7 +20,7 @@ __all__ = ["AbstractDataset", "VisionDataset"]
20
20
  class AbstractDataset(_AbstractDataset):
21
21
  """Abstract class for all datasets"""
22
22
 
23
- def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]:
23
+ def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]:
24
24
  img_name, target = self.data[index]
25
25
 
26
26
  # Check target
@@ -29,14 +29,14 @@ class AbstractDataset(_AbstractDataset):
29
29
  assert "labels" in target, "Target should contain 'labels' key"
30
30
  elif isinstance(target, tuple):
31
31
  assert len(target) == 2
32
- assert isinstance(target[0], str) or isinstance(
33
- target[0], np.ndarray
34
- ), "first element of the tuple should be a string or a numpy array"
32
+ assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
33
+ "first element of the tuple should be a string or a numpy array"
34
+ )
35
35
  assert isinstance(target[1], list), "second element of the tuple should be a list"
36
36
  else:
37
- assert isinstance(target, str) or isinstance(
38
- target, np.ndarray
39
- ), "Target should be a string or a numpy array"
37
+ assert isinstance(target, str) or isinstance(target, np.ndarray), (
38
+ "Target should be a string or a numpy array"
39
+ )
40
40
 
41
41
  # Read image
42
42
  img = (
@@ -48,11 +48,11 @@ class AbstractDataset(_AbstractDataset):
48
48
  return img, deepcopy(target)
49
49
 
50
50
  @staticmethod
51
- def collate_fn(samples: List[Tuple[torch.Tensor, Any]]) -> Tuple[torch.Tensor, List[Any]]:
51
+ def collate_fn(samples: list[tuple[torch.Tensor, Any]]) -> tuple[torch.Tensor, list[Any]]:
52
52
  images, targets = zip(*samples)
53
- images = torch.stack(images, dim=0) # type: ignore[assignment]
53
+ images = torch.stack(images, dim=0)
54
54
 
55
- return images, list(targets) # type: ignore[return-value]
55
+ return images, list(targets)
56
56
 
57
57
 
58
58
  class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import os
7
7
  from copy import deepcopy
8
- from typing import Any, List, Tuple
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
  import tensorflow as tf
@@ -20,7 +20,7 @@ __all__ = ["AbstractDataset", "VisionDataset"]
20
20
  class AbstractDataset(_AbstractDataset):
21
21
  """Abstract class for all datasets"""
22
22
 
23
- def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]:
23
+ def _read_sample(self, index: int) -> tuple[tf.Tensor, Any]:
24
24
  img_name, target = self.data[index]
25
25
 
26
26
  # Check target
@@ -29,14 +29,14 @@ class AbstractDataset(_AbstractDataset):
29
29
  assert "labels" in target, "Target should contain 'labels' key"
30
30
  elif isinstance(target, tuple):
31
31
  assert len(target) == 2
32
- assert isinstance(target[0], str) or isinstance(
33
- target[0], np.ndarray
34
- ), "first element of the tuple should be a string or a numpy array"
32
+ assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
33
+ "first element of the tuple should be a string or a numpy array"
34
+ )
35
35
  assert isinstance(target[1], list), "second element of the tuple should be a list"
36
36
  else:
37
- assert isinstance(target, str) or isinstance(
38
- target, np.ndarray
39
- ), "Target should be a string or a numpy array"
37
+ assert isinstance(target, str) or isinstance(target, np.ndarray), (
38
+ "Target should be a string or a numpy array"
39
+ )
40
40
 
41
41
  # Read image
42
42
  img = (
@@ -48,7 +48,7 @@ class AbstractDataset(_AbstractDataset):
48
48
  return img, deepcopy(target)
49
49
 
50
50
  @staticmethod
51
- def collate_fn(samples: List[Tuple[tf.Tensor, Any]]) -> Tuple[tf.Tensor, List[Any]]:
51
+ def collate_fn(samples: list[tuple[tf.Tensor, Any]]) -> tuple[tf.Tensor, list[Any]]:
52
52
  images, targets = zip(*samples)
53
53
  images = tf.stack(images, axis=0)
54
54
 
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import json
7
7
  import os
8
- from typing import Any, Dict, List, Tuple, Type, Union
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
 
@@ -26,7 +26,6 @@ class DetectionDataset(AbstractDataset):
26
26
  >>> img, target = train_set[0]
27
27
 
28
28
  Args:
29
- ----
30
29
  img_folder: folder with all the images of the dataset
31
30
  label_path: path to the annotations of each image
32
31
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
@@ -47,13 +46,13 @@ class DetectionDataset(AbstractDataset):
47
46
  )
48
47
 
49
48
  # File existence check
50
- self._class_names: List = []
49
+ self._class_names: list = []
51
50
  if not os.path.exists(label_path):
52
51
  raise FileNotFoundError(f"unable to locate {label_path}")
53
52
  with open(label_path, "rb") as f:
54
53
  labels = json.load(f)
55
54
 
56
- self.data: List[Tuple[str, Tuple[np.ndarray, List[str]]]] = []
55
+ self.data: list[tuple[str, tuple[np.ndarray, list[str]]]] = []
57
56
  np_dtype = np.float32
58
57
  for img_name, label in labels.items():
59
58
  # File existence check
@@ -65,18 +64,16 @@ class DetectionDataset(AbstractDataset):
65
64
  self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
66
65
 
67
66
  def format_polygons(
68
- self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type
69
- ) -> Tuple[np.ndarray, List[str]]:
67
+ self, polygons: list | dict, use_polygons: bool, np_dtype: type
68
+ ) -> tuple[np.ndarray, list[str]]:
70
69
  """Format polygons into an array
71
70
 
72
71
  Args:
73
- ----
74
72
  polygons: the bounding boxes
75
73
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
76
74
  np_dtype: dtype of array
77
75
 
78
76
  Returns:
79
- -------
80
77
  geoms: bounding boxes as np array
81
78
  polygons_classes: list of classes for each bounding box
82
79
  """
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import json
7
7
  import os
8
- from typing import Any, Dict, List, Tuple
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
 
@@ -26,7 +26,6 @@ class DocArtefacts(VisionDataset):
26
26
  >>> img, target = train_set[0]
27
27
 
28
28
  Args:
29
- ----
30
29
  train: whether the subset should be the training one
31
30
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
32
31
  **kwargs: keyword arguments from `VisionDataset`.
@@ -51,7 +50,7 @@ class DocArtefacts(VisionDataset):
51
50
  tmp_root = os.path.join(self.root, "images")
52
51
  with open(os.path.join(self.root, "labels.json"), "rb") as f:
53
52
  labels = json.load(f)
54
- self.data: List[Tuple[str, Dict[str, Any]]] = []
53
+ self.data: list[tuple[str, dict[str, Any]]] = []
55
54
  img_list = os.listdir(tmp_root)
56
55
  if len(labels) != len(img_list):
57
56
  raise AssertionError("the number of images and labels do not match")