python-doctr 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/cord.py +10 -1
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +11 -1
  8. doctr/datasets/generator/base.py +6 -5
  9. doctr/datasets/ic03.py +11 -1
  10. doctr/datasets/ic13.py +10 -1
  11. doctr/datasets/iiit5k.py +26 -16
  12. doctr/datasets/imgur5k.py +11 -2
  13. doctr/datasets/loader.py +1 -6
  14. doctr/datasets/sroie.py +11 -1
  15. doctr/datasets/svhn.py +11 -1
  16. doctr/datasets/svt.py +11 -1
  17. doctr/datasets/synthtext.py +11 -1
  18. doctr/datasets/utils.py +9 -3
  19. doctr/datasets/vocabs.py +15 -4
  20. doctr/datasets/wildreceipt.py +12 -1
  21. doctr/file_utils.py +45 -12
  22. doctr/io/elements.py +52 -10
  23. doctr/io/html.py +2 -2
  24. doctr/io/image/pytorch.py +6 -8
  25. doctr/io/image/tensorflow.py +1 -1
  26. doctr/io/pdf.py +5 -2
  27. doctr/io/reader.py +6 -0
  28. doctr/models/__init__.py +0 -1
  29. doctr/models/_utils.py +57 -20
  30. doctr/models/builder.py +73 -15
  31. doctr/models/classification/magc_resnet/tensorflow.py +13 -6
  32. doctr/models/classification/mobilenet/pytorch.py +47 -9
  33. doctr/models/classification/mobilenet/tensorflow.py +51 -14
  34. doctr/models/classification/predictor/pytorch.py +28 -17
  35. doctr/models/classification/predictor/tensorflow.py +26 -16
  36. doctr/models/classification/resnet/tensorflow.py +21 -8
  37. doctr/models/classification/textnet/pytorch.py +3 -3
  38. doctr/models/classification/textnet/tensorflow.py +11 -5
  39. doctr/models/classification/vgg/tensorflow.py +9 -3
  40. doctr/models/classification/vit/tensorflow.py +10 -4
  41. doctr/models/classification/zoo.py +55 -19
  42. doctr/models/detection/_utils/__init__.py +1 -0
  43. doctr/models/detection/_utils/base.py +66 -0
  44. doctr/models/detection/differentiable_binarization/base.py +4 -3
  45. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  46. doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
  47. doctr/models/detection/fast/base.py +6 -5
  48. doctr/models/detection/fast/pytorch.py +4 -4
  49. doctr/models/detection/fast/tensorflow.py +15 -12
  50. doctr/models/detection/linknet/base.py +4 -3
  51. doctr/models/detection/linknet/tensorflow.py +23 -11
  52. doctr/models/detection/predictor/pytorch.py +15 -1
  53. doctr/models/detection/predictor/tensorflow.py +17 -3
  54. doctr/models/detection/zoo.py +7 -2
  55. doctr/models/factory/hub.py +8 -18
  56. doctr/models/kie_predictor/base.py +13 -3
  57. doctr/models/kie_predictor/pytorch.py +45 -20
  58. doctr/models/kie_predictor/tensorflow.py +44 -17
  59. doctr/models/modules/layers/pytorch.py +2 -3
  60. doctr/models/modules/layers/tensorflow.py +6 -8
  61. doctr/models/modules/transformer/pytorch.py +2 -2
  62. doctr/models/modules/transformer/tensorflow.py +0 -2
  63. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  64. doctr/models/modules/vision_transformer/tensorflow.py +1 -1
  65. doctr/models/predictor/base.py +97 -58
  66. doctr/models/predictor/pytorch.py +35 -20
  67. doctr/models/predictor/tensorflow.py +35 -18
  68. doctr/models/preprocessor/pytorch.py +4 -4
  69. doctr/models/preprocessor/tensorflow.py +3 -2
  70. doctr/models/recognition/crnn/tensorflow.py +8 -6
  71. doctr/models/recognition/master/pytorch.py +2 -2
  72. doctr/models/recognition/master/tensorflow.py +9 -4
  73. doctr/models/recognition/parseq/pytorch.py +4 -3
  74. doctr/models/recognition/parseq/tensorflow.py +14 -11
  75. doctr/models/recognition/sar/pytorch.py +7 -6
  76. doctr/models/recognition/sar/tensorflow.py +10 -12
  77. doctr/models/recognition/vitstr/pytorch.py +1 -1
  78. doctr/models/recognition/vitstr/tensorflow.py +9 -4
  79. doctr/models/recognition/zoo.py +1 -1
  80. doctr/models/utils/pytorch.py +1 -1
  81. doctr/models/utils/tensorflow.py +15 -15
  82. doctr/models/zoo.py +2 -2
  83. doctr/py.typed +0 -0
  84. doctr/transforms/functional/base.py +1 -1
  85. doctr/transforms/functional/pytorch.py +5 -5
  86. doctr/transforms/modules/base.py +37 -15
  87. doctr/transforms/modules/pytorch.py +73 -14
  88. doctr/transforms/modules/tensorflow.py +78 -19
  89. doctr/utils/fonts.py +7 -5
  90. doctr/utils/geometry.py +141 -31
  91. doctr/utils/metrics.py +34 -175
  92. doctr/utils/reconstitution.py +212 -0
  93. doctr/utils/visualization.py +5 -118
  94. doctr/version.py +1 -1
  95. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/METADATA +85 -81
  96. python_doctr-0.10.0.dist-info/RECORD +173 -0
  97. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
  98. doctr/models/artefacts/__init__.py +0 -2
  99. doctr/models/artefacts/barcode.py +0 -74
  100. doctr/models/artefacts/face.py +0 -63
  101. doctr/models/obj_detection/__init__.py +0 -1
  102. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  103. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  104. python_doctr-0.8.1.dist-info/RECORD +0 -173
  105. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
  106. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
  107. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
