python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
doctr/datasets/sroie.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  import csv
7
7
  import os
8
8
  from pathlib import Path
9
- from typing import Any, Dict, List, Tuple, Union
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  from tqdm import tqdm
@@ -29,7 +29,6 @@ class SROIE(VisionDataset):
29
29
  >>> img, target = train_set[0]
30
30
 
31
31
  Args:
32
- ----
33
32
  train: whether the subset should be the training one
34
33
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
35
34
  recognition_task: whether the dataset should be used for recognition task
@@ -74,10 +73,12 @@ class SROIE(VisionDataset):
74
73
  self.train = train
75
74
 
76
75
  tmp_root = os.path.join(self.root, "images")
77
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
76
+ self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
78
77
  np_dtype = np.float32
79
78
 
80
- for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking SROIE", total=len(os.listdir(tmp_root))):
79
+ for img_path in tqdm(
80
+ iterable=os.listdir(tmp_root), desc="Preparing and Loading SROIE", total=len(os.listdir(tmp_root))
81
+ ):
81
82
  # File existence check
82
83
  if not os.path.exists(os.path.join(tmp_root, img_path)):
83
84
  raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
doctr/datasets/svhn.py CHANGED
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import os
7
- from typing import Any, Dict, List, Tuple, Union
7
+ from typing import Any
8
8
 
9
9
  import h5py
10
10
  import numpy as np
@@ -28,7 +28,6 @@ class SVHN(VisionDataset):
28
28
  >>> img, target = train_set[0]
29
29
 
30
30
  Args:
31
- ----
32
31
  train: whether the subset should be the training one
33
32
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
34
33
  recognition_task: whether the dataset should be used for recognition task
@@ -72,7 +71,7 @@ class SVHN(VisionDataset):
72
71
  )
73
72
 
74
73
  self.train = train
75
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
74
+ self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
76
75
  np_dtype = np.float32
77
76
 
78
77
  tmp_root = os.path.join(self.root, "train" if train else "test")
@@ -81,7 +80,9 @@ class SVHN(VisionDataset):
81
80
  with h5py.File(os.path.join(tmp_root, "digitStruct.mat"), "r") as f:
82
81
  img_refs = f["digitStruct/name"]
83
82
  box_refs = f["digitStruct/bbox"]
84
- for img_ref, box_ref in tqdm(iterable=zip(img_refs, box_refs), desc="Unpacking SVHN", total=len(img_refs)):
83
+ for img_ref, box_ref in tqdm(
84
+ iterable=zip(img_refs, box_refs), desc="Preparing and Loading SVHN", total=len(img_refs)
85
+ ):
85
86
  # convert ascii matrix to string
86
87
  img_name = "".join(map(chr, f[img_ref[0]][()].flatten()))
87
88
 
@@ -128,7 +129,7 @@ class SVHN(VisionDataset):
128
129
  if recognition_task:
129
130
  crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_name), geoms=box_targets)
130
131
  for crop, label in zip(crops, label_targets):
131
- if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
132
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0 and " " not in label:
132
133
  self.data.append((crop, label))
133
134
  elif detection_task:
134
135
  self.data.append((img_name, box_targets))
doctr/datasets/svt.py CHANGED
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import os
7
- from typing import Any, Dict, List, Tuple, Union
7
+ from typing import Any
8
8
 
9
9
  import defusedxml.ElementTree as ET
10
10
  import numpy as np
@@ -28,7 +28,6 @@ class SVT(VisionDataset):
28
28
  >>> img, target = train_set[0]
29
29
 
30
30
  Args:
31
- ----
32
31
  train: whether the subset should be the training one
33
32
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
34
33
  recognition_task: whether the dataset should be used for recognition task
@@ -36,7 +35,7 @@ class SVT(VisionDataset):
36
35
  **kwargs: keyword arguments from `VisionDataset`.
