python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,14 +1,15 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, Dict, List, Optional, Tuple
6
+ from collections.abc import Callable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
 
10
11
  from doctr.models.builder import DocumentBuilder
11
- from doctr.utils.geometry import extract_crops, extract_rcrops, rotate_image
12
+ from doctr.utils.geometry import extract_crops, extract_rcrops, remove_image_padding, rotate_image
12
13
 
13
14
  from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
14
15
  from ..classification import crop_orientation_predictor, page_orientation_predictor
@@ -21,7 +22,6 @@ class _OCRPredictor:
21
22
  """Implements an object able to localize and identify text elements in a set of documents
22
23
 
23
24
  Args:
24
- ----
25
25
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
26
26
  without rotated textual elements.
27
27
  straighten_pages: if True, estimates the page general orientation based on the median line orientation.
@@ -34,8 +34,8 @@ class _OCRPredictor:
34
34
  **kwargs: keyword args of `DocumentBuilder`
35
35
  """
36
36
 
37
- crop_orientation_predictor: Optional[OrientationPredictor]
38
- page_orientation_predictor: Optional[OrientationPredictor]
37
+ crop_orientation_predictor: OrientationPredictor | None
38
+ page_orientation_predictor: OrientationPredictor | None
39
39
 
40
40
  def __init__(
41
41
  self,
@@ -48,21 +48,27 @@ class _OCRPredictor:
48
48
  ) -> None:
49
49
  self.assume_straight_pages = assume_straight_pages
50
50
  self.straighten_pages = straighten_pages
51
- self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True)
51
+ self._page_orientation_disabled = kwargs.pop("disable_page_orientation", False)
52
+ self._crop_orientation_disabled = kwargs.pop("disable_crop_orientation", False)
53
+ self.crop_orientation_predictor = (
54
+ None
55
+ if assume_straight_pages
56
+ else crop_orientation_predictor(pretrained=True, disabled=self._crop_orientation_disabled)
57
+ )
52
58
  self.page_orientation_predictor = (
53
- page_orientation_predictor(pretrained=True)
59
+ page_orientation_predictor(pretrained=True, disabled=self._page_orientation_disabled)
54
60
  if detect_orientation or straighten_pages or not assume_straight_pages
55
61
  else None
56
62
  )
57
63
  self.doc_builder = DocumentBuilder(**kwargs)
58
64
  self.preserve_aspect_ratio = preserve_aspect_ratio
59
65
  self.symmetric_pad = symmetric_pad
60
- self.hooks: List[Callable] = []
66
+ self.hooks: list[Callable] = []
61
67
 
62
68
  def _general_page_orientations(
63
69
  self,
64
- pages: List[np.ndarray],
65
- ) -> List[Tuple[int, float]]:
70
+ pages: list[np.ndarray],
71
+ ) -> list[tuple[int, float]]:
66
72
  _, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
67
73
  # Flatten to list of tuples with (value, confidence)
68
74
  page_orientations = [
@@ -73,8 +79,8 @@ class _OCRPredictor:
73
79
  return page_orientations
74
80
 
75
81
  def _get_orientations(
76
- self, pages: List[np.ndarray], seg_maps: List[np.ndarray]
77
- ) -> Tuple[List[Tuple[int, float]], List[int]]:
82
+ self, pages: list[np.ndarray], seg_maps: list[np.ndarray]
83
+ ) -> tuple[list[tuple[int, float]], list[int]]:
78
84
  general_pages_orientations = self._general_page_orientations(pages)
79
85
  origin_page_orientations = [
80
86
  estimate_orientation(seq_map, general_orientation)
@@ -84,11 +90,11 @@ class _OCRPredictor:
84
90
 
85
91
  def _straighten_pages(
86
92
  self,
87
- pages: List[np.ndarray],
88
- seg_maps: List[np.ndarray],
89
- general_pages_orientations: Optional[List[Tuple[int, float]]] = None,
90
- origin_pages_orientations: Optional[List[int]] = None,
91
- ) -> List[np.ndarray]:
93
+ pages: list[np.ndarray],
94
+ seg_maps: list[np.ndarray],
95
+ general_pages_orientations: list[tuple[int, float]] | None = None,
96
+ origin_pages_orientations: list[int] | None = None,
97
+ ) -> list[np.ndarray]:
92
98
  general_pages_orientations = (
93
99
  general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
94
100
  )