doctr/datasets/sroie.py CHANGED
@@ -33,6 +33,7 @@ class SROIE(VisionDataset):
33
33
  train: whether the subset should be the training one
34
34
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
35
35
  recognition_task: whether the dataset should be used for recognition task
36
+ detection_task: whether the dataset should be used for detection task
36
37
  **kwargs: keyword arguments from `VisionDataset`.
37
38
  """
38
39
 
@@ -52,6 +53,7 @@ class SROIE(VisionDataset):
52
53
  train: bool = True,
53
54
  use_polygons: bool = False,
54
55
  recognition_task: bool = False,
56
+ detection_task: bool = False,
55
57
  **kwargs: Any,
56
58
  ) -> None:
57
59
  url, sha256, name = self.TRAIN if train else self.TEST
@@ -63,10 +65,16 @@ class SROIE(VisionDataset):
63
65
  pre_transforms=convert_target_to_relative if not recognition_task else None,
64
66
  **kwargs,
65
67
  )
68
+ if recognition_task and detection_task:
69
+ raise ValueError(
70
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
71
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
72
+ )
73
+
66
74
  self.train = train
67
75
 
68
76
  tmp_root = os.path.join(self.root, "images")
69
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
77
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
70
78
  np_dtype = np.float32
71
79
 
72
80
  for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking SROIE", total=len(os.listdir(tmp_root))):
@@ -94,6 +102,8 @@ class SROIE(VisionDataset):
94
102
  for crop, label in zip(crops, labels):
95
103
  if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
96
104
  self.data.append((crop, label))
105
+ elif detection_task:
106
+ self.data.append((img_path, coords))
97
107
  else:
98
108
  self.data.append((img_path, dict(boxes=coords, labels=labels)))
99
109
 
doctr/datasets/svhn.py CHANGED
@@ -32,6 +32,7 @@ class SVHN(VisionDataset):
32
32
  train: whether the subset should be the training one
33
33
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
34
34
  recognition_task: whether the dataset should be used for recognition task
35
+ detection_task: whether the dataset should be used for detection task
35
36
  **kwargs: keyword arguments from `VisionDataset`.
36
37
  """
37
38
 
@@ -52,6 +53,7 @@ class SVHN(VisionDataset):
52
53
  train: bool = True,
53
54
  use_polygons: bool = False,
54
55
  recognition_task: bool = False,
56
+ detection_task: bool = False,
55
57
  **kwargs: Any,
56
58
  ) -> None:
57
59
  url, sha256, name = self.TRAIN if train else self.TEST
@@ -63,8 +65,14 @@ class SVHN(VisionDataset):
63
65
  pre_transforms=convert_target_to_relative if not recognition_task else None,
64
66
  **kwargs,
65
67
  )
68
+ if recognition_task and detection_task:
69
+ raise ValueError(
70
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
71
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
72
+ )
73
+
66
74
  self.train = train
