python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -42,12 +42,15 @@ class PositionalEncoding(layers.Layer, NestedObject):
42
42
  x: tf.Tensor,
43
43
  **kwargs: Any,
44
44
  ) -> tf.Tensor:
45
- """
45
+ """Forward pass
46
+
46
47
  Args:
48
+ ----
47
49
  x: embeddings (batch, max_len, d_model)
48
50
  **kwargs: additional arguments
49
51
 
50
- Returns:
52
+ Returns
53
+ -------
51
54
  positional embeddings (batch, max_len, d_model)
52
55
  """
53
56
  if x.dtype == tf.float16: # amp fix: cast to half
@@ -62,7 +65,6 @@ def scaled_dot_product_attention(
62
65
  query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: Optional[tf.Tensor] = None
63
66
  ) -> Tuple[tf.Tensor, tf.Tensor]:
64
67
  """Scaled Dot-Product Attention"""
65
-
66
68
  scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) / math.sqrt(query.shape[-1])
67
69
  if mask is not None:
68
70
  # NOTE: to ensure the ONNX compatibility, tf.where works only with bool type condition
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -28,8 +28,7 @@ class PatchEmbedding(nn.Module):
28
28
  self.projection = nn.Conv2d(channels, embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
29
29
 
30
30
  def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
31
- """
32
- 100 % borrowed from:
31
+ """100 % borrowed from:
33
32
  https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py
34
33
 
35
34
  This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
@@ -38,7 +37,6 @@ class PatchEmbedding(nn.Module):
38
37
  Source:
39
38
  https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py
40
39
  """
41
-
42
40
  num_patches = embeddings.shape[1] - 1
43
41
  num_positions = self.positions.shape[1] - 1
44
42
  if num_patches == num_positions and height == width:
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -45,8 +45,7 @@ class PatchEmbedding(layers.Layer, NestedObject):
45
45
  )
46
46
 
47
47
  def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
48
- """
49
- 100 % borrowed from:
48
+ """100 % borrowed from:
50
49
  https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_tf_vit.py
51
50
 
52
51
  This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
@@ -55,7 +54,6 @@ class PatchEmbedding(layers.Layer, NestedObject):
55
54
  Source:
56
55
  https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py
57
56
  """
58
-
59
57
  seq_len, dim = embeddings.shape[1:]
60
58
  num_patches = seq_len - 1
61
59
 
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -70,10 +70,12 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained: bool = False, **kwargs: Any) -
70
70
  >>> out = model(input_tensor)
71
71
 
72
72
  Args:
73
+ ----
73
74
  pretrained (bool): If True, returns a model pre-trained on our object detection dataset
75
+ **kwargs: keyword arguments of the FasterRCNN architecture
74
76
 
75
77
  Returns:
78
+ -------
76
79
  object detection architecture
77
80
  """
78
-
79
81
  return _fasterrcnn("fasterrcnn_mobilenet_v3_large_fpn", pretrained, **kwargs)
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Optional, Tuple
6
+ from typing import Any, Callable, List, Optional, Tuple
7
7
 
8
8
  import numpy as np
9
9
 
@@ -21,6 +21,7 @@ class _OCRPredictor:
21
21
  """Implements an object able to localize and identify text elements in a set of documents
22
22
 
23
23
  Args:
24
+ ----
24
25
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
25
26
  without rotated textual elements.
26
27
  straighten_pages: if True, estimates the page general orientation based on the median line orientation.
@@ -28,7 +29,7 @@ class _OCRPredictor:
28
29
  accordingly. Doing so will improve performances for documents with page-uniform rotations.
29
30
  preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
30
31
  symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
31
- kwargs: keyword args of `DocumentBuilder`
32
+ **kwargs: keyword args of `DocumentBuilder`
32
33
  """
33
34
 
34
35
  crop_orientation_predictor: Optional[CropOrientationPredictor]
