python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, Dict, List, Optional, Tuple
6
+ from collections.abc import Callable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
 
@@ -21,7 +22,6 @@ class _OCRPredictor:
21
22
  """Implements an object able to localize and identify text elements in a set of documents
22
23
 
23
24
  Args:
24
- ----
25
25
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
26
26
  without rotated textual elements.
27
27
  straighten_pages: if True, estimates the page general orientation based on the median line orientation.
@@ -34,8 +34,8 @@ class _OCRPredictor:
34
34
  **kwargs: keyword args of `DocumentBuilder`
35
35
  """
36
36
 
37
- crop_orientation_predictor: Optional[OrientationPredictor]
38
- page_orientation_predictor: Optional[OrientationPredictor]
37
+ crop_orientation_predictor: OrientationPredictor | None
38
+ page_orientation_predictor: OrientationPredictor | None
39
39
 
40
40
  def __init__(
41
41
  self,
@@ -63,12 +63,12 @@ class _OCRPredictor:
63
63
  self.doc_builder = DocumentBuilder(**kwargs)
64
64
  self.preserve_aspect_ratio = preserve_aspect_ratio
65
65
  self.symmetric_pad = symmetric_pad
66
- self.hooks: List[Callable] = []
66
+ self.hooks: list[Callable] = []
67
67
 
68
68
  def _general_page_orientations(
69
69
  self,
70
- pages: List[np.ndarray],
71
- ) -> List[Tuple[int, float]]:
70
+ pages: list[np.ndarray],
71
+ ) -> list[tuple[int, float]]:
72
72
  _, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
73
73
  # Flatten to list of tuples with (value, confidence)
74
74
  page_orientations = [
@@ -79,8 +79,8 @@ class _OCRPredictor:
79
79
  return page_orientations
80
80
 
81
81
  def _get_orientations(
82
- self, pages: List[np.ndarray], seg_maps: List[np.ndarray]
83
- ) -> Tuple[List[Tuple[int, float]], List[int]]:
82
+ self, pages: list[np.ndarray], seg_maps: list[np.ndarray]
83
+ ) -> tuple[list[tuple[int, float]], list[int]]:
84
84
  general_pages_orientations = self._general_page_orientations(pages)
85
85
  origin_page_orientations = [
86
86
  estimate_orientation(seq_map, general_orientation)
@@ -90,11 +90,11 @@ class _OCRPredictor:
90
90
 
91
91
  def _straighten_pages(
92
92
  self,
93
- pages: List[np.ndarray],
94
- seg_maps: List[np.ndarray],
95
- general_pages_orientations: Optional[List[Tuple[int, float]]] = None,
96
- origin_pages_orientations: Optional[List[int]] = None,
97
- ) -> List[np.ndarray]:
93
+ pages: list[np.ndarray],
94
+ seg_maps: list[np.ndarray],
95
+ general_pages_orientations: list[tuple[int, float]] | None = None,
96
+ origin_pages_orientations: list[int] | None = None,
97
+ ) -> list[np.ndarray]:
98
98
  general_pages_orientations = (
99
99
  general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
100
100
  )
@@ -114,12 +114,12 @@ class _OCRPredictor:
114
114
 
115
115
  @staticmethod
116
116
  def _generate_crops(
117
- pages: List[np.ndarray],
118
- loc_preds: List[np.ndarray],
117
+ pages: list[np.ndarray],
118
+ loc_preds: list[np.ndarray],
119
119
  channels_last: bool,
120
120
  assume_straight_pages: bool = False,
121
121
  assume_horizontal: bool = False,
122
- ) -> List[List[np.ndarray]]:
122
+ ) -> list[list[np.ndarray]]:
123
123
  if assume_straight_pages:
124
124
  crops = [
125
125
  extract_crops(page, _boxes[:, :4], channels_last=channels_last)
@@ -134,12 +134,12 @@ class _OCRPredictor:
134
134
 
135
135
  @staticmethod
136
136
  def _prepare_crops(
137
- pages: List[np.ndarray],
138
- loc_preds: List[np.ndarray],
137
+ pages: list[np.ndarray],
138
+ loc_preds: list[np.ndarray],
139
139
  channels_last: bool,
140
140
  assume_straight_pages: bool = False,
141
141
  assume_horizontal: bool = False,
142
- ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
142
+ ) -> tuple[list[list[np.ndarray]], list[np.ndarray]]:
143
143
  crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)
144
144
 
145
145
  # Avoid sending zero-sized crops
@@ -154,9 +154,9 @@ class _OCRPredictor:
154
154
 
