python-doctr 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/datasets/pytorch.py +2 -2
  6. doctr/datasets/generator/base.py +6 -5
  7. doctr/datasets/imgur5k.py +1 -1
  8. doctr/datasets/loader.py +1 -6
  9. doctr/datasets/utils.py +2 -1
  10. doctr/datasets/vocabs.py +9 -2
  11. doctr/file_utils.py +26 -12
  12. doctr/io/elements.py +40 -6
  13. doctr/io/html.py +2 -2
  14. doctr/io/image/pytorch.py +6 -8
  15. doctr/io/image/tensorflow.py +1 -1
  16. doctr/io/pdf.py +5 -2
  17. doctr/io/reader.py +6 -0
  18. doctr/models/__init__.py +0 -1
  19. doctr/models/_utils.py +57 -20
  20. doctr/models/builder.py +71 -13
  21. doctr/models/classification/mobilenet/pytorch.py +45 -9
  22. doctr/models/classification/mobilenet/tensorflow.py +38 -7
  23. doctr/models/classification/predictor/pytorch.py +18 -11
  24. doctr/models/classification/predictor/tensorflow.py +16 -10
  25. doctr/models/classification/textnet/pytorch.py +3 -3
  26. doctr/models/classification/textnet/tensorflow.py +3 -3
  27. doctr/models/classification/zoo.py +39 -15
  28. doctr/models/detection/__init__.py +1 -0
  29. doctr/models/detection/_utils/__init__.py +1 -0
  30. doctr/models/detection/_utils/base.py +66 -0
  31. doctr/models/detection/differentiable_binarization/base.py +4 -3
  32. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  33. doctr/models/detection/differentiable_binarization/tensorflow.py +14 -18
  34. doctr/models/detection/fast/__init__.py +6 -0
  35. doctr/models/detection/fast/base.py +257 -0
  36. doctr/models/detection/fast/pytorch.py +442 -0
  37. doctr/models/detection/fast/tensorflow.py +428 -0
  38. doctr/models/detection/linknet/base.py +4 -3
  39. doctr/models/detection/predictor/pytorch.py +15 -1
  40. doctr/models/detection/predictor/tensorflow.py +15 -1
  41. doctr/models/detection/zoo.py +21 -4
  42. doctr/models/factory/hub.py +3 -12
  43. doctr/models/kie_predictor/base.py +9 -3
  44. doctr/models/kie_predictor/pytorch.py +41 -20
  45. doctr/models/kie_predictor/tensorflow.py +36 -16
  46. doctr/models/modules/layers/pytorch.py +89 -10
  47. doctr/models/modules/layers/tensorflow.py +88 -10
  48. doctr/models/modules/transformer/pytorch.py +2 -2
  49. doctr/models/predictor/base.py +77 -50
  50. doctr/models/predictor/pytorch.py +31 -20
  51. doctr/models/predictor/tensorflow.py +27 -17
  52. doctr/models/preprocessor/pytorch.py +4 -4
  53. doctr/models/preprocessor/tensorflow.py +3 -2
  54. doctr/models/recognition/master/pytorch.py +2 -2
  55. doctr/models/recognition/parseq/pytorch.py +4 -3
  56. doctr/models/recognition/parseq/tensorflow.py +4 -3
  57. doctr/models/recognition/sar/pytorch.py +7 -6
  58. doctr/models/recognition/sar/tensorflow.py +3 -9
  59. doctr/models/recognition/vitstr/pytorch.py +1 -1
  60. doctr/models/recognition/zoo.py +1 -1
  61. doctr/models/zoo.py +2 -2
  62. doctr/py.typed +0 -0
  63. doctr/transforms/functional/base.py +1 -1
  64. doctr/transforms/functional/pytorch.py +4 -4
  65. doctr/transforms/modules/base.py +37 -15
  66. doctr/transforms/modules/pytorch.py +66 -8
  67. doctr/transforms/modules/tensorflow.py +63 -7
  68. doctr/utils/fonts.py +7 -5
  69. doctr/utils/geometry.py +35 -12
  70. doctr/utils/metrics.py +33 -174
  71. doctr/utils/reconstitution.py +126 -0
  72. doctr/utils/visualization.py +5 -118
  73. doctr/version.py +1 -1
  74. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/METADATA +96 -91
  75. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/RECORD +79 -75
  76. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/WHEEL +1 -1
  77. doctr/models/artefacts/__init__.py +0 -2
  78. doctr/models/artefacts/barcode.py +0 -74
  79. doctr/models/artefacts/face.py +0 -63
  80. doctr/models/obj_detection/__init__.py +0 -1
  81. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  82. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  83. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/LICENSE +0 -0
  84. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/top_level.txt +0 -0
  85. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/zip-safe +0 -0