67
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
75
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
68
76
  np_dtype = np.float32
69
77
 
70
78
  tmp_root = os.path.join(self.root, "train" if train else "test")
@@ -122,6 +130,8 @@ class SVHN(VisionDataset):
122
130
  for crop, label in zip(crops, label_targets):
123
131
  if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
124
132
  self.data.append((crop, label))
133
+ elif detection_task:
134
+ self.data.append((img_name, box_targets))
125
135
  else:
126
136
  self.data.append((img_name, dict(boxes=box_targets, labels=label_targets)))
127
137
 
doctr/datasets/svt.py CHANGED
@@ -32,6 +32,7 @@ class SVT(VisionDataset):
32
32
  train: whether the subset should be the training one
33
33
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
34
34
  recognition_task: whether the dataset should be used for recognition task
35
+ detection_task: whether the dataset should be used for detection task
35
36
  **kwargs: keyword arguments from `VisionDataset`.
36
37
  """
37
38
 
@@ -43,6 +44,7 @@ class SVT(VisionDataset):
43
44
  train: bool = True,
44
45
  use_polygons: bool = False,
45
46
  recognition_task: bool = False,
47
+ detection_task: bool = False,
46
48
  **kwargs: Any,
47
49
  ) -> None:
48
50
  super().__init__(
@@ -53,8 +55,14 @@ class SVT(VisionDataset):
53
55
  pre_transforms=convert_target_to_relative if not recognition_task else None,
54
56
  **kwargs,
55
57
  )
58
+ if recognition_task and detection_task:
59
+ raise ValueError(
60
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
61
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
62
+ )
63
+
56
64
  self.train = train
57
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
65
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
58
66
  np_dtype = np.float32
59
67
 
60
68
  # Load xml data
@@ -108,6 +116,8 @@ class SVT(VisionDataset):
108
116
  for crop, label in zip(crops, labels):
109
117
  if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
110
118
  self.data.append((crop, label))
119
+ elif detection_task:
120
+ self.data.append((name.text, boxes))
111
121
  else:
112
122
  self.data.append((name.text, dict(boxes=boxes, labels=labels)))
113
123
 
@@ -35,6 +35,7 @@ class SynthText(VisionDataset):
35
35
  train: whether the subset should be the training one
36
36
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
37
37
  recognition_task: whether the dataset should be used for recognition task
38
+ detection_task: whether the dataset should be used for detection task
38
39
  **kwargs: keyword arguments from `VisionDataset`.
39
40
  """
40
41
 
@@ -46,6 +47,7 @@ class SynthText(VisionDataset):
46
47
  train: bool = True,
47
48
  use_polygons: bool = False,
48
49
  recognition_task: bool = False,
50
+ detection_task: bool = False,
49
51
  **kwargs: Any,
50
52
  ) -> None:
51
53
  super().__init__(
@@ -56,8 +58,14 @@ class SynthText(VisionDataset):
56
58
  pre_transforms=convert_target_to_relative if not recognition_task else None,
57
59
  **kwargs,
58
60
  )
61
+ if recognition_task and detection_task:
62
+ raise ValueError(
63
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
64
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
65
+ )
66
+
59
67
  self.train = train
60
- self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
68
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
61
69
  np_dtype = np.float32
62
70
 
63
71
  # Load mat data
@@ -111,6 +119,8 @@ class SynthText(VisionDataset):
111
119
  tmp_img = Image.fromarray(crop)
112
120
  tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png"))
113
121
  reco_images_counter += 1
122
+ elif detection_task:
123
+ self.data.append((img_path[0], np.asarray(word_boxes, dtype=np_dtype)))
114
124
  else:
115
125
  self.data.append((img_path[0], dict(boxes=np.asarray(word_boxes, dtype=np_dtype), labels=labels)))
116
126
 
doctr/datasets/utils.py CHANGED
@@ -169,8 +169,13 @@ def encode_sequences(
169
169
  return encoded_data
170
170
 
171
171
 
172
- def convert_target_to_relative(img: ImageTensor, target: Dict[str, Any]) -> Tuple[ImageTensor, Dict[str, Any]]:
173
- target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img))
172
+ def convert_target_to_relative(
173
+ img: ImageTensor, target: Union[np.ndarray, Dict[str, Any]]
174
+ ) -> Tuple[ImageTensor, Union[Dict[str, Any], np.ndarray]]:
175
+ if isinstance(target, np.ndarray):
176
+ target = convert_to_relative_coords(target, get_img_shape(img))
177
+ else:
178
+ target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img))
174
179
  return img, target
