python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +8 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +7 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +4 -5
  17. doctr/datasets/ic13.py +4 -5
  18. doctr/datasets/iiit5k.py +6 -5
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +6 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +6 -5
  27. doctr/datasets/svhn.py +6 -5
  28. doctr/datasets/svt.py +4 -5
  29. doctr/datasets/synthtext.py +4 -5
  30. doctr/datasets/utils.py +34 -29
  31. doctr/datasets/vocabs.py +17 -7
  32. doctr/datasets/wildreceipt.py +14 -10
  33. doctr/file_utils.py +2 -7
  34. doctr/io/elements.py +59 -79
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +30 -48
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +8 -11
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +5 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +8 -21
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +6 -8
  52. doctr/models/classification/predictor/tensorflow.py +6 -8
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +20 -31
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +8 -15
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +9 -12
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +6 -12
  65. doctr/models/classification/zoo.py +19 -14
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +14 -26
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +14 -23
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +5 -6
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +3 -7
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +4 -5
  91. doctr/models/kie_predictor/pytorch.py +18 -19
  92. doctr/models/kie_predictor/tensorflow.py +13 -14
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -10
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  101. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +28 -29
  104. doctr/models/predictor/pytorch.py +12 -13
  105. doctr/models/predictor/tensorflow.py +8 -9
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +10 -14
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +11 -23
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +12 -22
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +16 -22
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +12 -21
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +12 -20
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +14 -17
  136. doctr/models/utils/tensorflow.py +17 -16
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +20 -28
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +58 -22
  145. doctr/transforms/modules/tensorflow.py +18 -32
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +16 -47
  150. doctr/utils/metrics.py +17 -37
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +9 -13
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.10.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
doctr/contrib/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ from .artefacts import ArtefactDetector
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Optional, Tuple
6
+ from typing import Any
7
7
 
8
8
  import cv2
9
9
  import numpy as np
@@ -14,7 +14,7 @@ from .base import _BasePredictor
14
14
 
15
15
  __all__ = ["ArtefactDetector"]
16
16
 
