python-doctr 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/cord.py +10 -1
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +11 -1
  8. doctr/datasets/generator/base.py +6 -5
  9. doctr/datasets/ic03.py +11 -1
  10. doctr/datasets/ic13.py +10 -1
  11. doctr/datasets/iiit5k.py +26 -16
  12. doctr/datasets/imgur5k.py +11 -2
  13. doctr/datasets/loader.py +1 -6
  14. doctr/datasets/sroie.py +11 -1
  15. doctr/datasets/svhn.py +11 -1
  16. doctr/datasets/svt.py +11 -1
  17. doctr/datasets/synthtext.py +11 -1
  18. doctr/datasets/utils.py +9 -3
  19. doctr/datasets/vocabs.py +15 -4
  20. doctr/datasets/wildreceipt.py +12 -1
  21. doctr/file_utils.py +45 -12
  22. doctr/io/elements.py +52 -10
  23. doctr/io/html.py +2 -2
  24. doctr/io/image/pytorch.py +6 -8
  25. doctr/io/image/tensorflow.py +1 -1
  26. doctr/io/pdf.py +5 -2
  27. doctr/io/reader.py +6 -0
  28. doctr/models/__init__.py +0 -1
  29. doctr/models/_utils.py +57 -20
  30. doctr/models/builder.py +73 -15
  31. doctr/models/classification/magc_resnet/tensorflow.py +13 -6
  32. doctr/models/classification/mobilenet/pytorch.py +47 -9
  33. doctr/models/classification/mobilenet/tensorflow.py +51 -14
  34. doctr/models/classification/predictor/pytorch.py +28 -17
  35. doctr/models/classification/predictor/tensorflow.py +26 -16
  36. doctr/models/classification/resnet/tensorflow.py +21 -8
  37. doctr/models/classification/textnet/pytorch.py +3 -3
  38. doctr/models/classification/textnet/tensorflow.py +11 -5
  39. doctr/models/classification/vgg/tensorflow.py +9 -3
  40. doctr/models/classification/vit/tensorflow.py +10 -4
  41. doctr/models/classification/zoo.py +55 -19
  42. doctr/models/detection/_utils/__init__.py +1 -0
  43. doctr/models/detection/_utils/base.py +66 -0
  44. doctr/models/detection/differentiable_binarization/base.py +4 -3
  45. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  46. doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
  47. doctr/models/detection/fast/base.py +6 -5
  48. doctr/models/detection/fast/pytorch.py +4 -4
  49. doctr/models/detection/fast/tensorflow.py +15 -12
  50. doctr/models/detection/linknet/base.py +4 -3
  51. doctr/models/detection/linknet/tensorflow.py +23 -11
  52. doctr/models/detection/predictor/pytorch.py +15 -1
  53. doctr/models/detection/predictor/tensorflow.py +17 -3
  54. doctr/models/detection/zoo.py +7 -2
  55. doctr/models/factory/hub.py +8 -18
  56. doctr/models/kie_predictor/base.py +13 -3
  57. doctr/models/kie_predictor/pytorch.py +45 -20
  58. doctr/models/kie_predictor/tensorflow.py +44 -17
  59. doctr/models/modules/layers/pytorch.py +2 -3
  60. doctr/models/modules/layers/tensorflow.py +6 -8
  61. doctr/models/modules/transformer/pytorch.py +2 -2
  62. doctr/models/modules/transformer/tensorflow.py +0 -2
  63. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  64. doctr/models/modules/vision_transformer/tensorflow.py +1 -1
  65. doctr/models/predictor/base.py +97 -58
  66. doctr/models/predictor/pytorch.py +35 -20
  67. doctr/models/predictor/tensorflow.py +35 -18
  68. doctr/models/preprocessor/pytorch.py +4 -4
  69. doctr/models/preprocessor/tensorflow.py +3 -2
  70. doctr/models/recognition/crnn/tensorflow.py +8 -6
  71. doctr/models/recognition/master/pytorch.py +2 -2
  72. doctr/models/recognition/master/tensorflow.py +9 -4
  73. doctr/models/recognition/parseq/pytorch.py +4 -3
  74. doctr/models/recognition/parseq/tensorflow.py +14 -11
  75. doctr/models/recognition/sar/pytorch.py +7 -6
  76. doctr/models/recognition/sar/tensorflow.py +10 -12
  77. doctr/models/recognition/vitstr/pytorch.py +1 -1
  78. doctr/models/recognition/vitstr/tensorflow.py +9 -4
  79. doctr/models/recognition/zoo.py +1 -1
  80. doctr/models/utils/pytorch.py +1 -1
  81. doctr/models/utils/tensorflow.py +15 -15
  82. doctr/models/zoo.py +2 -2
  83. doctr/py.typed +0 -0
  84. doctr/transforms/functional/base.py +1 -1
  85. doctr/transforms/functional/pytorch.py +5 -5
  86. doctr/transforms/modules/base.py +37 -15
  87. doctr/transforms/modules/pytorch.py +73 -14
  88. doctr/transforms/modules/tensorflow.py +78 -19
  89. doctr/utils/fonts.py +7 -5
  90. doctr/utils/geometry.py +141 -31
  91. doctr/utils/metrics.py +34 -175
  92. doctr/utils/reconstitution.py +212 -0
  93. doctr/utils/visualization.py +5 -118
  94. doctr/version.py +1 -1
  95. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/METADATA +85 -81
  96. python_doctr-0.10.0.dist-info/RECORD +173 -0
  97. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
  98. doctr/models/artefacts/__init__.py +0 -2
  99. doctr/models/artefacts/barcode.py +0 -74
  100. doctr/models/artefacts/face.py +0 -63
  101. doctr/models/obj_detection/__init__.py +0 -1
  102. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  103. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  104. python_doctr-0.8.1.dist-info/RECORD +0 -173
  105. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
  106. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
  107. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