@@ -101,34 +107,40 @@ class _OCRPredictor:
101
107
  ]
102
108
  )
103
109
  return [
104
- # We exapnd if the page is wider than tall and the angle is 90 or -90
105
- rotate_image(page, angle, expand=page.shape[1] > page.shape[0] and abs(angle) == 90)
110
+ # expand if height and width are not equal, then remove the padding
111
+ remove_image_padding(rotate_image(page, angle, expand=page.shape[0] != page.shape[1]))
106
112
  for page, angle in zip(pages, origin_pages_orientations)
107
113
  ]
108
114
 
109
115
  @staticmethod
110
116
  def _generate_crops(
111
- pages: List[np.ndarray],
112
- loc_preds: List[np.ndarray],
117
+ pages: list[np.ndarray],
118
+ loc_preds: list[np.ndarray],
113
119
  channels_last: bool,
114
120
  assume_straight_pages: bool = False,
115
- ) -> List[List[np.ndarray]]:
116
- extraction_fn = extract_crops if assume_straight_pages else extract_rcrops
117
-
118
- crops = [
119
- extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator]
120
- for page, _boxes in zip(pages, loc_preds)
121
- ]
121
+ assume_horizontal: bool = False,
122
+ ) -> list[list[np.ndarray]]:
123
+ if assume_straight_pages:
124
+ crops = [
125
+ extract_crops(page, _boxes[:, :4], channels_last=channels_last)
126
+ for page, _boxes in zip(pages, loc_preds)
127
+ ]
128
+ else:
129
+ crops = [
130
+ extract_rcrops(page, _boxes[:, :4], channels_last=channels_last, assume_horizontal=assume_horizontal)
131
+ for page, _boxes in zip(pages, loc_preds)
132
+ ]
122
133
  return crops
123
134
 
124
135
  @staticmethod
125
136
  def _prepare_crops(
126
- pages: List[np.ndarray],
127
- loc_preds: List[np.ndarray],
137
+ pages: list[np.ndarray],
138
+ loc_preds: list[np.ndarray],
128
139
  channels_last: bool,
129
140
  assume_straight_pages: bool = False,
130
- ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
131
- crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages)
141
+ assume_horizontal: bool = False,
142
+ ) -> tuple[list[list[np.ndarray]], list[np.ndarray]]:
143
+ crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)
132
144
 
133
145
  # Avoid sending zero-sized crops
134
146
  is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
@@ -142,9 +154,9 @@ class _OCRPredictor:
142
154
 
143
155
  def _rectify_crops(
144
156
  self,
145
- crops: List[List[np.ndarray]],
146
- loc_preds: List[np.ndarray],
147
- ) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]:
157
+ crops: list[list[np.ndarray]],
158
+ loc_preds: list[np.ndarray],
159
+ ) -> tuple[list[list[np.ndarray]], list[np.ndarray], list[tuple[int, float]]]:
148
160
  # Work at a page level
149
161
  orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc]
150
162
  rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
@@ -162,10 +174,10 @@ class _OCRPredictor:
162
174
 
163
175
  @staticmethod
164
176
  def _process_predictions(
165
- loc_preds: List[np.ndarray],
166
- word_preds: List[Tuple[str, float]],
167
- crop_orientations: List[Dict[str, Any]],
168
- ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]:
177
+ loc_preds: list[np.ndarray],
178
+ word_preds: list[tuple[str, float]],
179
+ crop_orientations: list[dict[str, Any]],
180
+ ) -> tuple[list[np.ndarray], list[list[tuple[str, float]]], list[list[dict[str, Any]]]]:
169
181
  text_preds = []
170
182
  crop_orientation_preds = []
171
183
  if len(loc_preds) > 0:
@@ -182,7 +194,6 @@ class _OCRPredictor:
182
194
  """Add a hook to the predictor
183
195
 
184
196
  Args:
185
- ----
186
197
  hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
187
198
  """
188
199
  self.hooks.append(hook)
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -24,7 +24,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
- ----
28
27
  det_predictor: detection module
29
28
  reco_predictor: recognition module
30
29
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -52,8 +51,8 @@ class OCRPredictor(nn.Module, _OCRPredictor):
52
51
  **kwargs: Any,