17
- default_cfgs: Dict[str, Dict[str, Any]] = {
17
+ default_cfgs: dict[str, dict[str, Any]] = {
18
18
  "yolov8_artefact": {
19
19
  "input_shape": (3, 1024, 1024),
20
20
  "labels": ["bar_code", "qr_code", "logo", "photo"],
@@ -34,7 +34,6 @@ class ArtefactDetector(_BasePredictor):
34
34
  >>> results = detector(doc)
35
35
 
36
36
  Args:
37
- ----
38
37
  arch: the architecture to use
39
38
  batch_size: the batch size to use
40
39
  model_path: the path to the model to use
@@ -50,9 +49,9 @@ class ArtefactDetector(_BasePredictor):
50
49
  self,
51
50
  arch: str = "yolov8_artefact",
52
51
  batch_size: int = 2,
53
- model_path: Optional[str] = None,
54
- labels: Optional[List[str]] = None,
55
- input_shape: Optional[Tuple[int, int, int]] = None,
52
+ model_path: str | None = None,
53
+ labels: list[str] | None = None,
54
+ input_shape: tuple[int, int, int] | None = None,
56
55
  conf_threshold: float = 0.5,
57
56
  iou_threshold: float = 0.5,
58
57
  **kwargs: Any,
@@ -66,7 +65,7 @@ class ArtefactDetector(_BasePredictor):
66
65
  def preprocess(self, img: np.ndarray) -> np.ndarray:
67
66
  return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0)
68
67
 
69
- def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]:
68
+ def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> list[list[dict[str, Any]]]:
70
69
  results = []
71
70
 
72
71
  for batch in zip(output, input_images):
@@ -109,7 +108,6 @@ class ArtefactDetector(_BasePredictor):
109
108
  Display the results
110
109
 
111
110
  Args:
112
- ----
113
111
  **kwargs: additional keyword arguments to be passed to `plt.show`
114
112
  """
115
113
  requires_package("matplotlib", "`.show()` requires matplotlib installed")
doctr/contrib/base.py CHANGED
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Optional
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
 
@@ -16,32 +16,29 @@ class _BasePredictor:
16
16
  Base class for all predictors
17
17
 
18
18
  Args:
19
- ----
20
19
  batch_size: the batch size to use
21
20
  url: the url to use to download a model if needed
22
21
  model_path: the path to the model to use
23
22
  **kwargs: additional arguments to be passed to `download_from_url`
24
23
  """
25
24
 
26
- def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None:
25
+ def __init__(self, batch_size: int, url: str | None = None, model_path: str | None = None, **kwargs) -> None:
27
26
  self.batch_size = batch_size
28
27
  self.session = self._init_model(url, model_path, **kwargs)
29
28
 
30
- self._inputs: List[np.ndarray] = []
31
- self._results: List[Any] = []
29
+ self._inputs: list[np.ndarray] = []
30
+ self._results: list[Any] = []
32
31
 
33
- def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any:
32
+ def _init_model(self, url: str | None = None, model_path: str | None = None, **kwargs: Any) -> Any:
34
33
  """
35
34
  Download the model from the given url if needed
36
35
 
37
36
  Args:
38
- ----
39
37
  url: the url to use
40
38
  model_path: the path to the model to use
41
39
  **kwargs: additional arguments to be passed to `download_from_url`
42
40
 
43
41
  Returns:
44
- -------
45
42
  Any: the ONNX loaded model
46
43
  """
47
44
  requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
@@ -57,40 +54,34 @@ class _BasePredictor:
57
54
  Preprocess the input image
58
55
 
59
56
  Args:
60
- ----
61
57
  img: the input image to preprocess
62
58
 
63
59
  Returns:
64
- -------
65
60
  np.ndarray: the preprocessed image
66
61
  """
67
62
  raise NotImplementedError
68
63
 
69
- def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any:
64
+ def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> Any:
70
65
  """
71
66
  Postprocess the model output
72
67
 
73
68
  Args:
74
- ----
75
69
  output: the model output to postprocess
76
70
  input_images: the input images used to generate the output
77
71
 
78
72
  Returns:
79
- -------
80
73
  Any: the postprocessed output
81
74
  """
82
75
  raise NotImplementedError
83
76
 
84
- def __call__(self, inputs: List[np.ndarray]) -> Any:
77
+ def __call__(self, inputs: list[np.ndarray]) -> Any:
85
78
  """
86
79
  Call the model on the given inputs
87
80
 
88
81
  Args:
89
- ----
90
82
  inputs: the inputs to use
91
83
 
92
84
  Returns:
93
- -------
94
85
  Any: the postprocessed output
95
86
  """
96
87
  self._inputs = inputs
doctr/datasets/cord.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  import json
7
7
  import os
8
8
  from pathlib import Path
9
- from typing import Any, Dict, List, Tuple, Union
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  from tqdm import tqdm
@@ -29,7 +29,6 @@ class CORD(VisionDataset):
29
29
  >>> img, target = train_set[0]
30
30
 
31
31
  Args:
32
- ----
33
32
  train: whether the subset should be the training one
34
33
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
35
34
  recognition_task: whether the dataset should be used for recognition task
@@ -72,12 +71,14 @@ class CORD(VisionDataset):
72
71
  + "To get the whole dataset with boxes and labels leave both parameters to False."
73
72
  )
74
73
 
75
- # List images
74
+ # list images
76
75
  tmp_root = os.path.join(self.root, "image")
77
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
76
+ self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
78
77
  self.train = train
79
78
  np_dtype = np.float32
80
- for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking CORD", total=len(os.listdir(tmp_root))):
79
+ for img_path in tqdm(
80
+ iterable=os.listdir(tmp_root), desc="Preparing and Loading CORD", total=len(os.listdir(tmp_root))
81
+ ):
81
82
  # File existence check
82
83
  if not os.path.exists(os.path.join(tmp_root, img_path)):
83
84
  raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
@@ -91,7 +92,7 @@ class CORD(VisionDataset):
91
92
  if len(word["text"]) > 0:
92
93
  x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"]
93
94
  y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"]
94
- box: Union[List[float], np.ndarray]
95
+ box: list[float] | np.ndarray
95
96
  if use_polygons:
96
97
  # (x, y) coordinates of top left, top right, bottom right, bottom left corners
97
98
  box = np.array(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,12 +1,13 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import os
7
7
  import shutil
8
+ from collections.abc import Callable
8
9
  from pathlib import Path
9
- from typing import Any, Callable, List, Optional, Tuple, Union
10
+ from typing import Any
10
11
 
11
12
  import numpy as np
12
13
 
@@ -19,15 +20,15 @@ __all__ = ["_AbstractDataset", "_VisionDataset"]
19
20
 
20
21
 
21
22
  class _AbstractDataset:
22
- data: List[Any] = []
23
- _pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None
23
+ data: list[Any] = []
24
+ _pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None
24
25
 
25
26
  def __init__(
26
27
  self,
27
- root: Union[str, Path],
28
- img_transforms: Optional[Callable[[Any], Any]] = None,
29
- sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
30
- pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
28
+ root: str | Path,
29
+ img_transforms: Callable[[Any], Any] | None = None,
30
+ sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
31
+ pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
31
32
  ) -> None:
32
33
  if not Path(root).is_dir():
33
34
  raise ValueError(f"expected a path to a reachable folder: {root}")
@@ -41,10 +42,10 @@ class _AbstractDataset:
41
42
  def __len__(self) -> int:
42
43
  return len(self.data)
43
44
 
44
- def _read_sample(self, index: int) -> Tuple[Any, Any]:
45
+ def _read_sample(self, index: int) -> tuple[Any, Any]:
45
46
  raise NotImplementedError
46
47
 
47
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
48
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
48
49
  # Read image
49
50
  img, target = self._read_sample(index)
50
51
  # Pre-transforms (format conversion at run-time etc.)
@@ -82,7 +83,6 @@ class _VisionDataset(_AbstractDataset):
82
83
  """Implements an abstract dataset
83
84
 
84
85
  Args:
85
- ----
86
86
  url: URL of the dataset
87
87
  file_name: name of the file once downloaded
88
88
  file_hash: expected SHA256 of the file
@@ -96,13 +96,13 @@ class _VisionDataset(_AbstractDataset):
96
96
  def __init__(
97
97
  self,
98
98
  url: str,
99
- file_name: Optional[str] = None,
100
- file_hash: Optional[str] = None,
99
+ file_name: str | None = None,
100
+ file_hash: str | None = None,
101
101
  extract_archive: bool = False,
102
102
  download: bool = False,
103
103
  overwrite: bool = False,
104
- cache_dir: Optional[str] = None,
105
- cache_subdir: Optional[str] = None,
104
+ cache_dir: str | None = None,
105
+ cache_subdir: str | None = None,
106
106
  **kwargs: Any,
107
107
  ) -> None:
108
108
  cache_dir = (
@@ -115,7 +115,7 @@ class _VisionDataset(_AbstractDataset):
115
115
 
116
116
  file_name = file_name if isinstance(file_name, str) else os.path.basename(url)
117
117
  # Download the file if not present
118
- archive_path: Union[str, Path] = os.path.join(cache_dir, cache_subdir, file_name)
118
+ archive_path: str | Path = os.path.join(cache_dir, cache_subdir, file_name)
119
119
 
120
120
  if not os.path.exists(archive_path) and not download:
121
121
  raise ValueError("the dataset needs to be downloaded first with download=True")
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import os
7
7
  from copy import deepcopy
8
- from typing import Any, List, Tuple
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
  import torch
@@ -20,7 +20,7 @@ __all__ = ["AbstractDataset", "VisionDataset"]
20
20
  class AbstractDataset(_AbstractDataset):
21
21
  """Abstract class for all datasets"""
22
22
 
23
- def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]:
23
+ def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]:
24
24
  img_name, target = self.data[index]