175
180
 
176
181
 
@@ -186,7 +191,8 @@ def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> Lis
186
191
  -------
187
192
  a list of cropped images
188
193
  """
189
- img: np.ndarray = np.array(Image.open(img_path).convert("RGB"))
194
+ with Image.open(img_path) as pil_img:
195
+ img: np.ndarray = np.array(pil_img.convert("RGB"))
190
196
  # Polygon
191
197
  if geoms.ndim == 3 and geoms.shape[1:] == (4, 2):
192
198
  return extract_rcrops(img, geoms.astype(dtype=int))
doctr/datasets/vocabs.py CHANGED
@@ -17,9 +17,15 @@ VOCABS: Dict[str, str] = {
17
17
  "ancient_greek": "αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ",
18
18
  "arabic_letters": "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي",
19
19
  "persian_letters": "پچڢڤگ",
20
- "hindi_digits": "٠١٢٣٤٥٦٧٨٩",
20
+ "arabic_digits": "٠١٢٣٤٥٦٧٨٩",
21
21
  "arabic_diacritics": "ًٌٍَُِّْ",
22
22
  "arabic_punctuation": "؟؛«»—",
23
+ "hindi_letters": "अआइईउऊऋॠऌॡएऐओऔअंअःकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसह",
24
+ "hindi_digits": "०१२३४५६७८९",
25
+ "hindi_punctuation": "।,?!:्ॐ॰॥॰",
26
+ "bangla_letters": "অআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃেৈোৌ্ৎংঃঁ",
27
+ "bangla_digits": "০১২৩৪৫৬৭৮৯",
28
+ "generic_cyrillic_letters": "абвгдежзийклмнопрстуфхцчшщьюяАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЬЮЯ",
23
29
  }
24
30
 
25
31
  VOCABS["latin"] = VOCABS["digits"] + VOCABS["ascii_letters"] + VOCABS["punctuation"]
@@ -32,7 +38,7 @@ VOCABS["italian"] = VOCABS["english"] + "àèéìíîòóùúÀÈÉÌÍÎÒÓÙ
32
38
  VOCABS["german"] = VOCABS["english"] + "äöüßÄÖÜẞ"
33
39
  VOCABS["arabic"] = (
34
40
  VOCABS["digits"]
35
- + VOCABS["hindi_digits"]
41
+ + VOCABS["arabic_digits"]
36
42
  + VOCABS["arabic_letters"]
37
43
  + VOCABS["persian_letters"]
38
44
  + VOCABS["arabic_diacritics"]
@@ -48,10 +54,15 @@ VOCABS["finnish"] = VOCABS["english"] + "äöÄÖ"
48
54
  VOCABS["swedish"] = VOCABS["english"] + "åäöÅÄÖ"
49
55
  VOCABS["vietnamese"] = (
50
56
  VOCABS["english"]
51
- + "áàảạãăắằẳẵặâấầẩẫậéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵ"
52
- + "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ"
57
+ + "áàảạãăắằẳẵặâấầẩẫậđéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵ"
58
+ + "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬĐÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ"
53
59
  )
54
60
  VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪"
61
+ VOCABS["hindi"] = VOCABS["hindi_letters"] + VOCABS["hindi_digits"] + VOCABS["hindi_punctuation"]
62
+ VOCABS["bangla"] = VOCABS["bangla_letters"] + VOCABS["bangla_digits"]
63
+ VOCABS["ukrainian"] = (
64
+ VOCABS["generic_cyrillic_letters"] + VOCABS["digits"] + VOCABS["punctuation"] + VOCABS["currency"] + "ґіїєҐІЇЄ₴"
65
+ )
55
66
  VOCABS["multilingual"] = "".join(
56
67
  dict.fromkeys(
57
68
  VOCABS["french"]
@@ -40,6 +40,7 @@ class WILDRECEIPT(AbstractDataset):
40
40
  train: whether the subset should be the training one
41
41
  use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
42
42
  recognition_task: whether the dataset should be used for recognition task
43
+ detection_task: whether the dataset should be used for detection task
43
44
  **kwargs: keyword arguments from `AbstractDataset`.
44
45
  """
45
46
 