doctr/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from . import io, models, datasets, transforms, utils
1
+ from . import io, models, datasets, contrib, transforms, utils
2
2
  from .file_utils import is_tf_available, is_torch_available
3
3
  from .version import __version__ # noqa: F401
File without changes
@@ -0,0 +1,131 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ from doctr.file_utils import requires_package
12
+
13
+ from .base import _BasePredictor
14
+
15
+ __all__ = ["ArtefactDetector"]
16
+
17
+ default_cfgs: Dict[str, Dict[str, Any]] = {
18
+ "yolov8_artefact": {
19
+ "input_shape": (3, 1024, 1024),
20
+ "labels": ["bar_code", "qr_code", "logo", "photo"],
21
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/yolo_artefact-f9d66f14.onnx&src=0",
22
+ },
23
+ }
24
+
25
+
26
+ class ArtefactDetector(_BasePredictor):
27
+ """
28
+ A class to detect artefacts in images
29
+
30
+ >>> from doctr.io import DocumentFile
31
+ >>> from doctr.contrib.artefacts import ArtefactDetector
32
+ >>> doc = DocumentFile.from_images(["path/to/image.jpg"])
33
+ >>> detector = ArtefactDetector()
34
+ >>> results = detector(doc)
35
+
36
+ Args:
37
+ ----
38
+ arch: the architecture to use
39
+ batch_size: the batch size to use
40
+ model_path: the path to the model to use
41
+ labels: the labels to use
42
+ input_shape: the input shape to use
43
+ mask_labels: the mask labels to use
44
+ conf_threshold: the confidence threshold to use
45
+ iou_threshold: the intersection over union threshold to use
46
+ **kwargs: additional arguments to be passed to `download_from_url`
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ arch: str = "yolov8_artefact",
52
+ batch_size: int = 2,
53
+ model_path: Optional[str] = None,
54
+ labels: Optional[List[str]] = None,
55
+ input_shape: Optional[Tuple[int, int, int]] = None,
56
+ conf_threshold: float = 0.5,
57
+ iou_threshold: float = 0.5,
58
+ **kwargs: Any,
59
+ ) -> None:
60
+ super().__init__(batch_size=batch_size, url=default_cfgs[arch]["url"], model_path=model_path, **kwargs)
61
+ self.labels = labels or default_cfgs[arch]["labels"]
62
+ self.input_shape = input_shape or default_cfgs[arch]["input_shape"]
63
+ self.conf_threshold = conf_threshold
64
+ self.iou_threshold = iou_threshold
65
+
66
+ def preprocess(self, img: np.ndarray) -> np.ndarray:
67
+ return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0)
68
+
69
+ def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]:
70
+ results = []
71
+
72
+ for batch in zip(output, input_images):
73
+ for out, img in zip(batch[0], batch[1]):
74
+ org_height, org_width = img.shape[:2]
75
+ width_scale, height_scale = org_width / self.input_shape[2], org_height / self.input_shape[1]
76
+ for res in out:
77
+ sample_results = []
78
+ for row in np.transpose(np.squeeze(res)):
79
+ classes_scores = row[4:]
80
+ max_score = np.amax(classes_scores)
81
+ if max_score >= self.conf_threshold:
82
+ class_id = np.argmax(classes_scores)
83
+ x, y, w, h = row[0], row[1], row[2], row[3]
84
+ # to rescaled xmin, ymin, xmax, ymax
85
+ xmin = int((x - w / 2) * width_scale)
86
+ ymin = int((y - h / 2) * height_scale)
87
+ xmax = int((x + w / 2) * width_scale)
88
+ ymax = int((y + h / 2) * height_scale)
89
+
90
+ sample_results.append({
91
+ "label": self.labels[class_id],
92
+ "confidence": float(max_score),
93
+ "box": [xmin, ymin, xmax, ymax],
94
+ })
95
+
96
+ # Filter out overlapping boxes
97
+ boxes = [res["box"] for res in sample_results]
98
+ scores = [res["confidence"] for res in sample_results]
99
+ keep_indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf_threshold, self.iou_threshold) # type: ignore[arg-type]
100
+ sample_results = [sample_results[i] for i in keep_indices]
101
+
102
+ results.append(sample_results)
103
+
104
+ self._results = results
105
+ return results
106
+
107
+ def show(self, **kwargs: Any) -> None:
108
+ """
109
+ Display the results
110
+
111
+ Args:
112
+ ----
113
+ **kwargs: additional keyword arguments to be passed to `plt.show`
114
+ """
115
+ requires_package("matplotlib", "`.show()` requires matplotlib installed")
116
+ import matplotlib.pyplot as plt
117
+ from matplotlib.patches import Rectangle
118
+
119
+ # visualize the results with matplotlib
120
+ if self._results and self._inputs:
121
+ for img, res in zip(self._inputs, self._results):
122
+ plt.figure(figsize=(10, 10))
123
+ plt.imshow(img)
124
+ for obj in res:
125
+ xmin, ymin, xmax, ymax = obj["box"]
126
+ label = obj["label"]
127
+ plt.text(xmin, ymin, f"{label} {obj['confidence']:.2f}", color="red")
128
+ plt.gca().add_patch(
129
+ Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor="red", linewidth=2)
130
+ )
131
+ plt.show(**kwargs)
doctr/contrib/base.py ADDED
@@ -0,0 +1,105 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, List, Optional
7
+
8
+ import numpy as np
9
+
10
+ from doctr.file_utils import requires_package
11
+ from doctr.utils.data import download_from_url
12
+
13
+
14
+ class _BasePredictor:
15
+ """
16
+ Base class for all predictors
17
+
18
+ Args:
19
+ ----
20
+ batch_size: the batch size to use
21
+ url: the url to use to download a model if needed
22
+ model_path: the path to the model to use
23
+ **kwargs: additional arguments to be passed to `download_from_url`
24
+ """
25
+
26
+ def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None:
27
+ self.batch_size = batch_size
28
+ self.session = self._init_model(url, model_path, **kwargs)
29
+
30
+ self._inputs: List[np.ndarray] = []
31
+ self._results: List[Any] = []
32
+
33
+ def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any:
34
+ """
35
+ Download the model from the given url if needed
36
+
37
+ Args:
38
+ ----
39
+ url: the url to use
40
+ model_path: the path to the model to use
41
+ **kwargs: additional arguments to be passed to `download_from_url`
42
+
43
+ Returns:
44
+ -------
45
+ Any: the ONNX loaded model
46
+ """
47
+ requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
48
+ import onnxruntime as ort
49
+
50
+ if not url and not model_path:
51
+ raise ValueError("You must provide either a url or a model_path")
52
+ onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs)) # type: ignore[arg-type]
53
+ return ort.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
54
+
55
+ def preprocess(self, img: np.ndarray) -> np.ndarray:
56
+ """
57
+ Preprocess the input image
58
+
59
+ Args:
60
+ ----
61
+ img: the input image to preprocess
62
+
63
+ Returns:
64
+ -------
65
+ np.ndarray: the preprocessed image
66
+ """
67
+ raise NotImplementedError
68
+
69
+ def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any:
70
+ """
71
+ Postprocess the model output
72
+
73
+ Args:
74
+ ----
75
+ output: the model output to postprocess
76
+ input_images: the input images used to generate the output
77
+
78
+ Returns:
79
+ -------
80
+ Any: the postprocessed output
81
+ """
82
+ raise NotImplementedError
83
+
84
+ def __call__(self, inputs: List[np.ndarray]) -> Any:
85
+ """
86
+ Call the model on the given inputs
87
+
88
+ Args:
89
+ ----
90
+ inputs: the inputs to use
91
+
92
+ Returns:
93
+ -------
94
+ Any: the postprocessed output
95
+ """
96
+ self._inputs = inputs
97
+ model_inputs = self.session.get_inputs()
98
+
99
+ batched_inputs = [inputs[i : i + self.batch_size] for i in range(0, len(inputs), self.batch_size)]
100
+ processed_batches = [
101
+ np.array([self.preprocess(img) for img in batch], dtype=np.float32) for batch in batched_inputs
102
+ ]
103
+
104
+ outputs = [self.session.run(None, {model_inputs[0].name: batch}) for batch in processed_batches]
105
+ return self.postprocess(outputs, batched_inputs)
@@ -50,9 +50,9 @@ class AbstractDataset(_AbstractDataset):
50
50
  @staticmethod
51
51
  def collate_fn(samples: List[Tuple[torch.Tensor, Any]]) -> Tuple[torch.Tensor, List[Any]]:
52
52
  images, targets = zip(*samples)
53
- images = torch.stack(images, dim=0)
53
+ images = torch.stack(images, dim=0) # type: ignore[assignment]
54
54
 
55
- return images, list(targets)
55
+ return images, list(targets) # type: ignore[return-value]
56
56
 
57
57
 
58
58
  class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
@@ -20,7 +20,7 @@ def synthesize_text_img(
20
20
  font_family: Optional[str] = None,
21
21
  background_color: Optional[Tuple[int, int, int]] = None,
22
22
  text_color: Optional[Tuple[int, int, int]] = None,
23
- ) -> Image:
23
+ ) -> Image.Image:
24
24
  """Generate a synthetic text image