25
25
 
26
26
  # Check target
@@ -29,14 +29,14 @@ class AbstractDataset(_AbstractDataset):
29
29
  assert "labels" in target, "Target should contain 'labels' key"
30
30
  elif isinstance(target, tuple):
31
31
  assert len(target) == 2
32
- assert isinstance(target[0], str) or isinstance(
33
- target[0], np.ndarray
34
- ), "first element of the tuple should be a string or a numpy array"
32
+ assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
33
+ "first element of the tuple should be a string or a numpy array"
34
+ )
35
35
  assert isinstance(target[1], list), "second element of the tuple should be a list"
36
36
  else:
37
- assert isinstance(target, str) or isinstance(
38
- target, np.ndarray
39
- ), "Target should be a string or a numpy array"
37
+ assert isinstance(target, str) or isinstance(target, np.ndarray), (
38
+ "Target should be a string or a numpy array"
39
+ )
40
40
 
41
41
  # Read image
42
42
  img = (
@@ -48,11 +48,11 @@ class AbstractDataset(_AbstractDataset):
48
48
  return img, deepcopy(target)
49
49
 
50
50
  @staticmethod
51
- def collate_fn(samples: List[Tuple[torch.Tensor, Any]]) -> Tuple[torch.Tensor, List[Any]]:
51
+ def collate_fn(samples: list[tuple[torch.Tensor, Any]]) -> tuple[torch.Tensor, list[Any]]:
52
52
  images, targets = zip(*samples)
53
- images = torch.stack(images, dim=0) # type: ignore[assignment]
53
+ images = torch.stack(images, dim=0)
54
54
 
55
- return images, list(targets) # type: ignore[return-value]
55
+ return images, list(targets)
56
56
 
57
57
 
58
58
  class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import os
7
7
  from copy import deepcopy
8
- from typing import Any, List, Tuple
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
  import tensorflow as tf
@@ -20,7 +20,7 @@ __all__ = ["AbstractDataset", "VisionDataset"]
20
20
  class AbstractDataset(_AbstractDataset):
21
21
  """Abstract class for all datasets"""
22
22
 
23
- def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]:
23
+ def _read_sample(self, index: int) -> tuple[tf.Tensor, Any]:
24
24
  img_name, target = self.data[index]