doctr/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from . import io, models, datasets, transforms, utils
1
+ from . import io, models, datasets, contrib, transforms, utils
2
2
  from .file_utils import is_tf_available, is_torch_available
3
3
  from .version import __version__ # noqa: F401
File without changes
@@ -0,0 +1,131 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ from doctr.file_utils import requires_package
12
+
13
+ from .base import _BasePredictor
14
+
15
+ __all__ = ["ArtefactDetector"]
16
+
17
+ default_cfgs: Dict[str, Dict[str, Any]] = {
18
+ "yolov8_artefact": {
19
+ "input_shape": (3, 1024, 1024),
20
+ "labels": ["bar_code", "qr_code", "logo", "photo"],
21
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/yolo_artefact-f9d66f14.onnx&src=0",
22
+ },
23
+ }
24
+
25
+
26
+ class ArtefactDetector(_BasePredictor):
27
+ """
28
+ A class to detect artefacts in images
29
+
30
+ >>> from doctr.io import DocumentFile
31
+ >>> from doctr.contrib.artefacts import ArtefactDetector
32
+ >>> doc = DocumentFile.from_images(["path/to/image.jpg"])
33
+ >>> detector = ArtefactDetector()
34
+ >>> results = detector(doc)
35
+
36
+ Args:
37
+ ----
38
+ arch: the architecture to use
39
+ batch_size: the batch size to use
40
+ model_path: the path to the model to use
41
+ labels: the labels to use
42
+ input_shape: the input shape to use
43
+ mask_labels: the mask labels to use
44
+ conf_threshold: the confidence threshold to use
45
+ iou_threshold: the intersection over union threshold to use
46
+ **kwargs: additional arguments to be passed to `download_from_url`
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ arch: str = "yolov8_artefact",
52
+ batch_size: int = 2,
53
+ model_path: Optional[str] = None,
54
+ labels: Optional[List[str]] = None,
55
+ input_shape: Optional[Tuple[int, int, int]] = None,
56
+ conf_threshold: float = 0.5,
57
+ iou_threshold: float = 0.5,
58
+ **kwargs: Any,
59
+ ) -> None:
60
+ super().__init__(batch_size=batch_size, url=default_cfgs[arch]["url"], model_path=model_path, **kwargs)
61
+ self.labels = labels or default_cfgs[arch]["labels"]
62
+ self.input_shape = input_shape or default_cfgs[arch]["input_shape"]
63
+ self.conf_threshold = conf_threshold
64
+ self.iou_threshold = iou_threshold
65
+
66
+ def preprocess(self, img: np.ndarray) -> np.ndarray:
67
+ return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0)
68
+
69
+ def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]:
70
+ results = []
71
+
72
+ for batch in zip(output, input_images):
73
+ for out, img in zip(batch[0], batch[1]):
74
+ org_height, org_width = img.shape[:2]
75
+ width_scale, height_scale = org_width / self.input_shape[2], org_height / self.input_shape[1]
76
+ for res in out:
77
+ sample_results = []
78
+ for row in np.transpose(np.squeeze(res)):
79
+ classes_scores = row[4:]
80
+ max_score = np.amax(classes_scores)
81
+ if max_score >= self.conf_threshold:
82
+ class_id = np.argmax(classes_scores)
83
+ x, y, w, h = row[0], row[1], row[2], row[3]
84
+ # to rescaled xmin, ymin, xmax, ymax
85
+ xmin = int((x - w / 2) * width_scale)
86
+ ymin = int((y - h / 2) * height_scale)
87
+ xmax = int((x + w / 2) * width_scale)
88
+ ymax = int((y + h / 2) * height_scale)
89
+
90
+ sample_results.append({
91
+ "label": self.labels[class_id],
92
+ "confidence": float(max_score),
93
+ "box": [xmin, ymin, xmax, ymax],
94
+ })
95
+
96
+ # Filter out overlapping boxes
97
+ boxes = [res["box"] for res in sample_results]
98
+ scores = [res["confidence"] for res in sample_results]
99
+ keep_indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf_threshold, self.iou_threshold) # type: ignore[arg-type]
100
+ sample_results = [sample_results[i] for i in keep_indices]
101
+
102
+ results.append(sample_results)
103
+
104
+ self._results = results
105
+ return results
106
+
107
+ def show(self, **kwargs: Any) -> None:
108
+ """
109
+ Display the results
110
+
111
+ Args:
112
+ ----
113
+ **kwargs: additional keyword arguments to be passed to `plt.show`
114
+ """
115
+ requires_package("matplotlib", "`.show()` requires matplotlib installed")
116
+ import matplotlib.pyplot as plt
117
+ from matplotlib.patches import Rectangle
118
+
119
+ # visualize the results with matplotlib
120
+ if self._results and self._inputs:
121
+ for img, res in zip(self._inputs, self._results):
122
+ plt.figure(figsize=(10, 10))
123
+ plt.imshow(img)
124
+ for obj in res:
125
+ xmin, ymin, xmax, ymax = obj["box"]
126
+ label = obj["label"]
127
+ plt.text(xmin, ymin, f"{label} {obj['confidence']:.2f}", color="red")
128
+ plt.gca().add_patch(
129
+ Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor="red", linewidth=2)
130
+ )
131
+ plt.show(**kwargs)
doctr/contrib/base.py ADDED
@@ -0,0 +1,105 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, List, Optional
7
+
8
+ import numpy as np
9
+
10
+ from doctr.file_utils import requires_package
11
+ from doctr.utils.data import download_from_url
12
+
13
+
14
+ class _BasePredictor:
15
+ """
16
+ Base class for all predictors
17
+
18
+ Args:
19
+ ----
20
+ batch_size: the batch size to use
21
+ url: the url to use to download a model if needed
22
+ model_path: the path to the model to use
23
+ **kwargs: additional arguments to be passed to `download_from_url`
24
+ """
25
+
26
+ def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None:
27
+ self.batch_size = batch_size
28
+ self.session = self._init_model(url, model_path, **kwargs)
29
+
30
+ self._inputs: List[np.ndarray] = []
31
+ self._results: List[Any] = []
32
+
33
+ def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any:
34
+ """
35
+ Download the model from the given url if needed
36
+
37
+ Args:
38
+ ----
39
+ url: the url to use
40
+ model_path: the path to the model to use
41
+ **kwargs: additional arguments to be passed to `download_from_url`
42
+
43
+ Returns:
44
+ -------
45
+ Any: the ONNX loaded model
46
+ """
47
+ requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
48
+ import onnxruntime as ort
49
+
50
+ if not url and not model_path:
51
+ raise ValueError("You must provide either a url or a model_path")
52
+ onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs)) # type: ignore[arg-type]
53
+ return ort.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
54
+
55
+ def preprocess(self, img: np.ndarray) -> np.ndarray:
56
+ """
57
+ Preprocess the input image
58
+
59
+ Args:
60
+ ----
61
+ img: the input image to preprocess
62
+
63
+ Returns:
64
+ -------
65
+ np.ndarray: the preprocessed image
66
+ """
67
+ raise NotImplementedError
68
+
69
+ def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any:
70
+ """
71
+ Postprocess the model output
72
+
73
+ Args:
74
+ ----
75
+ output: the model output to postprocess
76
+ input_images: the input images used to generate the output
77
+
78
+ Returns:
79
+ -------
80
+ Any: the postprocessed output
81
+ """
82
+ raise NotImplementedError
83
+
84
+ def __call__(self, inputs: List[np.ndarray]) -> Any:
85
+ """
86
+ Call the model on the given inputs
87
+
88
+ Args:
89
+ ----
90
+ inputs: the inputs to use
91
+
92
+ Returns:
93
+ -------
94
+ Any: the postprocessed output
95
+ """
96
+ self._inputs = inputs
97
+ model_inputs = self.session.get_inputs()
98
+
99
+ batched_inputs = [inputs[i : i + self.batch_size] for i in range(0, len(inputs), self.batch_size)]
100
+ processed_batches = [
101
+ np.array([self.preprocess(img) for img in batch], dtype=np.float32) for batch in batched_inputs
102
+ ]
103
+
104
+ outputs = [self.session.run(None, {model_inputs[0].name: batch}) for batch in processed_batches]
105
+ return self.postprocess(outputs, batched_inputs)
doctr/datasets/cord.py CHANGED
@@ -33,6 +33,7 @@ class CORD(VisionDataset):
33
33
  train: whether the subset should be the training one
34
34
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
35
35
  recognition_task: whether the dataset should be used for recognition task
36
+ detection_task: whether the dataset should be used for detection task
36
37
  **kwargs: keyword arguments from `VisionDataset`.
37
38
  """