@@ -47,6 +48,7 @@ class _OCRPredictor:
47
48
  self.doc_builder = DocumentBuilder(**kwargs)
48
49
  self.preserve_aspect_ratio = preserve_aspect_ratio
49
50
  self.symmetric_pad = symmetric_pad
51
+ self.hooks: List[Callable] = []
50
52
 
51
53
  @staticmethod
52
54
  def _generate_crops(
@@ -148,3 +150,12 @@ class _OCRPredictor:
148
150
  _idx += page_boxes.shape[0]
149
151
 
150
152
  return loc_preds, text_preds
153
+
154
+ def add_hook(self, hook: Callable) -> None:
155
+ """Add a hook to the predictor
156
+
157
+ Args:
158
+ ----
159
+ hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
160
+ """
161
+ self.hooks.append(hook)
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -13,7 +13,7 @@ from doctr.io.elements import Document
13
13
  from doctr.models._utils import estimate_orientation, get_language
14
14
  from doctr.models.detection.predictor import DetectionPredictor
15
15
  from doctr.models.recognition.predictor import RecognitionPredictor
16
- from doctr.utils.geometry import rotate_boxes, rotate_image
16
+ from doctr.utils.geometry import rotate_image
17
17
 
18
18
  from .base import _OCRPredictor
19
19
 
@@ -24,6 +24,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
+ ----
27
28
  det_predictor: detection module
28
29
  reco_predictor: recognition module
29
30
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -35,7 +36,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
35
36
  page. Doing so will slightly deteriorate the overall latency.
36
37
  detect_language: if True, the language prediction will be added to the predictions for each
37
38
  page. Doing so will slightly deteriorate the overall latency.
38
- kwargs: keyword args of `DocumentBuilder`
39
+ **kwargs: keyword args of `DocumentBuilder`
39
40
  """
40
41
 
41
42
  def __init__(
@@ -59,7 +60,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
59
60
  self.detect_orientation = detect_orientation
60
61
  self.detect_language = detect_language
61
62
 
62
- @torch.no_grad()
63
+ @torch.inference_mode()
63
64
  def forward(
64
65
  self,
65
66
  pages: List[Union[np.ndarray, torch.Tensor]],
@@ -71,11 +72,18 @@ class OCRPredictor(nn.Module, _OCRPredictor):
71
72
 
72
73
  origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages]
73
74
 
75
+ # Localize text elements
76
+ loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
77
+
74
78
  # Detect document rotation and rotate pages
79
+ seg_maps = [
80
+ np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8)
81
+ for out_map in out_maps
82
+ ]
75
83
  if self.detect_orientation:
76
- origin_page_orientations = [estimate_orientation(page) for page in pages] # type: ignore[arg-type]
84
+ origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
77
85
  orientations = [
78
- {"value": orientation_page, "confidence": 1.0} for orientation_page in origin_page_orientations
86
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
79
87
  ]
80
88
  else:
81
89
  orientations = None
@@ -83,15 +91,12 @@ class OCRPredictor(nn.Module, _OCRPredictor):
83
91
  origin_page_orientations = (
84
92
  origin_page_orientations
85
93
  if self.detect_orientation
86
- else [estimate_orientation(page) for page in pages] # type: ignore[arg-type]
94
+ else [estimate_orientation(seq_map) for seq_map in seg_maps]
87
95
  )
88
- pages = [
89
- rotate_image(page, -angle, expand=True) # type: ignore[arg-type]
90
- for page, angle in zip(pages, origin_page_orientations)
91
- ]
96
+ pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
97
+ # Forward again to get predictions on straight pages
98
+ loc_preds = self.det_predictor(pages, **kwargs)
92
99
 
93
- # Localize text elements
94
- loc_preds = self.det_predictor(pages, **kwargs)
95
100
  assert all(
96
101
  len(loc_pred) == 1 for loc_pred in loc_preds
97
102
  ), "Detection Model in ocr_predictor should output only one class"
@@ -101,11 +106,15 @@ class OCRPredictor(nn.Module, _OCRPredictor):
101
106
  channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
102
107
 
103
108
  # Rectify crops if aspect ratio
104
- loc_preds = self._remove_padding(pages, loc_preds) # type: ignore[arg-type]
109
+ loc_preds = self._remove_padding(pages, loc_preds)
110
+
111
+ # Apply hooks to loc_preds if any
112
+ for hook in self.hooks:
113
+ loc_preds = hook(loc_preds)
105
114
 
106
115
  # Crop images
107
116
  crops, loc_preds = self._prepare_crops(
108
- pages, # type: ignore[arg-type]
117
+ pages,
109
118
  loc_preds,
110
119
  channels_last=channels_last,
111
120
  assume_straight_pages=self.assume_straight_pages,
@@ -123,24 +132,12 @@ class OCRPredictor(nn.Module, _OCRPredictor):
123
132
  languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
124
133
  else:
125
134
  languages_dict = None
126
- # Rotate back pages and boxes while keeping original image size
127
- if self.straighten_pages:
128
- boxes = [
129
- rotate_boxes(
130
- page_boxes,
131
- angle,
132
- orig_shape=page.shape[:2]
133
- if isinstance(page, np.ndarray)
134
- else page.shape[1:], # type: ignore[arg-type]
135
- target_shape=mask, # type: ignore[arg-type]
136
- )
137
- for page_boxes, page, angle, mask in zip(boxes, pages, origin_page_orientations, origin_page_shapes)
138
- ]
139
135
 
140
136
  out = self.doc_builder(
137
+ pages,
141
138
  boxes,
142
139
  text_preds,
143
- [page.shape[:2] if channels_last else page.shape[-2:] for page in pages], # type: ignore[misc]
140
+ origin_page_shapes,
144
141
  orientations,
145
142
  languages_dict,
146
143
  )
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -12,7 +12,7 @@ from doctr.io.elements import Document
12
12
  from doctr.models._utils import estimate_orientation, get_language
13
13
  from doctr.models.detection.predictor import DetectionPredictor
14
14
  from doctr.models.recognition.predictor import RecognitionPredictor
15
- from doctr.utils.geometry import rotate_boxes, rotate_image
15
+ from doctr.utils.geometry import rotate_image
16
16
  from doctr.utils.repr import NestedObject
17
17
 
18
18
  from .base import _OCRPredictor
@@ -24,6 +24,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
+ ----
27
28
  det_predictor: detection module
28
29
  reco_predictor: recognition module
29
30
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -35,7 +36,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
35
36
  page. Doing so will slightly deteriorate the overall latency.
36
37
  detect_language: if True, the language prediction will be added to the predictions for each
37
38
  page. Doing so will slightly deteriorate the overall latency.
38
- kwargs: keyword args of `DocumentBuilder`
39
+ **kwargs: keyword args of `DocumentBuilder`
39
40
  """
40
41
 
41
42
  _children_names = ["det_predictor", "reco_predictor", "doc_builder"]
@@ -71,31 +72,43 @@ class OCRPredictor(NestedObject, _OCRPredictor):
71
72
 
72
73
  origin_page_shapes = [page.shape[:2] for page in pages]
73
74
 
75
+ # Localize text elements
76
+ loc_preds_dict, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
77
+
74
78
  # Detect document rotation and rotate pages
79
+ seg_maps = [
80
+ np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8)
81
+ for out_map in out_maps
82
+ ]
75
83
  if self.detect_orientation:
76
- origin_page_orientations = [estimate_orientation(page) for page in pages]
84
+ origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
77
85
  orientations = [
78
- {"value": orientation_page, "confidence": 1.0} for orientation_page in origin_page_orientations
86
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
79
87
  ]
80
88
  else:
81
89
  orientations = None
82
90
  if self.straighten_pages:
83
91
  origin_page_orientations = (
84
- origin_page_orientations if self.detect_orientation else [estimate_orientation(page) for page in pages]
92
+ origin_page_orientations
93
+ if self.detect_orientation
94
+ else [estimate_orientation(seq_map) for seq_map in seg_maps]
85
95
  )
86
- pages = [rotate_image(page, -angle, expand=True) for page, angle in zip(pages, origin_page_orientations)]
96
+ pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
97
+ # forward again to get predictions on straight pages
98
+ loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
87
99
 
88
- # Localize text elements
89
- loc_preds_dict = self.det_predictor(pages, **kwargs)
90
100
  assert all(
91
101
  len(loc_pred) == 1 for loc_pred in loc_preds_dict
92
102
  ), "Detection Model in ocr_predictor should output only one class"
93
-
94
- loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict]
103
+ loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr]
95
104
 
96
105
  # Rectify crops if aspect ratio
97
106
  loc_preds = self._remove_padding(pages, loc_preds)