25
25
 
26
26
  # Check target
@@ -29,14 +29,14 @@ class AbstractDataset(_AbstractDataset):
29
29
  assert "labels" in target, "Target should contain 'labels' key"
30
30
  elif isinstance(target, tuple):
31
31
  assert len(target) == 2
32
- assert isinstance(target[0], str) or isinstance(
33
- target[0], np.ndarray
34
- ), "first element of the tuple should be a string or a numpy array"
32
+ assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
33
+ "first element of the tuple should be a string or a numpy array"
34
+ )
35
35
  assert isinstance(target[1], list), "second element of the tuple should be a list"
36
36
  else:
37
- assert isinstance(target, str) or isinstance(
38
- target, np.ndarray
39
- ), "Target should be a string or a numpy array"
37
+ assert isinstance(target, str) or isinstance(target, np.ndarray), (
38
+ "Target should be a string or a numpy array"
39
+ )
40
40
 
41
41
  # Read image
42
42
  img = (
@@ -48,7 +48,7 @@ class AbstractDataset(_AbstractDataset):
48
48
  return img, deepcopy(target)
49
49
 
50
50
  @staticmethod
51
- def collate_fn(samples: List[Tuple[tf.Tensor, Any]]) -> Tuple[tf.Tensor, List[Any]]:
51
+ def collate_fn(samples: list[tuple[tf.Tensor, Any]]) -> tuple[tf.Tensor, list[Any]]:
52
52
  images, targets = zip(*samples)
53
53
  images = tf.stack(images, axis=0)
54
54
 
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import json
7
7
  import os
8
- from typing import Any, Dict, List, Tuple, Type, Union
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
 
@@ -26,7 +26,6 @@ class DetectionDataset(AbstractDataset):
26
26
  >>> img, target = train_set[0]
27
27
 
28
28
  Args:
29
- ----
30
29
  img_folder: folder with all the images of the dataset
31
30
  label_path: path to the annotations of each image
32
31
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
@@ -47,13 +46,13 @@ class DetectionDataset(AbstractDataset):
47
46
  )