53
52
  ) -> None:
54
53
  nn.Module.__init__(self)
55
- self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
56
- self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
54
+ self.det_predictor = det_predictor.eval()
55
+ self.reco_predictor = reco_predictor.eval()
57
56
  _OCRPredictor.__init__(
58
57
  self,
59
58
  assume_straight_pages,
@@ -69,7 +68,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
69
68
  @torch.inference_mode()
70
69
  def forward(
71
70
  self,
72
- pages: List[Union[np.ndarray, torch.Tensor]],
71
+ pages: list[np.ndarray | torch.Tensor],
73
72
  **kwargs: Any,
74
73
  ) -> Document:
75
74
  # Dimension check
@@ -87,7 +86,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
87
86
  for out_map in out_maps
88
87
  ]
89
88
  if self.detect_orientation:
90
- general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
89
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
91
90
  orientations = [
92
91
  {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
93
92
  ]
@@ -96,13 +95,16 @@ class OCRPredictor(nn.Module, _OCRPredictor):
96
95
  general_pages_orientations = None
97
96
  origin_pages_orientations = None
98
97
  if self.straighten_pages:
99
- pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
98
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
99
+ # update page shapes after straightening
100
+ origin_page_shapes = [page.shape[:2] for page in pages]
101
+
100
102
  # Forward again to get predictions on straight pages
101
103
  loc_preds = self.det_predictor(pages, **kwargs)
102
104
 
103
- assert all(
104
- len(loc_pred) == 1 for loc_pred in loc_preds
105
- ), "Detection Model in ocr_predictor should output only one class"
105
+ assert all(len(loc_pred) == 1 for loc_pred in loc_preds), (
106
+ "Detection Model in ocr_predictor should output only one class"
107
+ )
106
108
 
107
109
  loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
108
110
  # Detach objectness scores from loc_preds
@@ -116,10 +118,11 @@ class OCRPredictor(nn.Module, _OCRPredictor):
116
118
 
117
119
  # Crop images
118
120
  crops, loc_preds = self._prepare_crops(
119
- pages, # type: ignore[arg-type]
121
+ pages,
120
122
  loc_preds,
121
123
  channels_last=channels_last,
122
124
  assume_straight_pages=self.assume_straight_pages,
125
+ assume_horizontal=self._page_orientation_disabled,
123
126
  )
124
127
  # Rectify crop orientation and get crop orientation predictions
125
128
  crop_orientations: Any = []
@@ -143,7 +146,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
143
146
  languages_dict = None
144
147
 
145
148
  out = self.doc_builder(
146
- pages, # type: ignore[arg-type]
149
+ pages,
147
150
  boxes,
148
151
  objectness_scores,
149
152
  text_preds,
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -24,7 +24,6 @@ class OCRPredictor(NestedObject, _OCRPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
- ----
28
27
  det_predictor: detection module
29
28
  reco_predictor: recognition module
30
29
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -69,7 +68,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
69
68
 
70
69
  def __call__(
71
70
  self,
72
- pages: List[Union[np.ndarray, tf.Tensor]],
71
+ pages: list[np.ndarray | tf.Tensor],
73
72
  **kwargs: Any,
74
73
  ) -> Document:
75
74
  # Dimension check
@@ -97,13 +96,16 @@ class OCRPredictor(NestedObject, _OCRPredictor):
97
96
  origin_pages_orientations = None
98
97
  if self.straighten_pages:
99
98
  pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
99
+ # update page shapes after straightening
100
+ origin_page_shapes = [page.shape[:2] for page in pages]
101
+
100
102
  # forward again to get predictions on straight pages
101
- loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
103
+ loc_preds_dict = self.det_predictor(pages, **kwargs)
102
104
 
103
- assert all(
104
- len(loc_pred) == 1 for loc_pred in loc_preds_dict
105
- ), "Detection Model in ocr_predictor should output only one class"
106
- loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr]
105
+ assert all(len(loc_pred) == 1 for loc_pred in loc_preds_dict), (
106
+ "Detection Model in ocr_predictor should output only one class"
107
+ )
108
+ loc_preds: list[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict]
107
109
  # Detach objectness scores from loc_preds
108
110
  loc_preds, objectness_scores = detach_scores(loc_preds)
109
111
 
@@ -113,7 +115,11 @@ class OCRPredictor(NestedObject, _OCRPredictor):
113
115
 
114
116
  # Crop images
115
117
  crops, loc_preds = self._prepare_crops(
116
- pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
118
+ pages,
119
+ loc_preds,
120
+ channels_last=True,
121
+ assume_straight_pages=self.assume_straight_pages,
122
+ assume_horizontal=self._page_orientation_disabled,
117
123
  )
118
124
  # Rectify crop orientation and get crop orientation predictions
119
125
  crop_orientations: Any = []
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, List, Tuple, Union
7
+ from typing import Any
8
8
 
9
9
  import numpy as np
10
10
  import torch
@@ -22,19 +22,19 @@ class PreProcessor(nn.Module):
22
22
  """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
23
23
 
24
24
  Args:
25
- ----
26
25
  output_size: expected size of each page in format (H, W)
27
26
  batch_size: the size of page batches
28
27
  mean: mean value of the training distribution by channel
29
28
  std: standard deviation of the training distribution by channel
29
+ **kwargs: additional arguments for the resizing operation
30
30
  """
31
31
 
32
32
  def __init__(
33
33
  self,
34
- output_size: Tuple[int, int],
34
+ output_size: tuple[int, int],
35
35
  batch_size: int,
36
- mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
37
- std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
36
+ mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
37
+ std: tuple[float, float, float] = (1.0, 1.0, 1.0),
38
38
  **kwargs: Any,
39
39
  ) -> None:
40
40
  super().__init__()
@@ -43,15 +43,13 @@ class PreProcessor(nn.Module):
43
43
  # Perform the division by 255 at the same time
44
44
  self.normalize = T.Normalize(mean, std)
45
45
 
46
- def batch_inputs(self, samples: List[torch.Tensor]) -> List[torch.Tensor]:
46
+ def batch_inputs(self, samples: list[torch.Tensor]) -> list[torch.Tensor]:
47
47
  """Gather samples into batches for inference purposes
48
48
 
49
49
  Args:
50
- ----
51
50
  samples: list of samples of shape (C, H, W)
52
51
 
53
52
  Returns:
54
- -------
55
53
  list of batched samples (*, C, H, W)
56
54
  """
57
55
  num_batches = int(math.ceil(len(samples) / self.batch_size))
@@ -62,7 +60,7 @@ class PreProcessor(nn.Module):
62
60
 
63
61
  return batches
64
62
 
65
- def sample_transforms(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
63
+ def sample_transforms(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
66
64
  if x.ndim != 3:
67
65
  raise AssertionError("expected list of 3D Tensors")
68
66
  if isinstance(x, np.ndarray):
@@ -79,17 +77,15 @@ class PreProcessor(nn.Module):
79
77
  else:
80
78
  x = x.to(dtype=torch.float32) # type: ignore[union-attr]
81
79
 
82
- return x
80
+ return x # type: ignore[return-value]
83
81
 
84
- def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]) -> List[torch.Tensor]:
82
+ def __call__(self, x: torch.Tensor | np.ndarray | list[torch.Tensor | np.ndarray]) -> list[torch.Tensor]:
85
83
  """Prepare document data for model forwarding
86
84
 
87
85
  Args:
88
- ----
89
86
  x: list of images (np.array) or tensors (already resized and batched)
90
87
 
91
88
  Returns:
92
- -------
93
89
  list of page batches
94
90
  """
95
91
  # Input type check
@@ -103,7 +99,7 @@ class PreProcessor(nn.Module):
103
99
  elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
104
100
  raise TypeError("unsupported data type for torch.Tensor")
105
101
  # Resizing
106
- if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
102
+ if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: # type: ignore[union-attr]
107
103
  x = F.resize(
108
104
  x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
109
105
  )
@@ -118,11 +114,11 @@ class PreProcessor(nn.Module):
118
114
  # Sample transform (to tensor, resize)
119
115
  samples = list(multithread_exec(self.sample_transforms, x))
120
116
  # Batching
121
- batches = self.batch_inputs(samples)
117
+ batches = self.batch_inputs(samples) # type: ignore[assignment]
122
118
  else:
123
119
  raise TypeError(f"invalid input type: {type(x)}")
124
120
 
125
121
  # Batch transforms (normalize)
126
122
  batches = list(multithread_exec(self.normalize, batches))
127
123
 
128
- return batches
124
+ return batches # type: ignore[return-value]
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, List, Tuple, Union
7
+ from typing import Any
8
8
 
9
9
  import numpy as np
10
10
  import tensorflow as tf
@@ -20,38 +20,36 @@ class PreProcessor(NestedObject):
20
20
  """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
21
21
 
22
22
  Args:
23
- ----
24
23
  output_size: expected size of each page in format (H, W)
25
24
  batch_size: the size of page batches
26
25
  mean: mean value of the training distribution by channel
27
26
  std: standard deviation of the training distribution by channel
27
+ **kwargs: additional arguments for the resizing operation
28
28
  """
29
29
 
30
- _children_names: List[str] = ["resize", "normalize"]
30
+ _children_names: list[str] = ["resize", "normalize"]
31
31
 
32
32
  def __init__(
33
33
  self,
34
- output_size: Tuple[int, int],
34
+ output_size: tuple[int, int],
35
35
  batch_size: int,
36
- mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
37
- std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
36
+ mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
37
+ std: tuple[float, float, float] = (1.0, 1.0, 1.0),
38
38
  **kwargs: Any,
39
39
  ) -> None:
40
40
  self.batch_size = batch_size
41
41
  self.resize = Resize(output_size, **kwargs)
42
42
  # Perform the division by 255 at the same time
43
43
  self.normalize = Normalize(mean, std)
44
- self._runs_on_cuda = tf.test.is_gpu_available()
44
+ self._runs_on_cuda = tf.config.list_physical_devices("GPU") != []
45
45
 
46
- def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]:
46
+ def batch_inputs(self, samples: list[tf.Tensor]) -> list[tf.Tensor]:
47
47
  """Gather samples into batches for inference purposes
48
48
 
49
49
  Args:
50
- ----
51
50
  samples: list of samples (tf.Tensor)
52
51
 
53
52
  Returns:
54
- -------
55
53
  list of batched samples
56
54
  """
57
55
  num_batches = int(math.ceil(len(samples) / self.batch_size))
@@ -62,7 +60,7 @@ class PreProcessor(NestedObject):
62
60
 
63
61
  return batches
64
62
 
65
- def sample_transforms(self, x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor:
63
+ def sample_transforms(self, x: np.ndarray | tf.Tensor) -> tf.Tensor:
66
64
  if x.ndim != 3:
67
65
  raise AssertionError("expected list of 3D Tensors")
68
66
  if isinstance(x, np.ndarray):
@@ -79,15 +77,13 @@ class PreProcessor(NestedObject):
79
77
 
80
78
  return x
81
79
 
82
- def __call__(self, x: Union[tf.Tensor, np.ndarray, List[Union[tf.Tensor, np.ndarray]]]) -> List[tf.Tensor]:
80
+ def __call__(self, x: tf.Tensor | np.ndarray | list[tf.Tensor | np.ndarray]) -> list[tf.Tensor]:
83
81
  """Prepare document data for model forwarding
84
82
 
85
83
  Args:
86
- ----
87
84
  x: list of images (np.array) or tensors (already resized and batched)
88
85
 
89
86
  Returns:
90
- -------
91
87
  list of page batches
92
88
  """
93
89
  # Input type check
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple
7
6
 
8
7
  import numpy as np
9
8
 
@@ -21,17 +20,15 @@ class RecognitionModel(NestedObject):
21
20
 
22
21
  def build_target(
23
22
  self,
24
- gts: List[str],
25
- ) -> Tuple[np.ndarray, List[int]]:
23
+ gts: list[str],
24
+ ) -> tuple[np.ndarray, list[int]]:
26
25
  """Encode a list of gts sequences into a np array and gives the corresponding*
27
26
  sequence lengths.
28
27
 
29
28
  Args:
30
- ----
31
29
  gts: list of ground-truth labels
32
30
 
33
31
  Returns:
34
- -------
35
32
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
36
33
  """
37
34
  encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab))
@@ -43,7 +40,6 @@ class RecognitionPostProcessor(NestedObject):
43
40
  """Abstract class to postprocess the raw output of the model
44
41
 
45
42
  Args:
46
- ----
47
43
  vocab: string containing the ordered sequence of supported characters
48
44
  """
49
45
 
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]