37
36
  """
38
37
 
39
- URL = "http://vision.ucsd.edu/~kai/svt/svt.zip"
38
+ URL = "http://www.iapr-tc11.org/dataset/SVT/svt.zip"
40
39
  SHA256 = "63b3d55e6b6d1e036e2a844a20c034fe3af3c32e4d914d6e0c4a3cd43df3bebf"
41
40
 
42
41
  def __init__(
@@ -62,7 +61,7 @@ class SVT(VisionDataset):
62
61
  )
63
62
 
64
63
  self.train = train
65
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
64
+ self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
66
65
  np_dtype = np.float32
67
66
 
68
67
  # Load xml data
@@ -74,7 +73,7 @@ class SVT(VisionDataset):
74
73
  )
75
74
  xml_root = xml_tree.getroot()
76
75
 
77
- for image in tqdm(iterable=xml_root, desc="Unpacking SVT", total=len(xml_root)):
76
+ for image in tqdm(iterable=xml_root, desc="Preparing and Loading SVT", total=len(xml_root)):
78
77
  name, _, _, _resolution, rectangles = image
79
78
 
80
79
  # File existence check
@@ -114,7 +113,7 @@ class SVT(VisionDataset):
114
113
  if recognition_task:
115
114
  crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
116
115
  for crop, label in zip(crops, labels):
117
- if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
116
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0 and " " not in label:
118
117
  self.data.append((crop, label))
119
118
  elif detection_task:
120
119
  self.data.append((name.text, boxes))
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import glob
7
7
  import os
8
- from typing import Any, Dict, List, Tuple, Union
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
11
  from PIL import Image
@@ -31,7 +31,6 @@ class SynthText(VisionDataset):
31
31
  >>> img, target = train_set[0]
32
32
 
33
33
  Args:
34
- ----
35
34
  train: whether the subset should be the training one
36
35
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
37
36
  recognition_task: whether the dataset should be used for recognition task
@@ -42,6 +41,12 @@ class SynthText(VisionDataset):
42
41
  URL = "https://thor.robots.ox.ac.uk/~vgg/data/scenetext/SynthText.zip"
43
42
  SHA256 = "28ab030485ec8df3ed612c568dd71fb2793b9afbfa3a9d9c6e792aef33265bf1"
44
43
 
44
+ # filter corrupted or missing images
45
+ BLACKLIST = (
46
+ "67/fruits_129_",
47
+ "194/window_19_",
48
+ )
49
+
45
50
  def __init__(
46
51
  self,
47
52
  train: bool = True,
@@ -65,7 +70,7 @@ class SynthText(VisionDataset):
65
70
  )
66
71
 
67
72
  self.train = train
68
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
73
+ self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
69
74
  np_dtype = np.float32
70
75
 
71
76
  # Load mat data
@@ -91,7 +96,7 @@ class SynthText(VisionDataset):
91
96
  del mat_data
92
97
 
93
98
  for img_path, word_boxes, txt in tqdm(
94
- iterable=zip(paths, boxes, labels), desc="Unpacking SynthText", total=len(paths)
99
+ iterable=zip(paths, boxes, labels), desc="Preparing and Loading SynthText", total=len(paths)
95
100
  ):
96
101
  # File existence check
97
102
  if not os.path.exists(os.path.join(tmp_root, img_path[0])):
@@ -112,7 +117,13 @@ class SynthText(VisionDataset):
112
117
  if recognition_task:
113
118
  crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path[0]), geoms=word_boxes)
114
119
  for crop, label in zip(crops, labels):
115
- if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
120
+ if (
121
+ crop.shape[0] > 0
122
+ and crop.shape[1] > 0
123
+ and len(label) > 0
124
+ and len(label) < 30
125
+ and " " not in label
126
+ ):
116
127
  # write data to disk
117
128
  with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f:
118
129
  f.write(label)
@@ -133,6 +144,7 @@ class SynthText(VisionDataset):
133
144
  return f"train={self.train}"
134
145
 
135
146
  def _read_from_folder(self, path: str) -> None:
136
- for img_path in glob.glob(os.path.join(path, "*.png")):
147
+ img_paths = glob.glob(os.path.join(path, "*.png"))
148
+ for img_path in tqdm(iterable=img_paths, desc="Preparing and Loading SynthText", total=len(img_paths)):
137
149
  with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
138
150
  self.data.append((img_path, f.read()))
doctr/datasets/utils.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,10 +6,10 @@
6
6
  import string
7
7
  import unicodedata
8
8
  from collections.abc import Sequence
9
+ from collections.abc import Sequence as SequenceType
9
10
  from functools import partial
10
11
  from pathlib import Path
11
- from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
12
- from typing import Sequence as SequenceType
12
+ from typing import Any, TypeVar
13
13
 
14
14
  import numpy as np
15
15
  from PIL import Image
@@ -19,7 +19,15 @@ from doctr.utils.geometry import convert_to_relative_coords, extract_crops, extr
19
19
 
20
20
  from .vocabs import VOCABS
21
21
 
22
- __all__ = ["translate", "encode_string", "decode_sequence", "encode_sequences", "pre_transform_multiclass"]
22
+ __all__ = [
23
+ "translate",
24
+ "encode_string",
25
+ "decode_sequence",
26
+ "encode_sequences",
27
+ "pre_transform_multiclass",
28
+ "crop_bboxes_from_image",
29
+ "convert_target_to_relative",
30
+ ]
23
31
 
24
32
  ImageTensor = TypeVar("ImageTensor")
25
33
 
@@ -32,17 +40,15 @@ def translate(
32
40
  """Translate a string input in a given vocabulary