25
25
 
26
26
  Args:
@@ -81,7 +81,7 @@ class _CharacterGenerator(AbstractDataset):
81
81
  self._data: List[Image.Image] = []
82
82
  if cache_samples:
83
83
  self._data = [
84
- (synthesize_text_img(char, font_family=font), idx)
84
+ (synthesize_text_img(char, font_family=font), idx) # type: ignore[misc]
85
85
  for idx, char in enumerate(self.vocab)
86
86
  for font in self.font_family
87
87
  ]
@@ -93,7 +93,7 @@ class _CharacterGenerator(AbstractDataset):
93
93
  # Samples are already cached
94
94
  if len(self._data) > 0:
95
95
  idx = index % len(self._data)
96
- pil_img, target = self._data[idx]
96
+ pil_img, target = self._data[idx] # type: ignore[misc]
97
97
  else:
98
98
  target = index % len(self.vocab)
99
99
  pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family))
@@ -132,7 +132,8 @@ class _WordGenerator(AbstractDataset):
132
132
  if cache_samples:
133
133
  _words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)]
134
134
  self._data = [
135
- (synthesize_text_img(text, font_family=random.choice(self.font_family)), text) for text in _words
135
+ (synthesize_text_img(text, font_family=random.choice(self.font_family)), text) # type: ignore[misc]
136
+ for text in _words
136
137
  ]