38
39
 
@@ -53,6 +54,7 @@ class CORD(VisionDataset):
53
54
  train: bool = True,
54
55
  use_polygons: bool = False,
55
56
  recognition_task: bool = False,
57
+ detection_task: bool = False,
56
58
  **kwargs: Any,
57
59
  ) -> None:
58
60
  url, sha256, name = self.TRAIN if train else self.TEST
@@ -64,10 +66,15 @@ class CORD(VisionDataset):
64
66
  pre_transforms=convert_target_to_relative if not recognition_task else None,
65
67
  **kwargs,
66
68
  )
69
+ if recognition_task and detection_task:
70
+ raise ValueError(
71
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
72
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
73
+ )
67
74
 
68
75
  # List images
69
76
  tmp_root = os.path.join(self.root, "image")
70
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
77
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
71
78
  self.train = train
72
79
  np_dtype = np.float32
73
80
  for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking CORD", total=len(os.listdir(tmp_root))):
@@ -109,6 +116,8 @@ class CORD(VisionDataset):
109
116
  )
110
117
  for crop, label in zip(crops, list(text_targets)):
111
118
  self.data.append((crop, label))
119
+ elif detection_task:
120
+ self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0)))
112
121
  else:
113
122
  self.data.append((
114
123
  img_path,
@@ -50,9 +50,9 @@ class AbstractDataset(_AbstractDataset):
50
50
  @staticmethod
51
51
  def collate_fn(samples: List[Tuple[torch.Tensor, Any]]) -> Tuple[torch.Tensor, List[Any]]:
52
52
  images, targets = zip(*samples)
53
- images = torch.stack(images, dim=0)
53
+ images = torch.stack(images, dim=0) # type: ignore[assignment]
54
54
 
55
- return images, list(targets)
55
+ return images, list(targets) # type: ignore[return-value]
56
56
 
57
57
 
58
58
  class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
doctr/datasets/funsd.py CHANGED
@@ -33,6 +33,7 @@ class FUNSD(VisionDataset):
33
33
  train: whether the subset should be the training one
34
34
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
35
35
  recognition_task: whether the dataset should be used for recognition task
36
+ detection_task: whether the dataset should be used for detection task
36
37
  **kwargs: keyword arguments from `VisionDataset`.
37
38
  """
38
39
 
@@ -45,6 +46,7 @@ class FUNSD(VisionDataset):
45
46
  train: bool = True,
46
47
  use_polygons: bool = False,
47
48
  recognition_task: bool = False,
49
+ detection_task: bool = False,
48
50
  **kwargs: Any,
49
51
  ) -> None:
50
52
  super().__init__(
@@ -55,6 +57,12 @@ class FUNSD(VisionDataset):
55
57
  pre_transforms=convert_target_to_relative if not recognition_task else None,
56
58
  **kwargs,
57
59
  )
60
+ if recognition_task and detection_task:
61
+ raise ValueError(
62
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
63
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
64
+ )
65
+
58
66
  self.train = train
59
67
  np_dtype = np.float32
60
68
 
@@ -63,7 +71,7 @@ class FUNSD(VisionDataset):
63
71
 
64
72
  # # List images
65
73
  tmp_root = os.path.join(self.root, subfolder, "images")
66
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
74
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
67
75
  for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking FUNSD", total=len(os.listdir(tmp_root))):
68
76
  # File existence check
69
77
  if not os.path.exists(os.path.join(tmp_root, img_path)):
@@ -100,6 +108,8 @@ class FUNSD(VisionDataset):
100
108
  # filter labels with unknown characters
101
109
  if not any(char in label for char in ["☑", "☐", "\uf703", "\uf702"]):
102
110
  self.data.append((crop, label))
111
+ elif detection_task:
112
+ self.data.append((img_path, np.asarray(box_targets, dtype=np_dtype)))
103
113
  else:
104
114
  self.data.append((
105
115
  img_path,
@@ -20,7 +20,7 @@ def synthesize_text_img(
20
20
  font_family: Optional[str] = None,
21
21
  background_color: Optional[Tuple[int, int, int]] = None,
22
22
  text_color: Optional[Tuple[int, int, int]] = None,
23
- ) -> Image:
23
+ ) -> Image.Image:
24
24
  """Generate a synthetic text image
25
25
 
26
26
  Args:
@@ -81,7 +81,7 @@ class _CharacterGenerator(AbstractDataset):
81
81
  self._data: List[Image.Image] = []
82
82
  if cache_samples:
83
83
  self._data = [
84
- (synthesize_text_img(char, font_family=font), idx)
84
+ (synthesize_text_img(char, font_family=font), idx) # type: ignore[misc]
85
85
  for idx, char in enumerate(self.vocab)
86
86
  for font in self.font_family
87
87
  ]
@@ -93,7 +93,7 @@ class _CharacterGenerator(AbstractDataset):
93
93
  # Samples are already cached
94
94
  if len(self._data) > 0:
95
95
  idx = index % len(self._data)
96
- pil_img, target = self._data[idx]
96
+ pil_img, target = self._data[idx] # type: ignore[misc]
97
97
  else:
98
98
  target = index % len(self.vocab)
99
99
  pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family))