48
47
 
49
48
  # File existence check
50
- self._class_names: List = []
49
+ self._class_names: list = []
51
50
  if not os.path.exists(label_path):
52
51
  raise FileNotFoundError(f"unable to locate {label_path}")
53
52
  with open(label_path, "rb") as f:
54
53
  labels = json.load(f)
55
54
 
56
- self.data: List[Tuple[str, Tuple[np.ndarray, List[str]]]] = []
55
+ self.data: list[tuple[str, tuple[np.ndarray, list[str]]]] = []
57
56
  np_dtype = np.float32
58
57
  for img_name, label in labels.items():
59
58
  # File existence check
@@ -65,18 +64,16 @@ class DetectionDataset(AbstractDataset):
65
64
  self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
66
65
 
67
66
  def format_polygons(
68
- self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type
69
- ) -> Tuple[np.ndarray, List[str]]:
67
+ self, polygons: list | dict, use_polygons: bool, np_dtype: type
68
+ ) -> tuple[np.ndarray, list[str]]:
70
69
  """Format polygons into an array
71
70
 
72
71
  Args:
73
- ----
74
72
  polygons: the bounding boxes
75
73
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
76
74
  np_dtype: dtype of array
77
75
 
78
76
  Returns:
79
- -------
80
77
  geoms: bounding boxes as np array
81
78
  polygons_classes: list of classes for each bounding box
82
79
  """
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import json
7
7
  import os
8
- from typing import Any, Dict, List, Tuple
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
 
@@ -26,7 +26,6 @@ class DocArtefacts(VisionDataset):
26
26
  >>> img, target = train_set[0]
27
27
 
28
28
  Args:
29
- ----
30
29
  train: whether the subset should be the training one
31
30
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
32
31
  **kwargs: keyword arguments from `VisionDataset`.
@@ -51,7 +50,7 @@ class DocArtefacts(VisionDataset):
51
50
  tmp_root = os.path.join(self.root, "images")
52
51
  with open(os.path.join(self.root, "labels.json"), "rb") as f:
53
52
  labels = json.load(f)
54
- self.data: List[Tuple[str, Dict[str, Any]]] = []
53
+ self.data: list[tuple[str, dict[str, Any]]] = []
55
54
  img_list = os.listdir(tmp_root)
56
55
  if len(labels) != len(img_list):
57
56
  raise AssertionError("the number of images and labels do not match")
doctr/datasets/funsd.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  import json
7
7
  import os
8
8
  from pathlib import Path
9
- from typing import Any, Dict, List, Tuple, Union
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  from tqdm import tqdm
@@ -29,7 +29,6 @@ class FUNSD(VisionDataset):
29
29
  >>> img, target = train_set[0]
30
30
 
31
31
  Args:
32
- ----
33
32
  train: whether the subset should be the training one
34
33
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
35
34
  recognition_task: whether the dataset should be used for recognition task
@@ -69,10 +68,12 @@ class FUNSD(VisionDataset):
69
68
  # Use the subset
70
69
  subfolder = os.path.join("dataset", "training_data" if train else "testing_data")
71
70
 
72
- # # List images
71
+ # # list images
73
72
  tmp_root = os.path.join(self.root, subfolder, "images")
74
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
75
- for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking FUNSD", total=len(os.listdir(tmp_root))):
73
+ self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
74
+ for img_path in tqdm(
75
+ iterable=os.listdir(tmp_root), desc="Preparing and Loading FUNSD", total=len(os.listdir(tmp_root))
76
+ ):
76
77
  # File existence check
77
78
  if not os.path.exists(os.path.join(tmp_root, img_path)):
78
79
  raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import random
7
- from typing import Any, Callable, List, Optional, Tuple, Union
7
+ from collections.abc import Callable
8
+ from typing import Any
8
9
 
9
10
  from PIL import Image, ImageDraw
10
11
 