137
138
 
138
139
  def _generate_string(self, min_chars: int, max_chars: int) -> str:
@@ -145,7 +146,7 @@ class _WordGenerator(AbstractDataset):
145
146
  def _read_sample(self, index: int) -> Tuple[Any, str]:
146
147
  # Samples are already cached
147
148
  if len(self._data) > 0:
148
- pil_img, target = self._data[index]
149
+ pil_img, target = self._data[index] # type: ignore[misc]
149
150
  else:
150
151
  target = self._generate_string(*self.wordlen_range)
151
152
  pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family))
doctr/datasets/imgur5k.py CHANGED
@@ -112,7 +112,7 @@ class IMGUR5K(AbstractDataset):
112
112
  if ann["word"] != "."
113
113
  ]
114
114
  # (x, y) coordinates of top left, top right, bottom right, bottom left corners
115
- box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes] # type: ignore[arg-type]
115
+ box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes]
116
116
 
117
117
  if not use_polygons:
118
118
  # xmin, ymin, xmax, ymax
doctr/datasets/loader.py CHANGED
@@ -9,8 +9,6 @@ from typing import Callable, Optional
9
9
  import numpy as np
10
10
  import tensorflow as tf
11
11
 
12
- from doctr.utils.multithreading import multithread_exec
13
-
14
12
  __all__ = ["DataLoader"]