155
155
  def _rectify_crops(
156
156
  self,
157
- crops: List[List[np.ndarray]],
158
- loc_preds: List[np.ndarray],
159
- ) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]:
157
+ crops: list[list[np.ndarray]],
158
+ loc_preds: list[np.ndarray],
159
+ ) -> tuple[list[list[np.ndarray]], list[np.ndarray], list[tuple[int, float]]]:
160
160
  # Work at a page level
161
161
  orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc]
162
162
  rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
@@ -174,10 +174,10 @@ class _OCRPredictor:
174
174
 
175
175
  @staticmethod
176
176
  def _process_predictions(
177
- loc_preds: List[np.ndarray],
178
- word_preds: List[Tuple[str, float]],
179
- crop_orientations: List[Dict[str, Any]],
180
- ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]:
177
+ loc_preds: list[np.ndarray],
178
+ word_preds: list[tuple[str, float]],
179
+ crop_orientations: list[dict[str, Any]],
180
+ ) -> tuple[list[np.ndarray], list[list[tuple[str, float]]], list[list[dict[str, Any]]]]:
181
181
  text_preds = []
182
182
  crop_orientation_preds = []
183
183
  if len(loc_preds) > 0:
@@ -194,7 +194,6 @@ class _OCRPredictor:
194
194
  """Add a hook to the predictor
195
195
 
196
196
  Args:
197
- ----
198
197
  hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
199
198
  """
200
199
  self.hooks.append(hook)
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -24,7 +24,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
- ----
28
27
  det_predictor: detection module
29
28
  reco_predictor: recognition module
30
29
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -52,8 +51,8 @@ class OCRPredictor(nn.Module, _OCRPredictor):
52
51
  **kwargs: Any,
53
52
  ) -> None:
54
53
  nn.Module.__init__(self)
55
- self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
56
- self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
54
+ self.det_predictor = det_predictor.eval()
55
+ self.reco_predictor = reco_predictor.eval()
57
56
  _OCRPredictor.__init__(
58
57
  self,
59
58
  assume_straight_pages,
@@ -69,7 +68,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
69
68
  @torch.inference_mode()
70
69
  def forward(
71
70
  self,
72
- pages: List[Union[np.ndarray, torch.Tensor]],
71
+ pages: list[np.ndarray | torch.Tensor],
73
72
  **kwargs: Any,
74
73
  ) -> Document:
75
74
  # Dimension check
@@ -87,7 +86,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
87
86
  for out_map in out_maps
88
87
  ]
89
88
  if self.detect_orientation:
90
- general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
89
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
91
90
  orientations = [
92
91
  {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
93
92
  ]
@@ -96,16 +95,16 @@ class OCRPredictor(nn.Module, _OCRPredictor):
96
95
  general_pages_orientations = None
97
96
  origin_pages_orientations = None
98
97
  if self.straighten_pages:
99
- pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
98
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
100
99
  # update page shapes after straightening
101
100
  origin_page_shapes = [page.shape[:2] for page in pages]
102
101
 
103
102
  # Forward again to get predictions on straight pages
104
103
  loc_preds = self.det_predictor(pages, **kwargs)
105
104
 
106
- assert all(
107
- len(loc_pred) == 1 for loc_pred in loc_preds
108
- ), "Detection Model in ocr_predictor should output only one class"
105
+ assert all(len(loc_pred) == 1 for loc_pred in loc_preds), (
106
+ "Detection Model in ocr_predictor should output only one class"
107
+ )
109
108
 
110
109
  loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
111
110
  # Detach objectness scores from loc_preds
@@ -119,7 +118,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
119
118
 
120
119
  # Crop images
121
120
  crops, loc_preds = self._prepare_crops(
122
- pages, # type: ignore[arg-type]
121
+ pages,
123
122
  loc_preds,
124
123
  channels_last=channels_last,
125
124
  assume_straight_pages=self.assume_straight_pages,
@@ -147,11 +146,11 @@ class OCRPredictor(nn.Module, _OCRPredictor):
147
146
  languages_dict = None
148
147
 
149
148
  out = self.doc_builder(
150
- pages, # type: ignore[arg-type]
149
+ pages,
151
150
  boxes,
152
151
  objectness_scores,
153
152
  text_preds,
154
- origin_page_shapes, # type: ignore[arg-type]
153
+ origin_page_shapes,
155
154
  crop_orientations,
156
155
  orientations,
157
156
  languages_dict,
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -24,7 +24,6 @@ class OCRPredictor(NestedObject, _OCRPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
- ----
28
27
  det_predictor: detection module
29
28
  reco_predictor: recognition module
30
29
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -69,7 +68,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
69
68
 
70
69
  def __call__(
71
70
  self,
72
- pages: List[Union[np.ndarray, tf.Tensor]],
71
+ pages: list[np.ndarray | tf.Tensor],
73
72
  **kwargs: Any,
74
73
  ) -> Document:
75
74
  # Dimension check
@@ -101,12 +100,12 @@ class OCRPredictor(NestedObject, _OCRPredictor):
101
100
  origin_page_shapes = [page.shape[:2] for page in pages]
102
101
 
103
102
  # forward again to get predictions on straight pages
104
- loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
103
+ loc_preds_dict = self.det_predictor(pages, **kwargs)
105
104
 
106
- assert all(
107
- len(loc_pred) == 1 for loc_pred in loc_preds_dict
108
- ), "Detection Model in ocr_predictor should output only one class"
109
- loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr]
105
+ assert all(len(loc_pred) == 1 for loc_pred in loc_preds_dict), (
106
+ "Detection Model in ocr_predictor should output only one class"
107
+ )
108
+ loc_preds: list[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict]
110
109
  # Detach objectness scores from loc_preds
111
110
  loc_preds, objectness_scores = detach_scores(loc_preds)
112
111
 
@@ -148,7 +147,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
148
147
  boxes,
149
148
  objectness_scores,
150
149
  text_preds,
151
- origin_page_shapes, # type: ignore[arg-type]
150
+ origin_page_shapes,
152
151
  crop_orientations,
153
152
  orientations,
154
153
  languages_dict,
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, List, Tuple, Union
7
+ from typing import Any
8
8
 
9
9
  import numpy as np
10
10
  import torch
@@ -22,19 +22,19 @@ class PreProcessor(nn.Module):
22
22
  """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
23
23
 
24
24
  Args:
25
- ----
26
25
  output_size: expected size of each page in format (H, W)
27
26
  batch_size: the size of page batches
28
27
  mean: mean value of the training distribution by channel
29
28
  std: standard deviation of the training distribution by channel
29
+ **kwargs: additional arguments for the resizing operation
30
30
  """
31
31
 
32
32
  def __init__(
33
33
  self,
34
- output_size: Tuple[int, int],
34
+ output_size: tuple[int, int],
35
35
  batch_size: int,
36
- mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
37
- std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
36
+ mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
37
+ std: tuple[float, float, float] = (1.0, 1.0, 1.0),
38
38
  **kwargs: Any,
39
39
  ) -> None:
40
40
  super().__init__()
@@ -43,15 +43,13 @@ class PreProcessor(nn.Module):
43
43
  # Perform the division by 255 at the same time
44
44
  self.normalize = T.Normalize(mean, std)
45
45
 
46
- def batch_inputs(self, samples: List[torch.Tensor]) -> List[torch.Tensor]:
46
+ def batch_inputs(self, samples: list[torch.Tensor]) -> list[torch.Tensor]:
47
47
  """Gather samples into batches for inference purposes
48
48
 
49
49
  Args:
50
- ----
51
50
  samples: list of samples of shape (C, H, W)
52
51
 
53
52
  Returns:
54
- -------
55
53
  list of batched samples (*, C, H, W)
56
54
  """
57
55
  num_batches = int(math.ceil(len(samples) / self.batch_size))
@@ -62,7 +60,7 @@ class PreProcessor(nn.Module):
62
60
 
63
61
  return batches
64
62
 
65
- def sample_transforms(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
63
+ def sample_transforms(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
66
64
  if x.ndim != 3:
67
65
  raise AssertionError("expected list of 3D Tensors")
68
66
  if isinstance(x, np.ndarray):
@@ -79,17 +77,15 @@ class PreProcessor(nn.Module):
79
77
  else:
80
78
  x = x.to(dtype=torch.float32) # type: ignore[union-attr]
81
79
 
82
- return x
80
+ return x # type: ignore[return-value]
83
81
 
84
- def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]) -> List[torch.Tensor]:
82
+ def __call__(self, x: torch.Tensor | np.ndarray | list[torch.Tensor | np.ndarray]) -> list[torch.Tensor]:
85
83
  """Prepare document data for model forwarding
86
84
 
87
85
  Args:
88
- ----
89
86
  x: list of images (np.array) or tensors (already resized and batched)
90
87
 
91
88
  Returns:
92
- -------
93
89
  list of page batches
94
90
  """
95
91
  # Input type check
@@ -103,7 +99,7 @@ class PreProcessor(nn.Module):
103
99
  elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
104
100
  raise TypeError("unsupported data type for torch.Tensor")
105
101
  # Resizing
106
- if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
102
+ if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: # type: ignore[union-attr]
107
103
  x = F.resize(
108
104
  x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
109
105
  )
@@ -118,11 +114,11 @@ class PreProcessor(nn.Module):
118
114
  # Sample transform (to tensor, resize)
119
115
  samples = list(multithread_exec(self.sample_transforms, x))
120
116
  # Batching
121
- batches = self.batch_inputs(samples)
117
+ batches = self.batch_inputs(samples) # type: ignore[assignment]
122
118
  else:
123
119
  raise TypeError(f"invalid input type: {type(x)}")
124
120
 
125
121
  # Batch transforms (normalize)
126
122
  batches = list(multithread_exec(self.normalize, batches))
127
123
 
128
- return batches
124
+ return batches # type: ignore[return-value]
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, List, Tuple, Union
7
+ from typing import Any
8
8
 
9
9
  import numpy as np
10
10
  import tensorflow as tf
@@ -20,21 +20,21 @@ class PreProcessor(NestedObject):
20
20
  """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
21
21
 
22
22
  Args:
23
- ----
24
23
  output_size: expected size of each page in format (H, W)
25
24
  batch_size: the size of page batches
26
25
  mean: mean value of the training distribution by channel
27
26
  std: standard deviation of the training distribution by channel
27
+ **kwargs: additional arguments for the resizing operation
28
28
  """
29
29
 
30
- _children_names: List[str] = ["resize", "normalize"]
30
+ _children_names: list[str] = ["resize", "normalize"]
31
31
 
32
32
  def __init__(
33
33
  self,
34
- output_size: Tuple[int, int],
34
+ output_size: tuple[int, int],
35
35
  batch_size: int,
36
- mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
37
- std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
36
+ mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
37
+ std: tuple[float, float, float] = (1.0, 1.0, 1.0),
38
38
  **kwargs: Any,
39
39
  ) -> None:
40
40
  self.batch_size = batch_size
@@ -43,15 +43,13 @@ class PreProcessor(NestedObject):
43
43
  self.normalize = Normalize(mean, std)
44
44
  self._runs_on_cuda = tf.config.list_physical_devices("GPU") != []
45
45
 
46
- def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]:
46
+ def batch_inputs(self, samples: list[tf.Tensor]) -> list[tf.Tensor]:
47
47
  """Gather samples into batches for inference purposes
48
48
 
49
49
  Args:
50
- ----
51
50
  samples: list of samples (tf.Tensor)
52
51
 
53
52
  Returns:
54
- -------
55
53
  list of batched samples
56
54
  """
57
55
  num_batches = int(math.ceil(len(samples) / self.batch_size))
@@ -62,7 +60,7 @@ class PreProcessor(NestedObject):
62
60
 
63
61
  return batches
64
62
 
65
- def sample_transforms(self, x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor:
63
+ def sample_transforms(self, x: np.ndarray | tf.Tensor) -> tf.Tensor:
66
64
  if x.ndim != 3:
67
65
  raise AssertionError("expected list of 3D Tensors")
68
66
  if isinstance(x, np.ndarray):
@@ -79,15 +77,13 @@ class PreProcessor(NestedObject):
79
77
 
80
78
  return x
81
79
 
82
- def __call__(self, x: Union[tf.Tensor, np.ndarray, List[Union[tf.Tensor, np.ndarray]]]) -> List[tf.Tensor]:
80
+ def __call__(self, x: tf.Tensor | np.ndarray | list[tf.Tensor | np.ndarray]) -> list[tf.Tensor]:
83
81
  """Prepare document data for model forwarding
84
82
 
85
83
  Args:
86
- ----
87
84
  x: list of images (np.array) or tensors (already resized and batched)
88
85
 
89
86
  Returns:
90
- -------
91
87
  list of page batches
92
88
  """
93
89
  # Input type check
@@ -3,4 +3,5 @@ from .master import *
3
3
  from .sar import *
4
4
  from .vitstr import *
5
5
  from .parseq import *
6
+ from .viptr import *
6
7
  from .zoo import *
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple
7
6
 
8
7
  import numpy as np
9
8
 
@@ -21,17 +20,15 @@ class RecognitionModel(NestedObject):
21
20
 
22
21
  def build_target(
23
22
  self,
24
- gts: List[str],
25
- ) -> Tuple[np.ndarray, List[int]]:
23
+ gts: list[str],
24
+ ) -> tuple[np.ndarray, list[int]]:
26
25
  """Encode a list of gts sequences into a np array and gives the corresponding*
27
26
  sequence lengths.
28
27
 
29
28
  Args:
30
- ----
31
29
  gts: list of ground-truth labels
32
30
 
33
31
  Returns:
34
- -------
35
32
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
36
33
  """
37
34
  encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab))
@@ -43,7 +40,6 @@ class RecognitionPostProcessor(NestedObject):
43
40
  """Abstract class to postprocess the raw output of the model
44
41
 
45
42
  Args:
46
- ----
47
43
  vocab: string containing the ordered sequence of supported characters
48
44
  """
49
45
 
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]