@@ -50,11 +51,19 @@ class WILDRECEIPT(AbstractDataset):
50
51
  train: bool = True,
51
52
  use_polygons: bool = False,
52
53
  recognition_task: bool = False,
54
+ detection_task: bool = False,
53
55
  **kwargs: Any,
54
56
  ) -> None:
55
57
  super().__init__(
56
58
  img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
57
59
  )
60
+ # Task check
61
+ if recognition_task and detection_task:
62
+ raise ValueError(
63
+ "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
64
+ + "To get the whole dataset with boxes and labels leave both parameters to False."
65
+ )
66
+
58
67
  # File existence check
59
68
  if not os.path.exists(label_path) or not os.path.exists(img_folder):
60
69
  raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
@@ -62,7 +71,7 @@ class WILDRECEIPT(AbstractDataset):
62
71
  tmp_root = img_folder
63
72
  self.train = train
64
73
  np_dtype = np.float32
65
- self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
74
+ self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = []
66
75
 
67
76
  with open(label_path, "r") as file:
68
77
  data = file.read()
@@ -100,6 +109,8 @@ class WILDRECEIPT(AbstractDataset):
100
109
  for crop, label in zip(crops, list(text_targets)):
101
110
  if label and " " not in label:
102
111
  self.data.append((crop, label))
112
+ elif detection_task:
113
+ self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0)))
103
114
  else:
104
115
  self.data.append((
105
116
  img_path,
doctr/file_utils.py CHANGED
@@ -5,21 +5,16 @@
5
5
 
6
6
  # Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
7
7
 
8
+ import importlib.metadata
8
9
  import importlib.util
9
10
  import logging
10
11
  import os
11
- import sys
12
+ from typing import Optional
12
13
 
13
14
  CLASS_NAME: str = "words"
14
15
 
15
16
 
16
- if sys.version_info < (3, 8): # pragma: no cover
17
- import importlib_metadata
18
- else:
19
- import importlib.metadata as importlib_metadata
20
-
21
-
22
- __all__ = ["is_tf_available", "is_torch_available", "CLASS_NAME"]
17
+ __all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"]
23
18
 
24
19
  ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
25
20
  ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
@@ -32,14 +27,28 @@ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VA
32
27
  _torch_available = importlib.util.find_spec("torch") is not None
33
28
  if _torch_available:
34
29
  try:
35
- _torch_version = importlib_metadata.version("torch")
30
+ _torch_version = importlib.metadata.version("torch")
36
31
  logging.info(f"PyTorch version {_torch_version} available.")
37
- except importlib_metadata.PackageNotFoundError: # pragma: no cover
32
+ except importlib.metadata.PackageNotFoundError: # pragma: no cover
38
33
  _torch_available = False
39
34
  else: # pragma: no cover
40
35
  logging.info("Disabling PyTorch because USE_TF is set")
41
36
  _torch_available = False
42
37
 
38
+ # Compatibility fix to make sure tensorflow.keras stays at Keras 2
39
+ if "TF_USE_LEGACY_KERAS" not in os.environ:
40
+ os.environ["TF_USE_LEGACY_KERAS"] = "1"
41
+
42
+ elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
43
+ raise ValueError(
44
+ "docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
45
+ )
46
+
47
+
48
+ def ensure_keras_v2() -> None: # pragma: no cover
49
+ if not os.environ.get("TF_USE_LEGACY_KERAS") == "1":
50
+ os.environ["TF_USE_LEGACY_KERAS"] = "1"
51
+
43
52
 
44
53
  if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
45
54
  _tf_available = importlib.util.find_spec("tensorflow") is not None
@@ -59,9 +68,9 @@ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VA
59
68
  # For the metadata, we have to look for both tensorflow and tensorflow-cpu
60
69
  for pkg in candidates:
61
70
  try:
62
- _tf_version = importlib_metadata.version(pkg)
71
+ _tf_version = importlib.metadata.version(pkg)
63
72
  break
64
- except importlib_metadata.PackageNotFoundError:
73
+ except importlib.metadata.PackageNotFoundError:
65
74
  pass
66
75
  _tf_available = _tf_version is not None
67
76
  if _tf_available:
@@ -70,6 +79,11 @@ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VA
70
79
  _tf_available = False
71
80
  else:
72
81
  logging.info(f"TensorFlow version {_tf_version} available.")
82
+ ensure_keras_v2()
83
+ import tensorflow as tf
84
+
85
+ # Enable eager execution - this is required for some models to work properly
86
+ tf.config.run_functions_eagerly(True)
73
87
  else: # pragma: no cover
74
88
  logging.info("Disabling Tensorflow because USE_TORCH is set")
75
89
  _tf_available = False
@@ -82,6 +96,25 @@ if not _torch_available and not _tf_available: # pragma: no cover
82
96
  )
83
97
 
84
98
 
99
+ def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover
100
+ """
101
+ package requirement helper
102
+
103
+ Args:
104
+ ----
105
+ name: name of the package
106
+ extra_message: additional message to display if the package is not found
107
+ """
108
+ try:
109
+ _pkg_version = importlib.metadata.version(name)
110
+ logging.info(f"{name} version {_pkg_version} available.")
111
+ except importlib.metadata.PackageNotFoundError:
112
+ raise ImportError(
113
+ f"\n\n{extra_message if extra_message is not None else ''} "
114
+ f"\nPlease install it with the following command: pip install {name}\n"
115
+ )
116
+
117
+
85
118
  def is_torch_available():
86
119
  """Whether PyTorch is installed."""
87
120
  return _torch_available
doctr/io/elements.py CHANGED
@@ -12,14 +12,19 @@ from xml.etree import ElementTree as ET
12
12
  from xml.etree.ElementTree import Element as ETElement
13
13
  from xml.etree.ElementTree import SubElement
14
14
 
15
- import matplotlib.pyplot as plt
16
15
  import numpy as np
17
16
 
18
17
  import doctr
18
+ from doctr.file_utils import requires_package
19
19
  from doctr.utils.common_types import BoundingBox
20
20
  from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox
21
+ from doctr.utils.reconstitution import synthesize_kie_page, synthesize_page
21
22
  from doctr.utils.repr import NestedObject
22
- from doctr.utils.visualization import synthesize_kie_page, synthesize_page, visualize_kie_page, visualize_page
23
+
24
+ try: # optional dependency for visualization
25
+ from doctr.utils.visualization import visualize_kie_page, visualize_page
26
+ except ModuleNotFoundError:
27
+ pass
23
28
 
24
29
  __all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"]
25
30
 
@@ -67,16 +72,27 @@ class Word(Element):
67
72
  confidence: the confidence associated with the text prediction
68
73
  geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
69
74
  the page's size
75
+ objectness_score: the objectness score of the detection
76
+ crop_orientation: the general orientation of the crop in degrees and its confidence
70
77
  """
71
78
 
72
- _exported_keys: List[str] = ["value", "confidence", "geometry"]
79
+ _exported_keys: List[str] = ["value", "confidence", "geometry", "objectness_score", "crop_orientation"]
73
80
  _children_names: List[str] = []
74
81
 
75
- def __init__(self, value: str, confidence: float, geometry: Union[BoundingBox, np.ndarray]) -> None:
82
+ def __init__(
83
+ self,
84
+ value: str,
85
+ confidence: float,
86
+ geometry: Union[BoundingBox, np.ndarray],
87
+ objectness_score: float,
88
+ crop_orientation: Dict[str, Any],
89
+ ) -> None:
76
90
  super().__init__()
77
91
  self.value = value
78
92
  self.confidence = confidence
79
93
  self.geometry = geometry
94
+ self.objectness_score = objectness_score
95
+ self.crop_orientation = crop_orientation
80
96
 
81
97
  def render(self) -> str:
82
98
  """Renders the full text of the element"""
@@ -135,7 +151,7 @@ class Line(Element):
135
151
  all words in it.
136
152
  """
137
153
 
138
- _exported_keys: List[str] = ["geometry"]
154
+ _exported_keys: List[str] = ["geometry", "objectness_score"]
139
155
  _children_names: List[str] = ["words"]
140
156
  words: List[Word] = []
141
157
 
@@ -143,15 +159,20 @@ class Line(Element):
143
159
  self,
144
160
  words: List[Word],
145
161
  geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
162
+ objectness_score: Optional[float] = None,
146
163
  ) -> None:
164
+ # Compute the objectness score of the line
165
+ if objectness_score is None:
166
+ objectness_score = float(np.mean([w.objectness_score for w in words]))
147
167
  # Resolve the geometry using the smallest enclosing bounding box
148
168
  if geometry is None:
149
169
  # Check whether this is a rotated or straight box
150
170
  box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 4 else resolve_enclosing_bbox