@@ -132,7 +132,8 @@ class _WordGenerator(AbstractDataset):
132
132
  if cache_samples:
133
133
  _words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)]
134
134
  self._data = [
135
- (synthesize_text_img(text, font_family=random.choice(self.font_family)), text) for text in _words
135
+ (synthesize_text_img(text, font_family=random.choice(self.font_family)), text) # type: ignore[misc]
136
+ for text in _words
136
137
  ]
137
138
 
138
139
  def _generate_string(self, min_chars: int, max_chars: int) -> str:
@@ -145,7 +146,7 @@ class _WordGenerator(AbstractDataset):
145
146
  def _read_sample(self, index: int) -> Tuple[Any, str]:
146
147
  # Samples are already cached
147
148
  if len(self._data) > 0:
148
- pil_img, target = self._data[index]
149
+ pil_img, target = self._data[index] # type: ignore[misc]
149
150
  else:
150
151
  target = self._generate_string(*self.wordlen_range)
151
152
  pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family))
doctr/datasets/ic03.py CHANGED
@@ -32,6 +32,7 @@ class IC03(VisionDataset):
32
32
  train: whether the subset should be the training one
33
33
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
34
34
  recognition_task: whether the dataset should be used for recognition task
35
+ detection_task: whether the dataset should be used for detection task
35
36
  **kwargs: keyword arguments from `VisionDataset`.
