python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,6 +1,6 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- else:
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Optional
6
+ from typing import Any
7
7
 
8
8
  from doctr.models.builder import KIEDocumentBuilder
9
9
 
@@ -17,7 +17,6 @@ class _KIEPredictor(_OCRPredictor):
17
17
  """Implements an object able to localize and identify text elements in a set of documents
18
18
 
19
19
  Args:
20
- ----
21
20
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
22
21
  without rotated textual elements.
23
22
  straighten_pages: if True, estimates the page general orientation based on the median line orientation.
@@ -30,8 +29,8 @@ class _KIEPredictor(_OCRPredictor):
30
29
  kwargs: keyword args of `DocumentBuilder`
31
30
  """
32
31
 
33
- crop_orientation_predictor: Optional[OrientationPredictor]
34
- page_orientation_predictor: Optional[OrientationPredictor]
32
+ crop_orientation_predictor: OrientationPredictor | None
33
+ page_orientation_predictor: OrientationPredictor | None
35
34
 
36
35
  def __init__(
37
36
  self,
@@ -46,4 +45,8 @@ class _KIEPredictor(_OCRPredictor):
46
45
  assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, detect_orientation, **kwargs
47
46
  )
48
47
 
48
+ # Remove the following arguments from kwargs after initialization of the parent class
49
+ kwargs.pop("disable_page_orientation", None)
50
+ kwargs.pop("disable_crop_orientation", None)
51
+
49
52
  self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs)
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -24,7 +24,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
- ----
28
27
  det_predictor: detection module
29
28
  reco_predictor: recognition module
30
29
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -52,8 +51,8 @@ class KIEPredictor(nn.Module, _KIEPredictor):
52
51
  **kwargs: Any,
53
52
  ) -> None:
54
53
  nn.Module.__init__(self)
55
- self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
56
- self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
54
+ self.det_predictor = det_predictor.eval()
55
+ self.reco_predictor = reco_predictor.eval()
57
56
  _KIEPredictor.__init__(
58
57
  self,
59
58
  assume_straight_pages,
@@ -69,7 +68,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
69
68
  @torch.inference_mode()
70
69
  def forward(
71
70
  self,
72
- pages: List[Union[np.ndarray, torch.Tensor]],
71
+ pages: list[np.ndarray | torch.Tensor],
73
72
  **kwargs: Any,
74
73
  ) -> Document:
75
74
  # Dimension check
@@ -89,7 +88,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
89
88
  for out_map in out_maps
90
89
  ]
91
90
  if self.detect_orientation:
92
- general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
91
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
93
92
  orientations = [
94
93
  {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
95
94
  ]
@@ -98,11 +97,14 @@ class KIEPredictor(nn.Module, _KIEPredictor):
98
97
  general_pages_orientations = None
99
98
  origin_pages_orientations = None
100
99
  if self.straighten_pages:
101
- pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
100
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
101
+ # update page shapes after straightening
102
+ origin_page_shapes = [page.shape[:2] for page in pages]
103
+
102
104
  # Forward again to get predictions on straight pages
103
105
  loc_preds = self.det_predictor(pages, **kwargs)
104
106
 
105
- dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
107
+ dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
106
108
 
107
109
  # Detach objectness scores from loc_preds
108
110
  objectness_scores = {}
@@ -122,10 +124,11 @@ class KIEPredictor(nn.Module, _KIEPredictor):
122
124
  crops = {}
123
125
  for class_name in dict_loc_preds.keys():
124
126
  crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
125
- pages, # type: ignore[arg-type]
127
+ pages,
126
128
  dict_loc_preds[class_name],
127
129
  channels_last=channels_last,
128
130
  assume_straight_pages=self.assume_straight_pages,
131
+ assume_horizontal=self._page_orientation_disabled,
129
132
  )
130
133
  # Rectify crop orientation
131
134
  crop_orientations: Any = {}
@@ -146,18 +149,18 @@ class KIEPredictor(nn.Module, _KIEPredictor):
146
149
  if not crop_orientations:
147
150
  crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
148
151
 
149
- boxes: Dict = {}
150
- text_preds: Dict = {}
151
- word_crop_orientations: Dict = {}
152
+ boxes: dict = {}
153
+ text_preds: dict = {}
154
+ word_crop_orientations: dict = {}
152
155
  for class_name in dict_loc_preds.keys():
153
156
  boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
154
157
  dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
155
158
  )
156
159
 
157
- boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
158
- objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
159
- text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
160
- crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
160
+ boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
161
+ objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
162
+ text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
163
+ crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
161
164
 
162
165
  if self.detect_language:
163
166
  languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
@@ -166,7 +169,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
166
169
  languages_dict = None
167
170
 
168
171
  out = self.doc_builder(
169
- pages, # type: ignore[arg-type]
172
+ pages,
170
173
  boxes_per_page,
171
174
  objectness_scores_per_page,
172
175
  text_preds_per_page,
@@ -178,7 +181,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
178
181
  return out
179
182
 
180
183
  @staticmethod
181
- def get_text(text_pred: Dict) -> str:
184
+ def get_text(text_pred: dict) -> str:
182
185
  text = []
183
186
  for value in text_pred.values():
184
187
  text += [item[0] for item in value]
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -24,7 +24,6 @@ class KIEPredictor(NestedObject, _KIEPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
- ----
28
27
  det_predictor: detection module
29
28
  reco_predictor: recognition module
30
29
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -69,7 +68,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
69
68
 
70
69
  def __call__(
71
70
  self,
72
- pages: List[Union[np.ndarray, tf.Tensor]],
71
+ pages: list[np.ndarray | tf.Tensor],
73
72
  **kwargs: Any,
74
73
  ) -> Document:
75
74
  # Dimension check
@@ -99,10 +98,13 @@ class KIEPredictor(NestedObject, _KIEPredictor):
99
98
  origin_pages_orientations = None
100
99
  if self.straighten_pages:
101
100
  pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
101
+ # update page shapes after straightening
102
+ origin_page_shapes = [page.shape[:2] for page in pages]
103
+
102
104
  # Forward again to get predictions on straight pages
103
- loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
105
+ loc_preds = self.det_predictor(pages, **kwargs)
104
106
 
105
- dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
107
+ dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
106
108
 
107
109
  # Detach objectness scores from loc_preds
108
110
  objectness_scores = {}
@@ -119,7 +121,11 @@ class KIEPredictor(NestedObject, _KIEPredictor):
119
121
  crops = {}
120
122
  for class_name in dict_loc_preds.keys():
121
123
  crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
122
- pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages
124
+ pages,
125
+ dict_loc_preds[class_name],
126
+ channels_last=True,
127
+ assume_straight_pages=self.assume_straight_pages,
128
+ assume_horizontal=self._page_orientation_disabled,
123
129
  )
124
130
 
125
131
  # Rectify crop orientation
@@ -141,18 +147,18 @@ class KIEPredictor(NestedObject, _KIEPredictor):
141
147
  if not crop_orientations:
142
148
  crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
143
149
 
144
- boxes: Dict = {}
145
- text_preds: Dict = {}
146
- word_crop_orientations: Dict = {}
150
+ boxes: dict = {}
151
+ text_preds: dict = {}
152
+ word_crop_orientations: dict = {}
147
153
  for class_name in dict_loc_preds.keys():
148
154
  boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
149
155
  dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
150
156
  )
151
157
 
152
- boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
153
- objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
154
- text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
155
- crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
158
+ boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
159
+ objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
160
+ text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
161
+ crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
156
162
 
157
163
  if self.detect_language:
158
164
  languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
@@ -173,7 +179,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
173
179
  return out
174
180
 
175
181
  @staticmethod
176
- def get_text(text_pred: Dict) -> str:
182
+ def get_text(text_pred: dict) -> str:
177
183
  text = []
178
184
  for value in text_pred.values():
179
185
  text += [item[0] for item in value]
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Tuple, Union
7
6
 
8
7
  import numpy as np
9
8
  import torch
@@ -19,7 +18,7 @@ class FASTConvLayer(nn.Module):
19
18
  self,
20
19
  in_channels: int,
21
20
  out_channels: int,
22
- kernel_size: Union[int, Tuple[int, int]],
21
+ kernel_size: int | tuple[int, int],
23
22
  stride: int = 1,
24
23
  dilation: int = 1,
25
24
  groups: int = 1,
@@ -93,9 +92,7 @@ class FASTConvLayer(nn.Module):
93
92
 
94
93
  # The following logic is used to reparametrize the layer
95
94
  # Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py
96
- def _identity_to_conv(
97
- self, identity: Union[nn.BatchNorm2d, None]
98
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
95
+ def _identity_to_conv(self, identity: nn.BatchNorm2d | None) -> tuple[torch.Tensor, torch.Tensor] | tuple[int, int]:
99
96
  if identity is None or identity.running_var is None:
100
97
  return 0, 0
101
98
  if not hasattr(self, "id_tensor"):
@@ -106,18 +103,18 @@ class FASTConvLayer(nn.Module):
106
103
  id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
107
104
  self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
108
105
  kernel = self.id_tensor
109
- std = (identity.running_var + identity.eps).sqrt()
106
+ std = (identity.running_var + identity.eps).sqrt() # type: ignore
110
107
  t = (identity.weight / std).reshape(-1, 1, 1, 1)
111
108
  return kernel * t, identity.bias - identity.running_mean * identity.weight / std
112
109
 
113
- def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> Tuple[torch.Tensor, torch.Tensor]:
110
+ def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
114
111
  kernel = conv.weight
115
112
  kernel = self._pad_to_mxn_tensor(kernel)
116
113
  std = (bn.running_var + bn.eps).sqrt() # type: ignore
117
114
  t = (bn.weight / std).reshape(-1, 1, 1, 1)
118
115
  return kernel * t, bn.bias - bn.running_mean * bn.weight / std
119
116
 
120
- def _get_equivalent_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
117
+ def _get_equivalent_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
121
118
  kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
122
119
  if self.ver_conv is not None:
123
120
  kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type]
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -21,7 +21,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
21
21
  self,
22
22
  in_channels: int,
23
23
  out_channels: int,
24
- kernel_size: Union[int, Tuple[int, int]],
24
+ kernel_size: int | tuple[int, int],
25
25
  stride: int = 1,
26
26
  dilation: int = 1,
27
27
  groups: int = 1,
@@ -103,9 +103,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
103
103
 
104
104
  # The following logic is used to reparametrize the layer
105
105
  # Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py
106
- def _identity_to_conv(
107
- self, identity: layers.BatchNormalization
108
- ) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]:
106
+ def _identity_to_conv(self, identity: layers.BatchNormalization) -> tuple[tf.Tensor, tf.Tensor] | tuple[int, int]:
109
107
  if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"):
110
108
  return 0, 0
111
109
  if not hasattr(self, "id_tensor"):
@@ -120,7 +118,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
120
118
  t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
121
119
  return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
122
120
 
123
- def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]:
121
+ def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> tuple[tf.Tensor, tf.Tensor]:
124
122
  kernel = conv.kernel
125
123
  kernel = self._pad_to_mxn_tensor(kernel)
126
124
  std = tf.sqrt(bn.moving_variance + bn.epsilon)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,8 @@
6
6
  # This module 'transformer.py' is inspired by https://github.com/wenwenyu/MASTER-pytorch and Decoder is borrowed
7
7
 
8
8
  import math
9
- from typing import Any, Callable, Optional, Tuple
9
+ from collections.abc import Callable
10
+ from typing import Any
10
11
 
11
12
  import torch
12
13
  from torch import nn
@@ -33,26 +34,24 @@ class PositionalEncoding(nn.Module):
33
34
  """Forward pass
34
35
 
35
36
  Args:
36
- ----
37
37
  x: embeddings (batch, max_len, d_model)
38
38
 
39
- Returns
40
- -------
39
+ Returns:
41
40
  positional embeddings (batch, max_len, d_model)
42
41
  """
43
- x = x + self.pe[:, : x.size(1)]
42
+ x = x + self.pe[:, : x.size(1)] # type: ignore[index]
44
43
  return self.dropout(x)
45
44
 
46
45
 
47
46
  def scaled_dot_product_attention(
48
- query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None
49
- ) -> Tuple[torch.Tensor, torch.Tensor]:
47
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor | None = None
48
+ ) -> tuple[torch.Tensor, torch.Tensor]:
50
49
  """Scaled Dot-Product Attention"""
51
50
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
52
51
  if mask is not None:
53
52
  # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
54
- scores = scores.masked_fill(mask == 0, float("-inf"))
55
- p_attn = torch.softmax(scores, dim=-1)
53
+ scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined]
54
+ p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload]
56
55
  return torch.matmul(p_attn, value), p_attn
57
56
 
58
57
 
@@ -130,7 +129,7 @@ class EncoderBlock(nn.Module):
130
129
  PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
131
130
  ])
132
131
 
133
- def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
132
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
134
133
  output = x
135
134
 
136
135
  for i in range(self.num_layers):
@@ -183,8 +182,8 @@ class Decoder(nn.Module):
183
182
  self,
184
183
  tgt: torch.Tensor,
185
184
  memory: torch.Tensor,
186
- source_mask: Optional[torch.Tensor] = None,
187
- target_mask: Optional[torch.Tensor] = None,
185
+ source_mask: torch.Tensor | None = None,
186
+ target_mask: torch.Tensor | None = None,
188
187
  ) -> torch.Tensor:
189
188
  tgt = self.embed(tgt) * math.sqrt(self.d_model)
190
189
  pos_enc_tgt = self.positional_encoding(tgt)
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, Callable, Optional, Tuple
7
+ from collections.abc import Callable
8
+ from typing import Any
8
9
 
9
10
  import tensorflow as tf
10
11
  from tensorflow.keras import layers
@@ -13,8 +14,6 @@ from doctr.utils.repr import NestedObject
13
14
 
14
15
  __all__ = ["Decoder", "PositionalEncoding", "EncoderBlock", "PositionwiseFeedForward", "MultiHeadAttention"]
15
16
 
16
- tf.config.run_functions_eagerly(True)
17
-
18
17
 
19
18
  class PositionalEncoding(layers.Layer, NestedObject):
20
19
  """Compute positional encoding"""
@@ -45,12 +44,10 @@ class PositionalEncoding(layers.Layer, NestedObject):
45
44
  """Forward pass
46
45
 
47
46
  Args:
48
- ----
49
47
  x: embeddings (batch, max_len, d_model)
50
48
  **kwargs: additional arguments
51
49
 
52
- Returns
53
- -------
50
+ Returns:
54
51
  positional embeddings (batch, max_len, d_model)
55
52
  """
56
53
  if x.dtype == tf.float16: # amp fix: cast to half
@@ -62,8 +59,8 @@ class PositionalEncoding(layers.Layer, NestedObject):
62
59
 
63
60
  @tf.function
64
61
  def scaled_dot_product_attention(
65
- query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: Optional[tf.Tensor] = None
66
- ) -> Tuple[tf.Tensor, tf.Tensor]:
62
+ query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: tf.Tensor | None = None
63
+ ) -> tuple[tf.Tensor, tf.Tensor]:
67
64
  """Scaled Dot-Product Attention"""
68
65
  scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) / math.sqrt(query.shape[-1])
69
66
  if mask is not None:
@@ -162,7 +159,7 @@ class EncoderBlock(layers.Layer, NestedObject):
162
159
  PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
163
160
  ]
164
161
 
165
- def call(self, x: tf.Tensor, mask: Optional[tf.Tensor] = None, **kwargs: Any) -> tf.Tensor:
162
+ def call(self, x: tf.Tensor, mask: tf.Tensor | None = None, **kwargs: Any) -> tf.Tensor:
166
163
  output = x
167
164
 
168
165
  for i in range(self.num_layers):
@@ -212,8 +209,8 @@ class Decoder(layers.Layer, NestedObject):
212
209
  self,
213
210
  tgt: tf.Tensor,
214
211
  memory: tf.Tensor,
215
- source_mask: Optional[tf.Tensor] = None,
216
- target_mask: Optional[tf.Tensor] = None,
212
+ source_mask: tf.Tensor | None = None,
213
+ target_mask: tf.Tensor | None = None,
217
214
  **kwargs: Any,
218
215
  ) -> tf.Tensor:
219
216
  tgt = self.embed(tgt, **kwargs) * math.sqrt(self.d_model)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,10 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Tuple
8
7
 
9
8
  import torch
10
9
  from torch import nn
@@ -15,12 +14,12 @@ __all__ = ["PatchEmbedding"]
15
14
  class PatchEmbedding(nn.Module):
16
15
  """Compute 2D patch embeddings with cls token and positional encoding"""
17
16
 
18
- def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None:
17
+ def __init__(self, input_shape: tuple[int, int, int], embed_dim: int, patch_size: tuple[int, int]) -> None:
19
18
  super().__init__()
20
19
  channels, height, width = input_shape
21
20
  self.patch_size = patch_size
22
21
  self.interpolate = True if patch_size[0] == patch_size[1] else False
23
- self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)])
22
+ self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size))
24
23
  self.num_patches = self.grid_size[0] * self.grid_size[1]
25
24
 
26
25
  self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import layers
@@ -17,12 +17,12 @@ __all__ = ["PatchEmbedding"]
17
17
  class PatchEmbedding(layers.Layer, NestedObject):
18
18
  """Compute 2D patch embeddings with cls token and positional encoding"""
19
19
 
20
- def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None:
20
+ def __init__(self, input_shape: tuple[int, int, int], embed_dim: int, patch_size: tuple[int, int]) -> None:
21
21
  super().__init__()
22
22
  height, width, _ = input_shape
23
23
  self.patch_size = patch_size
24
24
  self.interpolate = True if patch_size[0] == patch_size[1] else False
25
- self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)])
25
+ self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size))
26
26
  self.num_patches = self.grid_size[0] * self.grid_size[1]
27
27
 
28
28
  self.cls_token = self.add_weight(shape=(1, 1, embed_dim), initializer="zeros", trainable=True, name="cls_token")
@@ -1,6 +1,6 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- else:
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]