33
41
 
34
42
  Args:
35
- ----
36
43
  input_string: input string to translate
37
44
  vocab_name: vocabulary to use (french, latin, ...)
38
45
  unknown_char: unknown character for non-translatable characters
39
46
 
40
47
  Returns:
41
- -------
42
48
  A string translated in a given vocab
43
49
  """
44
50
  if VOCABS.get(vocab_name) is None:
45
- raise KeyError("output vocabulary must be in vocabs dictionnary")
51
+ raise KeyError("output vocabulary must be in vocabs dictionary")
46
52
 
47
53
  translated = ""
48
54
  for char in input_string:
@@ -63,40 +69,37 @@ def translate(
63
69
  def encode_string(
64
70
  input_string: str,
65
71
  vocab: str,
66
- ) -> List[int]:
72
+ ) -> list[int]:
67
73
  """Given a predefined mapping, encode the string to a sequence of numbers
68
74
 
69
75
  Args:
70
- ----
71
76
  input_string: string to encode
72
77
  vocab: vocabulary (string), the encoding is given by the indexing of the character sequence
73
78
 
74
79
  Returns:
75
- -------
76
80
  A list encoding the input_string
77
81
  """
78
82
  try:
79
83
  return list(map(vocab.index, input_string))
80
- except ValueError:
84
+ except ValueError as e:
85
+ missing_chars = [char for char in input_string if char not in vocab]
81
86
  raise ValueError(
82
- f"some characters cannot be found in 'vocab'. \
83
- Please check the input string {input_string} and the vocabulary {vocab}"
84
- )
87
+ f"Some characters cannot be found in 'vocab': {set(missing_chars)}.\n"
88
+ f"Please check the input string `{input_string}` and the vocabulary `{vocab}`"
89
+ ) from e
85
90
 
86
91
 
87
92
  def decode_sequence(
88
- input_seq: Union[np.ndarray, SequenceType[int]],
93
+ input_seq: np.ndarray | SequenceType[int],
89
94
  mapping: str,
90
95
  ) -> str:
91
96
  """Given a predefined mapping, decode the sequence of numbers to a string
92
97
 
93
98
  Args:
94
- ----
95
99
  input_seq: array to decode
96
100
  mapping: vocabulary (string), the encoding is given by the indexing of the character sequence
97
101
 
98
102
  Returns:
99
- -------
100
103
  A string, decoded from input_seq
101
104
  """
102
105
  if not isinstance(input_seq, (Sequence, np.ndarray)):
@@ -108,18 +111,17 @@ def decode_sequence(
108
111
 
109
112
 
110
113
  def encode_sequences(
111
- sequences: List[str],
114
+ sequences: list[str],
112
115
  vocab: str,
113
- target_size: Optional[int] = None,
116
+ target_size: int | None = None,
114
117
  eos: int = -1,
115
- sos: Optional[int] = None,
116
- pad: Optional[int] = None,
118
+ sos: int | None = None,
119
+ pad: int | None = None,
117
120
  dynamic_seq_length: bool = False,
118
121
  ) -> np.ndarray:
119
122
  """Encode character sequences using a given vocab as mapping
