python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
doctr/contrib/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ from .artefacts import ArtefactDetector
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Optional, Tuple
6
+ from typing import Any
7
7
 
8
8
  import cv2
9
9
  import numpy as np
@@ -14,7 +14,7 @@ from .base import _BasePredictor
14
14
 
15
15
  __all__ = ["ArtefactDetector"]
16
16
 
17
- default_cfgs: Dict[str, Dict[str, Any]] = {
17
+ default_cfgs: dict[str, dict[str, Any]] = {
18
18
  "yolov8_artefact": {
19
19
  "input_shape": (3, 1024, 1024),
20
20
  "labels": ["bar_code", "qr_code", "logo", "photo"],
@@ -34,7 +34,6 @@ class ArtefactDetector(_BasePredictor):
34
34
  >>> results = detector(doc)
35
35
 
36
36
  Args:
37
- ----
38
37
  arch: the architecture to use
39
38
  batch_size: the batch size to use
40
39
  model_path: the path to the model to use
@@ -50,9 +49,9 @@ class ArtefactDetector(_BasePredictor):
50
49
  self,
51
50
  arch: str = "yolov8_artefact",
52
51
  batch_size: int = 2,
53
- model_path: Optional[str] = None,
54
- labels: Optional[List[str]] = None,
55
- input_shape: Optional[Tuple[int, int, int]] = None,
52
+ model_path: str | None = None,
53
+ labels: list[str] | None = None,
54
+ input_shape: tuple[int, int, int] | None = None,
56
55
  conf_threshold: float = 0.5,
57
56
  iou_threshold: float = 0.5,
58
57
  **kwargs: Any,
@@ -66,7 +65,7 @@ class ArtefactDetector(_BasePredictor):
66
65
  def preprocess(self, img: np.ndarray) -> np.ndarray:
67
66
  return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0)
68
67
 
69
- def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]:
68
+ def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> list[list[dict[str, Any]]]:
70
69
  results = []
71
70
 
72
71
  for batch in zip(output, input_images):
@@ -109,7 +108,6 @@ class ArtefactDetector(_BasePredictor):
109
108
  Display the results
110
109
 
111
110
  Args:
112
- ----
113
111
  **kwargs: additional keyword arguments to be passed to `plt.show`
114
112
  """
115
113
  requires_package("matplotlib", "`.show()` requires matplotlib installed")
doctr/contrib/base.py CHANGED
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Optional
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
 
@@ -16,32 +16,29 @@ class _BasePredictor:
16
16
  Base class for all predictors
17
17
 
18
18
  Args:
19
- ----
20
19
  batch_size: the batch size to use
21
20
  url: the url to use to download a model if needed
22
21
  model_path: the path to the model to use
23
22
  **kwargs: additional arguments to be passed to `download_from_url`
24
23
  """
25
24
 
26
- def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None:
25
+ def __init__(self, batch_size: int, url: str | None = None, model_path: str | None = None, **kwargs) -> None:
27
26
  self.batch_size = batch_size
28
27
  self.session = self._init_model(url, model_path, **kwargs)
29
28
 
30
- self._inputs: List[np.ndarray] = []
31
- self._results: List[Any] = []
29
+ self._inputs: list[np.ndarray] = []
30
+ self._results: list[Any] = []
32
31
 
33
- def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any:
32
+ def _init_model(self, url: str | None = None, model_path: str | None = None, **kwargs: Any) -> Any:
34
33
  """
35
34
  Download the model from the given url if needed
36
35
 
37
36
  Args:
38
- ----
39
37
  url: the url to use
40
38
  model_path: the path to the model to use
41
39
  **kwargs: additional arguments to be passed to `download_from_url`
42
40
 
43
41
  Returns:
44
- -------
45
42
  Any: the ONNX loaded model
46
43
  """
47
44
  requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
@@ -57,40 +54,34 @@ class _BasePredictor:
57
54
  Preprocess the input image
58
55
 
59
56
  Args:
60
- ----
61
57
  img: the input image to preprocess
62
58
 
63
59
  Returns:
64
- -------
65
60
  np.ndarray: the preprocessed image
66
61
  """
67
62
  raise NotImplementedError
68
63
 
69
- def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any:
64
+ def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> Any:
70
65
  """
71
66
  Postprocess the model output
72
67
 
73
68
  Args:
74
- ----
75
69
  output: the model output to postprocess
76
70
  input_images: the input images used to generate the output
77
71
 
78
72
  Returns:
79
- -------
80
73
  Any: the postprocessed output
81
74
  """
82
75
  raise NotImplementedError
83
76
 
84
- def __call__(self, inputs: List[np.ndarray]) -> Any:
77
+ def __call__(self, inputs: list[np.ndarray]) -> Any:
85
78
  """
86
79
  Call the model on the given inputs
87
80
 
88
81
  Args:
89
- ----
90
82
  inputs: the inputs to use
91
83
 
92
84
  Returns:
93
- -------
94
85
  Any: the postprocessed output
95
86
  """
96
87
  self._inputs = inputs
doctr/datasets/cord.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  import json
7
7
  import os
8
8
  from pathlib import Path
9
- from typing import Any, Dict, List, Tuple, Union
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  from tqdm import tqdm
@@ -29,10 +29,10 @@ class CORD(VisionDataset):
29
29
  >>> img, target = train_set[0]
30
30
 
31
31
  Args:
32
- ----
33
32
  train: whether the subset should be the training one
34
33
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
35
34
  recognition_task: whether the dataset should be used for recognition task
35
+ detection_task: whether the dataset should be used for detection task
36
36
  **kwargs: keyword arguments from `VisionDataset`.
37
37
  """
38
38
 
@@ -53,6 +53,7 @@ class CORD(VisionDataset):
53
53
  train: bool = True,
54
54
  use_polygons: bool = False,
55
55
  recognition_task: bool = False,
56
+ detection_task: bool = False,
56
57
  **kwargs: Any,
57
58
  ) -> None:
58
59
  url, sha256, name = self.TRAIN if train else self.TEST
@@ -64,13 +65,20 @@ class CORD(VisionDataset):
64
65
  pre_transforms=convert_target_to_relative if not recognition_task else None,
65
66
  **kwargs,
66
67
  )
68
+ if recognition_task and detection_task:
69
+ raise ValueError(
70
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
71
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
72
+ )
67
73
 
68
- # List images
74
+ # list images
69
75
  tmp_root = os.path.join(self.root, "image")
70
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
76
+ self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
71
77
  self.train = train
72
78
  np_dtype = np.float32
73
- for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking CORD", total=len(os.listdir(tmp_root))):
79
+ for img_path in tqdm(
80
+ iterable=os.listdir(tmp_root), desc="Preparing and Loading CORD", total=len(os.listdir(tmp_root))
81
+ ):
74
82
  # File existence check
75
83
  if not os.path.exists(os.path.join(tmp_root, img_path)):
76
84
  raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
@@ -84,7 +92,7 @@ class CORD(VisionDataset):
84
92
  if len(word["text"]) > 0:
85
93
  x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"]
86
94
  y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"]
87
- box: Union[List[float], np.ndarray]
95
+ box: list[float] | np.ndarray
88
96
  if use_polygons:
89
97
  # (x, y) coordinates of top left, top right, bottom right, bottom left corners
90
98
  box = np.array(
@@ -109,6 +117,8 @@ class CORD(VisionDataset):
109
117
  )
110
118
  for crop, label in zip(crops, list(text_targets)):
111
119
  self.data.append((crop, label))
120
+ elif detection_task:
121
+ self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0)))
112
122
  else:
113
123
  self.data.append((
114
124
  img_path,
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,12 +1,13 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import os
7
7
  import shutil
8
+ from collections.abc import Callable
8
9
  from pathlib import Path
9
- from typing import Any, Callable, List, Optional, Tuple, Union
10
+ from typing import Any
10
11
 
11
12
  import numpy as np
12
13
 
@@ -19,15 +20,15 @@ __all__ = ["_AbstractDataset", "_VisionDataset"]
19
20
 
20
21
 
21
22
  class _AbstractDataset:
22
- data: List[Any] = []
23
- _pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None
23
+ data: list[Any] = []
24
+ _pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None
24
25
 
25
26
  def __init__(
26
27
  self,
27
- root: Union[str, Path],
28
- img_transforms: Optional[Callable[[Any], Any]] = None,
29
- sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
30
- pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
28
+ root: str | Path,
29
+ img_transforms: Callable[[Any], Any] | None = None,
30
+ sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
31
+ pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
31
32
  ) -> None:
32
33
  if not Path(root).is_dir():
33
34
  raise ValueError(f"expected a path to a reachable folder: {root}")
@@ -41,10 +42,10 @@ class _AbstractDataset:
41
42
  def __len__(self) -> int:
42
43
  return len(self.data)
43
44
 
44
- def _read_sample(self, index: int) -> Tuple[Any, Any]:
45
+ def _read_sample(self, index: int) -> tuple[Any, Any]:
45
46
  raise NotImplementedError
46
47
 
47
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
48
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
48
49
  # Read image
49
50
  img, target = self._read_sample(index)
50
51
  # Pre-transforms (format conversion at run-time etc.)
@@ -82,7 +83,6 @@ class _VisionDataset(_AbstractDataset):
82
83
  """Implements an abstract dataset
83
84
 
84
85
  Args:
85
- ----
86
86
  url: URL of the dataset
87
87
  file_name: name of the file once downloaded
88
88
  file_hash: expected SHA256 of the file
@@ -96,13 +96,13 @@ class _VisionDataset(_AbstractDataset):
96
96
  def __init__(
97
97
  self,
98
98
  url: str,
99
- file_name: Optional[str] = None,
100
- file_hash: Optional[str] = None,
99
+ file_name: str | None = None,
100
+ file_hash: str | None = None,
101
101
  extract_archive: bool = False,
102
102
  download: bool = False,
103
103
  overwrite: bool = False,
104
- cache_dir: Optional[str] = None,
105
- cache_subdir: Optional[str] = None,
104
+ cache_dir: str | None = None,
105
+ cache_subdir: str | None = None,
106
106
  **kwargs: Any,
107
107
  ) -> None:
108
108
  cache_dir = (
@@ -115,7 +115,7 @@ class _VisionDataset(_AbstractDataset):
115
115
 
116
116
  file_name = file_name if isinstance(file_name, str) else os.path.basename(url)
117
117
  # Download the file if not present
118
- archive_path: Union[str, Path] = os.path.join(cache_dir, cache_subdir, file_name)
118
+ archive_path: str | Path = os.path.join(cache_dir, cache_subdir, file_name)
119
119
 
120
120
  if not os.path.exists(archive_path) and not download:
121
121
  raise ValueError("the dataset needs to be downloaded first with download=True")
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import os
7
7
  from copy import deepcopy
8
- from typing import Any, List, Tuple
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
  import torch
@@ -20,7 +20,7 @@ __all__ = ["AbstractDataset", "VisionDataset"]
20
20
  class AbstractDataset(_AbstractDataset):
21
21
  """Abstract class for all datasets"""
22
22
 
23
- def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]:
23
+ def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]:
24
24
  img_name, target = self.data[index]
25
25
 
26
26
  # Check target
@@ -29,14 +29,14 @@ class AbstractDataset(_AbstractDataset):
29
29
  assert "labels" in target, "Target should contain 'labels' key"
30
30
  elif isinstance(target, tuple):
31
31
  assert len(target) == 2
32
- assert isinstance(target[0], str) or isinstance(
33
- target[0], np.ndarray
34
- ), "first element of the tuple should be a string or a numpy array"
32
+ assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
33
+ "first element of the tuple should be a string or a numpy array"
34
+ )
35
35
  assert isinstance(target[1], list), "second element of the tuple should be a list"
36
36
  else:
37
- assert isinstance(target, str) or isinstance(
38
- target, np.ndarray
39
- ), "Target should be a string or a numpy array"
37
+ assert isinstance(target, str) or isinstance(target, np.ndarray), (
38
+ "Target should be a string or a numpy array"
39
+ )
40
40
 
41
41
  # Read image
42
42
  img = (
@@ -48,11 +48,11 @@ class AbstractDataset(_AbstractDataset):
48
48
  return img, deepcopy(target)
49
49
 
50
50
  @staticmethod
51
- def collate_fn(samples: List[Tuple[torch.Tensor, Any]]) -> Tuple[torch.Tensor, List[Any]]:
51
+ def collate_fn(samples: list[tuple[torch.Tensor, Any]]) -> tuple[torch.Tensor, list[Any]]:
52
52
  images, targets = zip(*samples)
53
- images = torch.stack(images, dim=0) # type: ignore[assignment]
53
+ images = torch.stack(images, dim=0)
54
54
 
55
- return images, list(targets) # type: ignore[return-value]
55
+ return images, list(targets)
56
56
 
57
57
 
58
58
  class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import os
7
7
  from copy import deepcopy
8
- from typing import Any, List, Tuple
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
  import tensorflow as tf
@@ -20,7 +20,7 @@ __all__ = ["AbstractDataset", "VisionDataset"]
20
20
  class AbstractDataset(_AbstractDataset):
21
21
  """Abstract class for all datasets"""
22
22
 
23
- def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]:
23
+ def _read_sample(self, index: int) -> tuple[tf.Tensor, Any]:
24
24
  img_name, target = self.data[index]
25
25
 
26
26
  # Check target
@@ -29,14 +29,14 @@ class AbstractDataset(_AbstractDataset):
29
29
  assert "labels" in target, "Target should contain 'labels' key"
30
30
  elif isinstance(target, tuple):
31
31
  assert len(target) == 2
32
- assert isinstance(target[0], str) or isinstance(
33
- target[0], np.ndarray
34
- ), "first element of the tuple should be a string or a numpy array"
32
+ assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
33
+ "first element of the tuple should be a string or a numpy array"
34
+ )
35
35
  assert isinstance(target[1], list), "second element of the tuple should be a list"
36
36
  else:
37
- assert isinstance(target, str) or isinstance(
38
- target, np.ndarray
39
- ), "Target should be a string or a numpy array"
37
+ assert isinstance(target, str) or isinstance(target, np.ndarray), (
38
+ "Target should be a string or a numpy array"
39
+ )
40
40
 
41
41
  # Read image
42
42
  img = (
@@ -48,7 +48,7 @@ class AbstractDataset(_AbstractDataset):
48
48
  return img, deepcopy(target)
49
49
 
50
50
  @staticmethod
51
- def collate_fn(samples: List[Tuple[tf.Tensor, Any]]) -> Tuple[tf.Tensor, List[Any]]:
51
+ def collate_fn(samples: list[tuple[tf.Tensor, Any]]) -> tuple[tf.Tensor, list[Any]]:
52
52
  images, targets = zip(*samples)
53
53
  images = tf.stack(images, axis=0)
54
54
 
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import json
7
7
  import os
8
- from typing import Any, Dict, List, Tuple, Type, Union
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
 
@@ -26,7 +26,6 @@ class DetectionDataset(AbstractDataset):
26
26
  >>> img, target = train_set[0]
27
27
 
28
28
  Args:
29
- ----
30
29
  img_folder: folder with all the images of the dataset
31
30
  label_path: path to the annotations of each image
32
31
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
@@ -47,13 +46,13 @@ class DetectionDataset(AbstractDataset):
47
46
  )
48
47
 
49
48
  # File existence check
50
- self._class_names: List = []
49
+ self._class_names: list = []
51
50
  if not os.path.exists(label_path):
52
51
  raise FileNotFoundError(f"unable to locate {label_path}")
53
52
  with open(label_path, "rb") as f:
54
53
  labels = json.load(f)
55
54
 
56
- self.data: List[Tuple[str, Tuple[np.ndarray, List[str]]]] = []
55
+ self.data: list[tuple[str, tuple[np.ndarray, list[str]]]] = []
57
56
  np_dtype = np.float32
58
57
  for img_name, label in labels.items():
59
58
  # File existence check
@@ -65,18 +64,16 @@ class DetectionDataset(AbstractDataset):
65
64
  self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
66
65
 
67
66
  def format_polygons(
68
- self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type
69
- ) -> Tuple[np.ndarray, List[str]]:
67
+ self, polygons: list | dict, use_polygons: bool, np_dtype: type
68
+ ) -> tuple[np.ndarray, list[str]]:
70
69
  """Format polygons into an array
71
70
 
72
71
  Args:
73
- ----
74
72
  polygons: the bounding boxes
75
73
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
76
74
  np_dtype: dtype of array
77
75
 
78
76
  Returns:
79
- -------
80
77
  geoms: bounding boxes as np array
81
78
  polygons_classes: list of classes for each bounding box
82
79
  """
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import json
7
7
  import os
8
- from typing import Any, Dict, List, Tuple
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
 
@@ -26,7 +26,6 @@ class DocArtefacts(VisionDataset):
26
26
  >>> img, target = train_set[0]
27
27
 
28
28
  Args:
29
- ----
30
29
  train: whether the subset should be the training one
31
30
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
32
31
  **kwargs: keyword arguments from `VisionDataset`.
@@ -51,7 +50,7 @@ class DocArtefacts(VisionDataset):
51
50
  tmp_root = os.path.join(self.root, "images")
52
51
  with open(os.path.join(self.root, "labels.json"), "rb") as f:
53
52
  labels = json.load(f)
54
- self.data: List[Tuple[str, Dict[str, Any]]] = []
53
+ self.data: list[tuple[str, dict[str, Any]]] = []
55
54
  img_list = os.listdir(tmp_root)
56
55
  if len(labels) != len(img_list):
57
56
  raise AssertionError("the number of images and labels do not match")
doctr/datasets/funsd.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  import json
7
7
  import os
8
8
  from pathlib import Path
9
- from typing import Any, Dict, List, Tuple, Union
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  from tqdm import tqdm
@@ -29,10 +29,10 @@ class FUNSD(VisionDataset):
29
29
  >>> img, target = train_set[0]
30
30
 
31
31
  Args:
32
- ----
33
32
  train: whether the subset should be the training one
34
33
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
35
34
  recognition_task: whether the dataset should be used for recognition task
35
+ detection_task: whether the dataset should be used for detection task
36
36
  **kwargs: keyword arguments from `VisionDataset`.
37
37
  """
38
38
 
@@ -45,6 +45,7 @@ class FUNSD(VisionDataset):
45
45
  train: bool = True,
46
46
  use_polygons: bool = False,
47
47
  recognition_task: bool = False,
48
+ detection_task: bool = False,
48
49
  **kwargs: Any,
49
50
  ) -> None:
50
51
  super().__init__(
@@ -55,16 +56,24 @@ class FUNSD(VisionDataset):
55
56
  pre_transforms=convert_target_to_relative if not recognition_task else None,
56
57
  **kwargs,
57
58
  )
59
+ if recognition_task and detection_task:
60
+ raise ValueError(
61
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
62
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
63
+ )
64
+
58
65
  self.train = train
59
66
  np_dtype = np.float32
60
67
 
61
68
  # Use the subset
62
69
  subfolder = os.path.join("dataset", "training_data" if train else "testing_data")
63
70
 
64
- # # List images
71
+ # # list images
65
72
  tmp_root = os.path.join(self.root, subfolder, "images")
66
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
67
- for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking FUNSD", total=len(os.listdir(tmp_root))):
73
+ self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
74
+ for img_path in tqdm(
75
+ iterable=os.listdir(tmp_root), desc="Preparing and Loading FUNSD", total=len(os.listdir(tmp_root))
76
+ ):
68
77
  # File existence check
69
78
  if not os.path.exists(os.path.join(tmp_root, img_path)):
70
79
  raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
@@ -100,6 +109,8 @@ class FUNSD(VisionDataset):
100
109
  # filter labels with unknown characters
101
110
  if not any(char in label for char in ["☑", "☐", "\uf703", "\uf702"]):
102
111
  self.data.append((crop, label))
112
+ elif detection_task:
113
+ self.data.append((img_path, np.asarray(box_targets, dtype=np_dtype)))
103
114
  else:
104
115
  self.data.append((
105
116
  img_path,
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]