15
13
 
16
14
 
@@ -47,7 +45,6 @@ class DataLoader:
47
45
  shuffle: whether the samples should be shuffled before passing it to the iterator
48
46
  batch_size: number of elements in each batch
49
47
  drop_last: if `True`, drops the last batch if it isn't full
50
- num_workers: number of workers to use for data loading
51
48
  collate_fn: function to merge samples into a batch
52
49
  """
53
50
 
@@ -57,7 +54,6 @@ class DataLoader:
57
54
  shuffle: bool = True,
58
55
  batch_size: int = 1,
59
56
  drop_last: bool = False,
60
- num_workers: Optional[int] = None,
61
57
  collate_fn: Optional[Callable] = None,
62
58
  ) -> None:
63
59
  self.dataset = dataset
@@ -69,7 +65,6 @@ class DataLoader:
69
65
  self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, "collate_fn") else default_collate
70
66
  else:
71
67
  self.collate_fn = collate_fn
72
- self.num_workers = num_workers
73
68
  self.reset()
74
69
 
75
70
  def __len__(self) -> int:
@@ -92,7 +87,7 @@ class DataLoader:
92
87
  idx = self._num_yielded * self.batch_size
93
88
  indices = self.indices[idx : min(len(self.dataset), idx + self.batch_size)]
94
89
 
95
- samples = list(multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers))
90
+ samples = list(map(self.dataset.__getitem__, indices))
96
91
 
97
92
  batch_data = self.collate_fn(samples)
98
93
 
doctr/datasets/utils.py CHANGED
@@ -186,7 +186,8 @@ def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> Lis
186
186
  -------
187
187
  a list of cropped images
188
188
  """
189
- img: np.ndarray = np.array(Image.open(img_path).convert("RGB"))
189
+ with Image.open(img_path) as pil_img:
190
+ img: np.ndarray = np.array(pil_img.convert("RGB"))
190
191
  # Polygon
191
192
  if geoms.ndim == 3 and geoms.shape[1:] == (4, 2):
192
193
  return extract_rcrops(img, geoms.astype(dtype=int))
doctr/datasets/vocabs.py CHANGED
@@ -17,9 +17,14 @@ VOCABS: Dict[str, str] = {
17
17
  "ancient_greek": "αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ",
18
18
  "arabic_letters": "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي",
19
19
  "persian_letters": "پچڢڤگ",
20
- "hindi_digits": "٠١٢٣٤٥٦٧٨٩",
20
+ "arabic_digits": "٠١٢٣٤٥٦٧٨٩",
21
21
  "arabic_diacritics": "ًٌٍَُِّْ",
22
22
  "arabic_punctuation": "؟؛«»—",
23
+ "hindi_letters": "अआइईउऊऋॠऌॡएऐओऔअंअःकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसह",
24
+ "hindi_digits": "०१२३४५६७८९",
25
+ "hindi_punctuation": "।,?!:्ॐ॰॥॰",
26
+ "bangla_letters": "অআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃেৈোৌ্ৎংঃঁ",
27
+ "bangla_digits": "০১২৩৪৫৬৭৮৯",
23
28
  }
24
29
 
25
30
  VOCABS["latin"] = VOCABS["digits"] + VOCABS["ascii_letters"] + VOCABS["punctuation"]
@@ -32,7 +37,7 @@ VOCABS["italian"] = VOCABS["english"] + "àèéìíîòóùúÀÈÉÌÍÎÒÓÙ
32
37
  VOCABS["german"] = VOCABS["english"] + "äöüßÄÖÜẞ"