36
37
  """
37
38
 
@@ -51,6 +52,7 @@ class IC03(VisionDataset):
51
52
  train: bool = True,
52
53
  use_polygons: bool = False,
53
54
  recognition_task: bool = False,
55
+ detection_task: bool = False,
54
56
  **kwargs: Any,
55
57
  ) -> None:
56
58
  url, sha256, file_name = self.TRAIN if train else self.TEST
@@ -62,8 +64,14 @@ class IC03(VisionDataset):
62
64
  pre_transforms=convert_target_to_relative if not recognition_task else None,
63
65
  **kwargs,
64
66
  )
67
+ if recognition_task and detection_task:
68
+ raise ValueError(
69
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
70
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
71
+ )
72
+
65
73
  self.train = train
66
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
74
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
67
75
  np_dtype = np.float32
68
76
 
69
77
  # Load xml data
@@ -117,6 +125,8 @@ class IC03(VisionDataset):
117
125
  for crop, label in zip(crops, labels):
118
126
  if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
119
127
  self.data.append((crop, label))
128
+ elif detection_task:
129
+ self.data.append((name.text, boxes))
120
130
  else:
121
131
  self.data.append((name.text, dict(boxes=boxes, labels=labels)))
122
132
 
doctr/datasets/ic13.py CHANGED
@@ -38,6 +38,7 @@ class IC13(AbstractDataset):
38
38
  label_folder: folder with all annotation files for the images
39
39
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
40
40
  recognition_task: whether the dataset should be used for recognition task
41
+ detection_task: whether the dataset should be used for detection task
41
42
  **kwargs: keyword arguments from `AbstractDataset`.
42
43
  """