98
107
 
108
+ # Apply hooks to loc_preds if any
109
+ for hook in self.hooks:
110
+ loc_preds = hook(loc_preds)
111
+
99
112
  # Crop images
100
113
  crops, loc_preds = self._prepare_crops(
101
114
  pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
@@ -114,19 +127,9 @@ class OCRPredictor(NestedObject, _OCRPredictor):
114
127
  languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
115
128
  else:
116
129
  languages_dict = None
117
- # Rotate back pages and boxes while keeping original image size
118
- if self.straighten_pages:
119
- boxes = [
120
- rotate_boxes(
121
- page_boxes,
122
- angle,
123
- orig_shape=page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:],
124
- target_shape=mask, # type: ignore[arg-type]
125
- )
126
- for page_boxes, page, angle, mask in zip(boxes, pages, origin_page_orientations, origin_page_shapes)
127
- ]
128
130
 
129
131
  out = self.doc_builder(
132
+ pages,
130
133
  boxes,
131
134
  text_preds,
132
135
  origin_page_shapes, # type: ignore[arg-type]
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -22,6 +22,7 @@ class PreProcessor(nn.Module):
22
22
  """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
23
23
 
24
24
  Args:
25
+ ----
25
26
  output_size: expected size of each page in format (H, W)
26
27
  batch_size: the size of page batches
27
28
  mean: mean value of the training distribution by channel
@@ -34,7 +35,6 @@ class PreProcessor(nn.Module):
34
35
  batch_size: int,
35
36
  mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
36
37
  std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
37
- fp16: bool = False,
38
38
  **kwargs: Any,
39
39
  ) -> None:
40
40
  super().__init__()
@@ -47,12 +47,13 @@ class PreProcessor(nn.Module):
47
47
  """Gather samples into batches for inference purposes
48
48
 
49
49
  Args:
50
+ ----
50
51
  samples: list of samples of shape (C, H, W)
51
52
 
52
53
  Returns:
54
+ -------
53
55
  list of batched samples (*, C, H, W)
54
56
  """
55
-
56
57
  num_batches = int(math.ceil(len(samples) / self.batch_size))
57
58
  batches = [
58
59
  torch.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], dim=0)
@@ -78,17 +79,19 @@ class PreProcessor(nn.Module):
78
79
  else:
79
80
  x = x.to(dtype=torch.float32) # type: ignore[union-attr]
80
81
 
81
- return x
82
+ return x # type: ignore[return-value]
82
83
 
83
84
  def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]) -> List[torch.Tensor]:
84
85
  """Prepare document data for model forwarding
85
86
 
86
87
  Args:
88
+ ----
87
89
  x: list of images (np.array) or tensors (already resized and batched)
90
+
88
91
  Returns:
92
+ -------
89
93
  list of page batches
90
94
  """
91
-
92
95
  # Input type check
93
96
  if isinstance(x, (np.ndarray, torch.Tensor)):
94
97
  if x.ndim != 4:
@@ -100,8 +103,10 @@ class PreProcessor(nn.Module):
100
103
  elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
101
104
  raise TypeError("unsupported data type for torch.Tensor")
102
105
  # Resizing
103
- if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
104
- x = F.resize(x, self.resize.size, interpolation=self.resize.interpolation)
106
+ if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: # type: ignore[union-attr]
107
+ x = F.resize(
108
+ x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
109
+ )
105
110
  # Data type
106
111
  if x.dtype == torch.uint8: # type: ignore[union-attr]
107
112
  x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr]
@@ -113,11 +118,11 @@ class PreProcessor(nn.Module):
113
118
  # Sample transform (to tensor, resize)
114
119
  samples = list(multithread_exec(self.sample_transforms, x))
115
120
  # Batching
116
- batches = self.batch_inputs(samples)
121
+ batches = self.batch_inputs(samples) # type: ignore[assignment]
117
122
  else:
118
123
  raise TypeError(f"invalid input type: {type(x)}")
119
124
 
120
125
  # Batch transforms (normalize)
121
126
  batches = list(multithread_exec(self.normalize, batches))
122
127
 
123
- return batches
128
+ return batches # type: ignore[return-value]
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -20,6 +20,7 @@ class PreProcessor(NestedObject):
20
20
  """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
21
21
 
22
22
  Args:
23
+ ----
23
24
  output_size: expected size of each page in format (H, W)
24
25
  batch_size: the size of page batches
25
26
  mean: mean value of the training distribution by channel
@@ -34,7 +35,6 @@ class PreProcessor(NestedObject):
34
35
  batch_size: int,
35
36
  mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
36
37
  std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
37
- fp16: bool = False,
38
38
  **kwargs: Any,
39
39
  ) -> None:
40
40
  self.batch_size = batch_size
@@ -46,12 +46,13 @@ class PreProcessor(NestedObject):
46
46
  """Gather samples into batches for inference purposes
47
47
 
48
48
  Args:
49
+ ----
49
50
  samples: list of samples (tf.Tensor)
50
51
 
51
52
  Returns:
53
+ -------
52
54
  list of batched samples
53
55
  """
54
-
55
56
  num_batches = int(math.ceil(len(samples) / self.batch_size))
56
57
  batches = [
57
58
  tf.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], axis=0)
@@ -81,11 +82,13 @@ class PreProcessor(NestedObject):
81
82
  """Prepare document data for model forwarding
82
83
 
83
84
  Args:
85
+ ----
84
86
  x: list of images (np.array) or tensors (already resized and batched)
87
+
85
88
  Returns:
89
+ -------
86
90
  list of page batches
87
91
  """
88
-
89
92
  # Input type check
90
93
  if isinstance(x, (np.ndarray, tf.Tensor)):
91
94
  if x.ndim != 4:
@@ -102,7 +105,9 @@ class PreProcessor(NestedObject):
102
105
  x = tf.image.convert_image_dtype(x, dtype=tf.float32)
103
106
  # Resizing
104
107
  if (x.shape[1], x.shape[2]) != self.resize.output_size:
105
- x = tf.image.resize(x, self.resize.output_size, method=self.resize.method)
108
+ x = tf.image.resize(
109
+ x, self.resize.output_size, method=self.resize.method, antialias=self.resize.antialias
110
+ )
106
111
 
107
112
  batches = [x]
108
113
 
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -27,9 +27,11 @@ class RecognitionModel(NestedObject):
27
27
  sequence lengths.
28
28
 
29
29
  Args:
30
+ ----
30
31
  gts: list of ground-truth labels
31
32
 
32
33
  Returns:
34
+ -------
33
35
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
34
36
  """
35
37
  encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab))
@@ -41,6 +43,7 @@ class RecognitionPostProcessor(NestedObject):
41
43
  """Abstract class to postprocess the raw output of the model
42
44
 
43
45
  Args:
46
+ ----
44
47
  vocab: string containing the ordered sequence of supported characters
45
48
  """
46
49