151
- geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator]
171
+ geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[misc]
152
172
 
153
173
  super().__init__(words=words)
154
174
  self.geometry = geometry
175
+ self.objectness_score = objectness_score
155
176
 
156
177
  def render(self) -> str:
157
178
  """Renders the full text of the element"""
@@ -189,7 +210,7 @@ class Block(Element):
189
210
  all lines and artefacts in it.
190
211
  """
191
212
 
192
- _exported_keys: List[str] = ["geometry"]
213
+ _exported_keys: List[str] = ["geometry", "objectness_score"]
193
214
  _children_names: List[str] = ["lines", "artefacts"]
194
215
  lines: List[Line] = []
195
216
  artefacts: List[Artefact] = []
@@ -199,7 +220,11 @@ class Block(Element):
199
220
  lines: List[Line] = [],
200
221
  artefacts: List[Artefact] = [],
201
222
  geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
223
+ objectness_score: Optional[float] = None,
202
224
  ) -> None:
225
+ # Compute the objectness score of the line
226
+ if objectness_score is None:
227
+ objectness_score = float(np.mean([w.objectness_score for line in lines for w in line.words]))
203
228
  # Resolve the geometry using the smallest enclosing bounding box
204
229
  if geometry is None:
205
230
  line_boxes = [word.geometry for line in lines for word in line.words]
@@ -207,10 +232,11 @@ class Block(Element):
207
232
  box_resolution_fn = (
208
233
  resolve_enclosing_rbbox if isinstance(lines[0].geometry, np.ndarray) else resolve_enclosing_bbox
209
234
  )
210
- geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore[operator]
235
+ geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore
211
236
 
212
237
  super().__init__(lines=lines, artefacts=artefacts)
213
238
  self.geometry = geometry
239
+ self.objectness_score = objectness_score
214
240
 
215
241
  def render(self, line_break: str = "\n") -> str:
216
242
  """Renders the full text of the element"""
@@ -274,12 +300,20 @@ class Page(Element):
274
300
  preserve_aspect_ratio: pass True if you passed True to the predictor
275
301
  **kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method
276
302
  """
303
+ requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed")
304
+ requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed")
305
+ import matplotlib.pyplot as plt
306
+
277
307
  visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio)
278
308
  plt.show(**kwargs)
279
309
 
280
310
  def synthesize(self, **kwargs) -> np.ndarray:
281
311
  """Synthesize the page from the predictions
282
312
 
313
+ Args:
314
+ ----
315
+ **kwargs: keyword arguments passed to the `synthesize_page` method
316
+
283
317
  Returns
284
318
  -------
285
319
  synthesized page
@@ -449,6 +483,10 @@ class KIEPage(Element):
449
483
  preserve_aspect_ratio: pass True if you passed True to the predictor
450
484
  **kwargs: keyword arguments passed to the matplotlib.pyplot.show method
451
485
  """
486
+ requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed")
487
+ requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed")
488
+ import matplotlib.pyplot as plt
489
+
452
490
  visualize_kie_page(
453
491
  self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio
454
492
  )
@@ -459,7 +497,7 @@ class KIEPage(Element):
459
497
 
460
498
  Args:
461
499
  ----
462
- **kwargs: keyword arguments passed to the matplotlib.pyplot.show method
500
+ **kwargs: keyword arguments passed to the `synthesize_kie_page` method
463
501
 
464
502
  Returns:
465
503
  -------
@@ -569,11 +607,15 @@ class Document(Element):
569
607
  def synthesize(self, **kwargs) -> List[np.ndarray]:
570
608
  """Synthesize all pages from their predictions
571
609
 
610
+ Args:
611
+ ----
612
+ **kwargs: keyword arguments passed to the `Page.synthesize` method
613
+
572
614
  Returns
573
615
  -------
574
616
  list of synthesized pages
575
617
  """
576
- return [page.synthesize() for page in self.pages]
618
+ return [page.synthesize(**kwargs) for page in self.pages]
577
619
 
578
620
  def export_as_xml(self, **kwargs) -> List[Tuple[bytes, ET.ElementTree]]:
579
621
  """Export the document as XML (hOCR-format)
doctr/io/html.py CHANGED
@@ -5,8 +5,6 @@
5
5
 
6
6
  from typing import Any
7
7
 
8
- from weasyprint import HTML
9
-
10
8
  __all__ = ["read_html"]
11
9
 
12
10
 
@@ -25,4 +23,6 @@ def read_html(url: str, **kwargs: Any) -> bytes:
25
23
  -------
26
24
  decoded PDF file as a bytes stream
27
25
  """
26
+ from weasyprint import HTML
27
+
28
28
  return HTML(url, **kwargs).write_pdf()
doctr/io/image/pytorch.py CHANGED
@@ -16,7 +16,7 @@ from doctr.utils.common_types import AbstractPath
16
16
  __all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
17
17
 
18
18
 
19
- def tensor_from_pil(pil_img: Image, dtype: torch.dtype = torch.float32) -> torch.Tensor:
19
+ def tensor_from_pil(pil_img: Image.Image, dtype: torch.dtype = torch.float32) -> torch.Tensor:
20
20
  """Convert a PIL Image to a PyTorch tensor
21
21
 
22
22
  Args:
@@ -51,9 +51,8 @@ def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float3
51
51
  if dtype not in (torch.uint8, torch.float16, torch.float32):
52
52
  raise ValueError("insupported value for dtype")
53
53
 
54
- pil_img = Image.open(img_path, mode="r").convert("RGB")
55
-
56
- return tensor_from_pil(pil_img, dtype)
54
+ with Image.open(img_path, mode="r") as pil_img:
55
+ return tensor_from_pil(pil_img.convert("RGB"), dtype)
57
56
 
58
57
 
59
58
  def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) -> torch.Tensor:
@@ -71,9 +70,8 @@ def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32)
71
70
  if dtype not in (torch.uint8, torch.float16, torch.float32):
72
71
  raise ValueError("insupported value for dtype")
73
72
 
74
- pil_img = Image.open(BytesIO(img_content), mode="r").convert("RGB")
75
-
76
- return tensor_from_pil(pil_img, dtype)
73
+ with Image.open(BytesIO(img_content), mode="r") as pil_img:
74
+ return tensor_from_pil(pil_img.convert("RGB"), dtype)
77
75
 
78
76
 
79
77
  def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor:
@@ -106,4 +104,4 @@ def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -
106
104
 
107
105
  def get_img_shape(img: torch.Tensor) -> Tuple[int, int]:
108
106
  """Get the shape of an image"""
109
- return img.shape[-2:]
107
+ return img.shape[-2:] # type: ignore[return-value]
@@ -15,7 +15,7 @@ from doctr.utils.common_types import AbstractPath
15
15
  __all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
16
16
 
17
17
 
18
- def tensor_from_pil(pil_img: Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
18
+ def tensor_from_pil(pil_img: Image.Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
19
19
  """Convert a PIL Image to a TensorFlow tensor
20
20
 
21
21
  Args:
doctr/io/pdf.py CHANGED
@@ -38,5 +38,8 @@ def read_pdf(
38
38
  the list of pages decoded as numpy ndarray of shape H x W x C
39
39
  """
40
40
  # Rasterise pages to numpy ndarrays with pypdfium2
41
- pdf = pdfium.PdfDocument(file, password=password, autoclose=True)
42
- return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf]
41
+ pdf = pdfium.PdfDocument(file, password=password)
42
+ try:
43
+ return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf]
44
+ finally:
45
+ pdf.close()
doctr/io/reader.py CHANGED
@@ -8,6 +8,7 @@ from typing import List, Sequence, Union
8
8
 
9
9
  import numpy as np
10
10
 
11
+ from doctr.file_utils import requires_package
11
12
  from doctr.utils.common_types import AbstractFile
12
13
 
13
14
  from .html import read_html
@@ -54,6 +55,11 @@ class DocumentFile:
54
55
  -------
55
56
  the list of pages decoded as numpy ndarray of shape H x W x 3
56
57
  """
58
+ requires_package(
59
+ "weasyprint",
60
+ "`.from_url` requires weasyprint installed.\n"
61
+ + "Installation instructions: https://doc.courtbouillon.org/weasyprint/stable/first_steps.html#installation",
62
+ )
57
63
  pdf_stream = read_html(url)
58
64
  return cls.from_pdf(pdf_stream, **kwargs)
59
65
 
doctr/models/__init__.py CHANGED
@@ -1,4 +1,3 @@
1
- from . import artefacts
2
1
  from .classification import *
3
2
  from .detection import *
4
3
  from .recognition import *