43
44
 
@@ -47,11 +48,17 @@ class IC13(AbstractDataset):
47
48
  label_folder: str,
48
49
  use_polygons: bool = False,
49
50
  recognition_task: bool = False,
51
+ detection_task: bool = False,
50
52
  **kwargs: Any,
51
53
  ) -> None:
52
54
  super().__init__(
53
55
  img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
54
56
  )
57
+ if recognition_task and detection_task:
58
+ raise ValueError(
59
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
60
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
61
+ )
55
62
 
56
63
  # File existence check
57
64
  if not os.path.exists(label_folder) or not os.path.exists(img_folder):
@@ -59,7 +66,7 @@ class IC13(AbstractDataset):
59
66
  f"unable to locate {label_folder if not os.path.exists(label_folder) else img_folder}"
60
67
  )
61
68
 
62
- self.data: List[Tuple[Union[Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
69
+ self.data: List[Tuple[Union[Path, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
63
70
  np_dtype = np.float32
64
71
 
65
72
  img_names = os.listdir(img_folder)
@@ -95,5 +102,7 @@ class IC13(AbstractDataset):
95
102
  crops = crop_bboxes_from_image(img_path=img_path, geoms=box_targets)
96
103
  for crop, label in zip(crops, labels):
97
104
  self.data.append((crop, label))
105
+ elif detection_task:
106
+ self.data.append((img_path, box_targets))
98
107
  else:
99
108
  self.data.append((img_path, dict(boxes=box_targets, labels=labels)))
doctr/datasets/iiit5k.py CHANGED
@@ -34,6 +34,7 @@ class IIIT5K(VisionDataset):
34
34
  train: whether the subset should be the training one
35
35
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
36
36
  recognition_task: whether the dataset should be used for recognition task
37
+ detection_task: whether the dataset should be used for detection task
37
38
  **kwargs: keyword arguments from `VisionDataset`.
38
39
  """
39
40
 
@@ -45,6 +46,7 @@ class IIIT5K(VisionDataset):
45
46
  train: bool = True,
46
47
  use_polygons: bool = False,
47
48
  recognition_task: bool = False,
49
+ detection_task: bool = False,
48
50
  **kwargs: Any,
49
51
  ) -> None:
50
52
  super().__init__(
@@ -55,6 +57,12 @@ class IIIT5K(VisionDataset):
55
57
  pre_transforms=convert_target_to_relative if not recognition_task else None,
56
58
  **kwargs,
57
59
  )
60
+ if recognition_task and detection_task:
61
+ raise ValueError(
62
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
63
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
64
+ )
65
+
58
66
  self.train = train
59
67
 
60
68
  # Load mat data
@@ -62,7 +70,7 @@ class IIIT5K(VisionDataset):
62
70
  mat_file = "trainCharBound" if self.train else "testCharBound"
63
71
  mat_data = sio.loadmat(os.path.join(tmp_root, f"{mat_file}.mat"))[mat_file][0]
64
72
 
65
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
73
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
66
74
  np_dtype = np.float32
67
75
 
68
76
  for img_path, label, box_targets in tqdm(iterable=mat_data, desc="Unpacking IIIT5K", total=len(mat_data)):
@@ -73,24 +81,26 @@ class IIIT5K(VisionDataset):
73
81
  if not os.path.exists(os.path.join(tmp_root, _raw_path)):
74
82
  raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, _raw_path)}")
75
83
 
84
+ if use_polygons:
85
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
86
+ box_targets = [
87
+ [
88
+ [box[0], box[1]],
89
+ [box[0] + box[2], box[1]],
90
+ [box[0] + box[2], box[1] + box[3]],
91
+ [box[0], box[1] + box[3]],
92
+ ]
93
+ for box in box_targets
94
+ ]
95
+ else:
96
+ # xmin, ymin, xmax, ymax
97
+ box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets]
98
+
76
99
  if recognition_task:
77
100
  self.data.append((_raw_path, _raw_label))
101
+ elif detection_task:
102
+ self.data.append((_raw_path, np.asarray(box_targets, dtype=np_dtype)))
78
103
  else:
79
- if use_polygons:
80
- # (x, y) coordinates of top left, top right, bottom right, bottom left corners
81
- box_targets = [
82
- [
83
- [box[0], box[1]],
84
- [box[0] + box[2], box[1]],
85
- [box[0] + box[2], box[1] + box[3]],
86
- [box[0], box[1] + box[3]],
87
- ]
88
- for box in box_targets
89
- ]
90
- else:
91
- # xmin, ymin, xmax, ymax
92
- box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets]
93
-
94
104
  # label are casted to list where each char corresponds to the character's bounding box
95
105
  self.data.append((
96
106
  _raw_path,
doctr/datasets/imgur5k.py CHANGED
@@ -46,6 +46,7 @@ class IMGUR5K(AbstractDataset):
46
46
  train: whether the subset should be the training one
47
47
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
48
48
  recognition_task: whether the dataset should be used for recognition task
49
+ detection_task: whether the dataset should be used for detection task
49
50
  **kwargs: keyword arguments from `AbstractDataset`.
50
51
  """
51
52
 
@@ -56,17 +57,23 @@ class IMGUR5K(AbstractDataset):
56
57
  train: bool = True,
57
58
  use_polygons: bool = False,
58
59
  recognition_task: bool = False,
60
+ detection_task: bool = False,
59
61
  **kwargs: Any,
60
62
  ) -> None:
61
63
  super().__init__(
62
64
  img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
63
65
  )
66
+ if recognition_task and detection_task:
67
+ raise ValueError(
68
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
69
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
70
+ )
64
71
 
65
72
  # File existence check
66
73
  if not os.path.exists(label_path) or not os.path.exists(img_folder):
67
74
  raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
68
75
 
69
- self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
76
+ self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
70
77
  self.train = train
71
78
  np_dtype = np.float32
72
79
 
@@ -112,7 +119,7 @@ class IMGUR5K(AbstractDataset):
112
119
  if ann["word"] != "."
113
120
  ]
114
121
  # (x, y) coordinates of top left, top right, bottom right, bottom left corners
115
- box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes] # type: ignore[arg-type]
122
+ box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes]
116
123
 
117
124
  if not use_polygons:
118
125
  # xmin, ymin, xmax, ymax
@@ -132,6 +139,8 @@ class IMGUR5K(AbstractDataset):
132
139
  tmp_img = Image.fromarray(crop)
133
140
  tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png"))
134
141
  reco_images_counter += 1
142
+ elif detection_task:
143
+ self.data.append((img_path, np.asarray(box_targets, dtype=np_dtype)))
135
144
  else:
136
145
  self.data.append((img_path, dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=labels)))
137
146
 
doctr/datasets/loader.py CHANGED
@@ -9,8 +9,6 @@ from typing import Callable, Optional
9
9
  import numpy as np
10
10
  import tensorflow as tf
11
11
 
12
- from doctr.utils.multithreading import multithread_exec
13
-
14
12
  __all__ = ["DataLoader"]
15
13
 
16
14
 
@@ -47,7 +45,6 @@ class DataLoader:
47
45
  shuffle: whether the samples should be shuffled before passing it to the iterator
48
46
  batch_size: number of elements in each batch
49
47
  drop_last: if `True`, drops the last batch if it isn't full
50
- num_workers: number of workers to use for data loading
51
48
  collate_fn: function to merge samples into a batch
52
49
  """
53
50
 
@@ -57,7 +54,6 @@ class DataLoader:
57
54
  shuffle: bool = True,
58
55
  batch_size: int = 1,
59
56
  drop_last: bool = False,
60
- num_workers: Optional[int] = None,
61
57
  collate_fn: Optional[Callable] = None,
62
58
  ) -> None:
63
59
  self.dataset = dataset
@@ -69,7 +65,6 @@ class DataLoader:
69
65
  self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, "collate_fn") else default_collate
70
66
  else:
71
67
  self.collate_fn = collate_fn
72
- self.num_workers = num_workers
73
68
  self.reset()
74
69
 
75
70
  def __len__(self) -> int:
@@ -92,7 +87,7 @@ class DataLoader:
92
87
  idx = self._num_yielded * self.batch_size
93
88
  indices = self.indices[idx : min(len(self.dataset), idx + self.batch_size)]
94
89
 
95
- samples = list(multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers))
90
+ samples = list(map(self.dataset.__getitem__, indices))
96
91
 
97
92
  batch_data = self.collate_fn(samples)
98
93