120
123
 
121
124
  Args:
122
- ----
123
125
  sequences: the list of character sequences of size N
124
126
  vocab: the ordered vocab to use for encoding
125
127
  target_size: maximum length of the encoded data
@@ -129,7 +131,6 @@ def encode_sequences(
129
131
  dynamic_seq_length: if `target_size` is specified, uses it as upper bound and enables dynamic sequence size
130
132
 
131
133
  Returns:
132
- -------
133
134
  the padded encoded data as a tensor
134
135
  """
135
136
  if 0 <= eos < len(vocab):
@@ -170,29 +171,36 @@ def encode_sequences(
170
171
 
171
172
 
172
173
  def convert_target_to_relative(
173
- img: ImageTensor, target: Union[np.ndarray, Dict[str, Any]]
174
- ) -> Tuple[ImageTensor, Union[Dict[str, Any], np.ndarray]]:
174
+ img: ImageTensor, target: np.ndarray | dict[str, Any]
175
+ ) -> tuple[ImageTensor, dict[str, Any] | np.ndarray]:
176
+ """Converts target to relative coordinates
177
+
178
+ Args:
179
+ img: tf.Tensor or torch.Tensor representing the image
180
+ target: target to convert to relative coordinates (boxes (N, 4) or polygons (N, 4, 2))
181
+
182
+ Returns:
183
+ The image and the target in relative coordinates
184
+ """
175
185
  if isinstance(target, np.ndarray):
176
- target = convert_to_relative_coords(target, get_img_shape(img))
186
+ target = convert_to_relative_coords(target, get_img_shape(img)) # type: ignore[arg-type]
177
187
  else:
178
- target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img))
188
+ target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img)) # type: ignore[arg-type]
179
189
  return img, target
180
190
 
181
191
 
182
- def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> List[np.ndarray]:
192
+ def crop_bboxes_from_image(img_path: str | Path, geoms: np.ndarray) -> list[np.ndarray]:
183
193
  """Crop a set of bounding boxes from an image
184
194
 
185
195
  Args:
186
- ----
187
196
  img_path: path to the image
188
197
  geoms: a array of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4)
189
198
 
190
199
  Returns:
191
- -------
192
200
  a list of cropped images
193
201
  """
194
202
  with Image.open(img_path) as pil_img:
195
- img: np.ndarray = np.array(pil_img.convert("RGB"))
203
+ img: np.ndarray = np.asarray(pil_img.convert("RGB"))
196
204
  # Polygon
197
205
  if geoms.ndim == 3 and geoms.shape[1:] == (4, 2):
198
206
  return extract_rcrops(img, geoms.astype(dtype=int))
@@ -201,21 +209,19 @@ def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> Lis
201
209
  raise ValueError("Invalid geometry format")
202
210
 
203
211
 
204
- def pre_transform_multiclass(img, target: Tuple[np.ndarray, List]) -> Tuple[np.ndarray, Dict[str, List]]:
212
+ def pre_transform_multiclass(img, target: tuple[np.ndarray, list]) -> tuple[np.ndarray, dict[str, list]]:
205
213
  """Converts multiclass target to relative coordinates.
206
214
 
207
215
  Args:
208
- ----
209
216
  img: Image
210
217
  target: tuple of target polygons and their classes names
211
218
 
212
219
  Returns:
213
- -------
214
220
  Image and dictionary of boxes, with class names as keys
215
221
  """
216
222
  boxes = convert_to_relative_coords(target[0], get_img_shape(img))
217
223
  boxes_classes = target[1]
218
- boxes_dict: Dict = {k: [] for k in sorted(set(boxes_classes))}
224
+ boxes_dict: dict = {k: [] for k in sorted(set(boxes_classes))}
219
225
  for k, poly in zip(boxes_classes, boxes):
220
226
  boxes_dict[k].append(poly)
221
227
  boxes_dict = {k: np.stack(v, axis=0) for k, v in boxes_dict.items()}