@@ -17,14 +18,13 @@ from ..datasets import AbstractDataset
17
18
  def synthesize_text_img(
18
19
  text: str,
19
20
  font_size: int = 32,
20
- font_family: Optional[str] = None,
21
- background_color: Optional[Tuple[int, int, int]] = None,
22
- text_color: Optional[Tuple[int, int, int]] = None,
21
+ font_family: str | None = None,
22
+ background_color: tuple[int, int, int] | None = None,
23
+ text_color: tuple[int, int, int] | None = None,
23
24
  ) -> Image.Image:
24
25
  """Generate a synthetic text image
25
26
 
26
27
  Args:
27
- ----
28
28
  text: the text to render as an image
29
29
  font_size: the size of the font
30
30
  font_family: the font family (has to be installed on your system)
@@ -32,7 +32,6 @@ def synthesize_text_img(
32
32
  text_color: text color on the final image
33
33
 
34
34
  Returns:
35
- -------
36
35
  PIL image of the text
37
36
  """
38
37
  background_color = (0, 0, 0) if background_color is None else background_color
@@ -61,9 +60,9 @@ class _CharacterGenerator(AbstractDataset):
61
60
  vocab: str,
62
61
  num_samples: int,
63
62
  cache_samples: bool = False,
64
- font_family: Optional[Union[str, List[str]]] = None,
65
- img_transforms: Optional[Callable[[Any], Any]] = None,
66
- sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
63
+ font_family: str | list[str] | None = None,
64
+ img_transforms: Callable[[Any], Any] | None = None,
65
+ sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
67
66
  ) -> None:
68
67
  self.vocab = vocab
69
68
  self._num_samples = num_samples
@@ -78,7 +77,7 @@ class _CharacterGenerator(AbstractDataset):
78
77
  self.img_transforms = img_transforms
79
78
  self.sample_transforms = sample_transforms
80
79
 
81
- self._data: List[Image.Image] = []
80
+ self._data: list[Image.Image] = []
82
81
  if cache_samples:
83
82
  self._data = [
84
83
  (synthesize_text_img(char, font_family=font), idx) # type: ignore[misc]
@@ -89,7 +88,7 @@ class _CharacterGenerator(AbstractDataset):
89
88
  def __len__(self) -> int:
90
89
  return self._num_samples
91
90
 
92
- def _read_sample(self, index: int) -> Tuple[Any, int]:
91
+ def _read_sample(self, index: int) -> tuple[Any, int]:
93
92
  # Samples are already cached
94
93
  if len(self._data) > 0:
95
94
  idx = index % len(self._data)
@@ -110,9 +109,9 @@ class _WordGenerator(AbstractDataset):
110
109
  max_chars: int,
111
110
  num_samples: int,
112
111
  cache_samples: bool = False,
113
- font_family: Optional[Union[str, List[str]]] = None,
114
- img_transforms: Optional[Callable[[Any], Any]] = None,
115
- sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
112
+ font_family: str | list[str] | None = None,
113
+ img_transforms: Callable[[Any], Any] | None = None,
114
+ sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None,
116
115
  ) -> None:
117
116
  self.vocab = vocab
118
117
  self.wordlen_range = (min_chars, max_chars)
@@ -128,7 +127,7 @@ class _WordGenerator(AbstractDataset):
128
127
  self.img_transforms = img_transforms
129
128
  self.sample_transforms = sample_transforms
130
129
 
131
- self._data: List[Image.Image] = []
130
+ self._data: list[Image.Image] = []
132
131
  if cache_samples:
133
132
  _words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)]
134
133
  self._data = [
@@ -143,7 +142,7 @@ class _WordGenerator(AbstractDataset):
143
142
  def __len__(self) -> int:
144
143
  return self._num_samples
145
144
 
146
- def _read_sample(self, index: int) -> Tuple[Any, str]:
145
+ def _read_sample(self, index: int) -> tuple[Any, str]:
147
146
  # Samples are already cached
148
147
  if len(self._data) > 0:
149
148
  pil_img, target = self._data[index] # type: ignore[misc]