33
38
  VOCABS["arabic"] = (
34
39
  VOCABS["digits"]
35
- + VOCABS["hindi_digits"]
40
+ + VOCABS["arabic_digits"]
36
41
  + VOCABS["arabic_letters"]
37
42
  + VOCABS["persian_letters"]
38
43
  + VOCABS["arabic_diacritics"]
@@ -52,6 +57,8 @@ VOCABS["vietnamese"] = (
52
57
  + "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ"
53
58
  )
54
59
  VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪"
60
+ VOCABS["hindi"] = VOCABS["hindi_letters"] + VOCABS["hindi_digits"] + VOCABS["hindi_punctuation"]
61
+ VOCABS["bangla"] = VOCABS["bangla_letters"] + VOCABS["bangla_digits"]
55
62
  VOCABS["multilingual"] = "".join(
56
63
  dict.fromkeys(
57
64
  VOCABS["french"]
doctr/file_utils.py CHANGED
@@ -5,21 +5,16 @@
5
5
 
6
6
  # Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
7
7
 
8
+ import importlib.metadata
8
9
  import importlib.util
9
10
  import logging
10
11
  import os
11
- import sys
12
+ from typing import Optional
12
13
 
13
14
  CLASS_NAME: str = "words"
14
15
 
15
16
 
16
- if sys.version_info < (3, 8): # pragma: no cover
17
- import importlib_metadata
18
- else:
19
- import importlib.metadata as importlib_metadata
20
-
21
-
22
- __all__ = ["is_tf_available", "is_torch_available", "CLASS_NAME"]
17
+ __all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"]
23
18
 
24
19
  ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
25
20
  ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
@@ -32,9 +27,9 @@ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VA
32
27
  _torch_available = importlib.util.find_spec("torch") is not None
33
28
  if _torch_available:
34
29
  try:
35
- _torch_version = importlib_metadata.version("torch")
30
+ _torch_version = importlib.metadata.version("torch")
36
31
  logging.info(f"PyTorch version {_torch_version} available.")
37
- except importlib_metadata.PackageNotFoundError: # pragma: no cover
32
+ except importlib.metadata.PackageNotFoundError: # pragma: no cover
38
33
  _torch_available = False
39
34
  else: # pragma: no cover
40
35
  logging.info("Disabling PyTorch because USE_TF is set")
@@ -59,9 +54,9 @@ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VA
59
54
  # For the metadata, we have to look for both tensorflow and tensorflow-cpu
60
55
  for pkg in candidates:
61
56
  try:
62
- _tf_version = importlib_metadata.version(pkg)
57
+ _tf_version = importlib.metadata.version(pkg)
63
58
  break
64
- except importlib_metadata.PackageNotFoundError:
59
+ except importlib.metadata.PackageNotFoundError:
65
60
  pass
66
61
  _tf_available = _tf_version is not None
67
62
  if _tf_available:
@@ -82,6 +77,25 @@ if not _torch_available and not _tf_available: # pragma: no cover
82
77
  )
83
78
 
84
79
 
80
+ def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover
81
+ """
82
+ package requirement helper
83
+
84
+ Args:
85
+ ----
86
+ name: name of the package
87
+ extra_message: additional message to display if the package is not found
88
+ """
89
+ try:
90
+ _pkg_version = importlib.metadata.version(name)
91
+ logging.info(f"{name} version {_pkg_version} available.")
92
+ except importlib.metadata.PackageNotFoundError:
93
+ raise ImportError(
94
+ f"\n\n{extra_message if extra_message is not None else ''} "
95
+ f"\nPlease install it with the following command: pip install {name}\n"
96
+ )
97
+
98
+
85
99
  def is_torch_available():
86
100
  """Whether PyTorch is installed."""
87
101
  return _torch_available
doctr/io/elements.py CHANGED
@@ -12,14 +12,19 @@ from xml.etree import ElementTree as ET
12
12
  from xml.etree.ElementTree import Element as ETElement
13
13
  from xml.etree.ElementTree import SubElement
14
14
 
15
- import matplotlib.pyplot as plt
16
15
  import numpy as np
17
16
 
18
17
  import doctr
18
+ from doctr.file_utils import requires_package
19
19
  from doctr.utils.common_types import BoundingBox
20
20
  from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox
21
+ from doctr.utils.reconstitution import synthesize_kie_page, synthesize_page
21
22
  from doctr.utils.repr import NestedObject
22
- from doctr.utils.visualization import synthesize_kie_page, synthesize_page, visualize_kie_page, visualize_page
23
+
24
+ try: # optional dependency for visualization
25
+ from doctr.utils.visualization import visualize_kie_page, visualize_page
26
+ except ModuleNotFoundError:
27
+ pass
23
28
 
24
29
  __all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"]
25
30
 
@@ -67,16 +72,27 @@ class Word(Element):
67
72
  confidence: the confidence associated with the text prediction
68
73
  geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
69
74
  the page's size
75
+ objectness_score: the objectness score of the detection
76
+ crop_orientation: the general orientation of the crop in degrees and its confidence
70
77
  """
71
78
 
72
- _exported_keys: List[str] = ["value", "confidence", "geometry"]
79
+ _exported_keys: List[str] = ["value", "confidence", "geometry", "objectness_score", "crop_orientation"]
73
80
  _children_names: List[str] = []
74
81
 
75
- def __init__(self, value: str, confidence: float, geometry: Union[BoundingBox, np.ndarray]) -> None:
82
+ def __init__(
83
+ self,
84
+ value: str,
85
+ confidence: float,
86
+ geometry: Union[BoundingBox, np.ndarray],
87
+ objectness_score: float,
88
+ crop_orientation: Dict[str, Any],
89
+ ) -> None:
76
90
  super().__init__()
77
91
  self.value = value
78
92
  self.confidence = confidence
79
93
  self.geometry = geometry
94
+ self.objectness_score = objectness_score
95
+ self.crop_orientation = crop_orientation
80
96
 
81
97
  def render(self) -> str:
82
98
  """Renders the full text of the element"""
@@ -135,7 +151,7 @@ class Line(Element):
135
151
  all words in it.
136
152
  """
137
153
 
138
- _exported_keys: List[str] = ["geometry"]
154
+ _exported_keys: List[str] = ["geometry", "objectness_score"]
139
155
  _children_names: List[str] = ["words"]
140
156
  words: List[Word] = []
141
157
 
@@ -143,7 +159,11 @@ class Line(Element):
143
159
  self,
144
160
  words: List[Word],
145
161
  geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
162
+ objectness_score: Optional[float] = None,
146
163
  ) -> None:
164
+ # Compute the objectness score of the line
165
+ if objectness_score is None:
166
+ objectness_score = float(np.mean([w.objectness_score for w in words]))
147
167
  # Resolve the geometry using the smallest enclosing bounding box
148
168
  if geometry is None:
149
169
  # Check whether this is a rotated or straight box
@@ -152,6 +172,7 @@ class Line(Element):
152
172
 
153
173
  super().__init__(words=words)
154
174
  self.geometry = geometry
175
+ self.objectness_score = objectness_score
155
176
 
156
177
  def render(self) -> str:
157
178
  """Renders the full text of the element"""
@@ -189,7 +210,7 @@ class Block(Element):
189
210
  all lines and artefacts in it.
190
211
  """
191
212
 
192
- _exported_keys: List[str] = ["geometry"]
213
+ _exported_keys: List[str] = ["geometry", "objectness_score"]
193
214
  _children_names: List[str] = ["lines", "artefacts"]
194
215
  lines: List[Line] = []
195
216
  artefacts: List[Artefact] = []
@@ -199,7 +220,11 @@ class Block(Element):
199
220
  lines: List[Line] = [],
200
221
  artefacts: List[Artefact] = [],
201
222
  geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
223
+ objectness_score: Optional[float] = None,
202
224
  ) -> None:
225
+ # Compute the objectness score of the line
226
+ if objectness_score is None:
227
+ objectness_score = float(np.mean([w.objectness_score for line in lines for w in line.words]))
203
228
  # Resolve the geometry using the smallest enclosing bounding box
204
229
  if geometry is None:
205
230
  line_boxes = [word.geometry for line in lines for word in line.words]
@@ -211,6 +236,7 @@ class Block(Element):
211
236
 
212
237
  super().__init__(lines=lines, artefacts=artefacts)
213
238
  self.geometry = geometry
239
+ self.objectness_score = objectness_score
214
240
 
215
241
  def render(self, line_break: str = "\n") -> str:
216
242
  """Renders the full text of the element"""
@@ -274,6 +300,10 @@ class Page(Element):
274
300
  preserve_aspect_ratio: pass True if you passed True to the predictor
275
301
  **kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method
276
302
  """
303
+ requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed")
304
+ requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed")
305
+ import matplotlib.pyplot as plt
306
+
277
307
  visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio)
278
308
  plt.show(**kwargs)
279
309
 
@@ -449,6 +479,10 @@ class KIEPage(Element):
449
479
  preserve_aspect_ratio: pass True if you passed True to the predictor
450
480
  **kwargs: keyword arguments passed to the matplotlib.pyplot.show method
451
481
  """
482
+ requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed")
483
+ requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed")
484
+ import matplotlib.pyplot as plt
485
+
452
486
  visualize_kie_page(
453
487
  self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio
454
488
  )
doctr/io/html.py CHANGED
@@ -5,8 +5,6 @@
5
5
 
6
6
  from typing import Any
7
7
 
8
- from weasyprint import HTML
9
-
10
8
  __all__ = ["read_html"]
11
9
 
12
10
 
@@ -25,4 +23,6 @@ def read_html(url: str, **kwargs: Any) -> bytes:
25
23
  -------
26
24
  decoded PDF file as a bytes stream
27
25
  """
26
+ from weasyprint import HTML
27
+
28
28
  return HTML(url, **kwargs).write_pdf()
doctr/io/image/pytorch.py CHANGED
@@ -16,7 +16,7 @@ from doctr.utils.common_types import AbstractPath
16
16
  __all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
17
17
 
18
18
 
19
- def tensor_from_pil(pil_img: Image, dtype: torch.dtype = torch.float32) -> torch.Tensor:
19
+ def tensor_from_pil(pil_img: Image.Image, dtype: torch.dtype = torch.float32) -> torch.Tensor:
20
20
  """Convert a PIL Image to a PyTorch tensor
21
21
 
22
22
  Args:
@@ -51,9 +51,8 @@ def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float3
51
51
  if dtype not in (torch.uint8, torch.float16, torch.float32):
52
52
  raise ValueError("insupported value for dtype")
53
53
 
54
- pil_img = Image.open(img_path, mode="r").convert("RGB")
55
-
56
- return tensor_from_pil(pil_img, dtype)
54
+ with Image.open(img_path, mode="r") as pil_img:
55
+ return tensor_from_pil(pil_img.convert("RGB"), dtype)
57
56
 
58
57
 
59
58
  def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) -> torch.Tensor:
@@ -71,9 +70,8 @@ def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32)
71
70
  if dtype not in (torch.uint8, torch.float16, torch.float32):
72
71
  raise ValueError("insupported value for dtype")
73
72
 
74
- pil_img = Image.open(BytesIO(img_content), mode="r").convert("RGB")
75
-
76
- return tensor_from_pil(pil_img, dtype)
73
+ with Image.open(BytesIO(img_content), mode="r") as pil_img:
74
+ return tensor_from_pil(pil_img.convert("RGB"), dtype)
77
75
 
78
76
 
79
77
  def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor:
@@ -106,4 +104,4 @@ def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -
106
104
 
107
105
  def get_img_shape(img: torch.Tensor) -> Tuple[int, int]:
108
106
  """Get the shape of an image"""
109
- return img.shape[-2:]
107
+ return img.shape[-2:] # type: ignore[return-value]
@@ -15,7 +15,7 @@ from doctr.utils.common_types import AbstractPath
15
15
  __all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
16
16
 
17
17
 
18
- def tensor_from_pil(pil_img: Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
18
+ def tensor_from_pil(pil_img: Image.Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
19
19
  """Convert a PIL Image to a TensorFlow tensor
20
20
 
21
21
  Args: