python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +8 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +7 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +4 -5
  17. doctr/datasets/ic13.py +4 -5
  18. doctr/datasets/iiit5k.py +6 -5
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +6 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +6 -5
  27. doctr/datasets/svhn.py +6 -5
  28. doctr/datasets/svt.py +4 -5
  29. doctr/datasets/synthtext.py +4 -5
  30. doctr/datasets/utils.py +34 -29
  31. doctr/datasets/vocabs.py +17 -7
  32. doctr/datasets/wildreceipt.py +14 -10
  33. doctr/file_utils.py +2 -7
  34. doctr/io/elements.py +59 -79
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +30 -48
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +8 -11
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +5 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +8 -21
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +6 -8
  52. doctr/models/classification/predictor/tensorflow.py +6 -8
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +20 -31
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +8 -15
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +9 -12
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +6 -12
  65. doctr/models/classification/zoo.py +19 -14
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +14 -26
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +14 -23
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +5 -6
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +3 -7
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +4 -5
  91. doctr/models/kie_predictor/pytorch.py +18 -19
  92. doctr/models/kie_predictor/tensorflow.py +13 -14
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -10
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  101. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +28 -29
  104. doctr/models/predictor/pytorch.py +12 -13
  105. doctr/models/predictor/tensorflow.py +8 -9
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +10 -14
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +11 -23
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +12 -22
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +16 -22
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +12 -21
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +12 -20
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +14 -17
  136. doctr/models/utils/tensorflow.py +17 -16
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +20 -28
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +58 -22
  145. doctr/transforms/modules/tensorflow.py +18 -32
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +16 -47
  150. doctr/utils/metrics.py +17 -37
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +9 -13
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.10.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Optional
6
+ from typing import Any
7
7
 
8
8
  from doctr.models.builder import KIEDocumentBuilder
9
9
 
@@ -17,7 +17,6 @@ class _KIEPredictor(_OCRPredictor):
17
17
  """Implements an object able to localize and identify text elements in a set of documents
18
18
 
19
19
  Args:
20
- ----
21
20
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
22
21
  without rotated textual elements.
23
22
  straighten_pages: if True, estimates the page general orientation based on the median line orientation.
@@ -30,8 +29,8 @@ class _KIEPredictor(_OCRPredictor):
30
29
  kwargs: keyword args of `DocumentBuilder`
31
30
  """
32
31
 
33
- crop_orientation_predictor: Optional[OrientationPredictor]
34
- page_orientation_predictor: Optional[OrientationPredictor]
32
+ crop_orientation_predictor: OrientationPredictor | None
33
+ page_orientation_predictor: OrientationPredictor | None
35
34
 
36
35
  def __init__(
37
36
  self,
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -24,7 +24,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
- ----
28
27
  det_predictor: detection module
29
28
  reco_predictor: recognition module
30
29
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -52,8 +51,8 @@ class KIEPredictor(nn.Module, _KIEPredictor):
52
51
  **kwargs: Any,
53
52
  ) -> None:
54
53
  nn.Module.__init__(self)
55
- self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
56
- self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
54
+ self.det_predictor = det_predictor.eval()
55
+ self.reco_predictor = reco_predictor.eval()
57
56
  _KIEPredictor.__init__(
58
57
  self,
59
58
  assume_straight_pages,
@@ -69,7 +68,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
69
68
  @torch.inference_mode()
70
69
  def forward(
71
70
  self,
72
- pages: List[Union[np.ndarray, torch.Tensor]],
71
+ pages: list[np.ndarray | torch.Tensor],
73
72
  **kwargs: Any,
74
73
  ) -> Document:
75
74
  # Dimension check
@@ -89,7 +88,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
89
88
  for out_map in out_maps
90
89
  ]
91
90
  if self.detect_orientation:
92
- general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
91
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
93
92
  orientations = [
94
93
  {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
95
94
  ]
@@ -98,14 +97,14 @@ class KIEPredictor(nn.Module, _KIEPredictor):
98
97
  general_pages_orientations = None
99
98
  origin_pages_orientations = None
100
99
  if self.straighten_pages:
101
- pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
100
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
102
101
  # update page shapes after straightening
103
102
  origin_page_shapes = [page.shape[:2] for page in pages]
104
103
 
105
104
  # Forward again to get predictions on straight pages
106
105
  loc_preds = self.det_predictor(pages, **kwargs)
107
106
 
108
- dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
107
+ dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
109
108
 
110
109
  # Detach objectness scores from loc_preds
111
110
  objectness_scores = {}
@@ -125,7 +124,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
125
124
  crops = {}
126
125
  for class_name in dict_loc_preds.keys():
127
126
  crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
128
- pages, # type: ignore[arg-type]
127
+ pages,
129
128
  dict_loc_preds[class_name],
130
129
  channels_last=channels_last,
131
130
  assume_straight_pages=self.assume_straight_pages,
@@ -150,18 +149,18 @@ class KIEPredictor(nn.Module, _KIEPredictor):
150
149
  if not crop_orientations:
151
150
  crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
152
151
 
153
- boxes: Dict = {}
154
- text_preds: Dict = {}
155
- word_crop_orientations: Dict = {}
152
+ boxes: dict = {}
153
+ text_preds: dict = {}
154
+ word_crop_orientations: dict = {}
156
155
  for class_name in dict_loc_preds.keys():
157
156
  boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
158
157
  dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
159
158
  )
160
159
 
161
- boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
162
- objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
163
- text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
164
- crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
160
+ boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
161
+ objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
162
+ text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
163
+ crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
165
164
 
166
165
  if self.detect_language:
167
166
  languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
@@ -170,7 +169,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
170
169
  languages_dict = None
171
170
 
172
171
  out = self.doc_builder(
173
- pages, # type: ignore[arg-type]
172
+ pages,
174
173
  boxes_per_page,
175
174
  objectness_scores_per_page,
176
175
  text_preds_per_page,
@@ -182,7 +181,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
182
181
  return out
183
182
 
184
183
  @staticmethod
185
- def get_text(text_pred: Dict) -> str:
184
+ def get_text(text_pred: dict) -> str:
186
185
  text = []
187
186
  for value in text_pred.values():
188
187
  text += [item[0] for item in value]
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -24,7 +24,6 @@ class KIEPredictor(NestedObject, _KIEPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
- ----
28
27
  det_predictor: detection module
29
28
  reco_predictor: recognition module
30
29
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -69,7 +68,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
69
68
 
70
69
  def __call__(
71
70
  self,
72
- pages: List[Union[np.ndarray, tf.Tensor]],
71
+ pages: list[np.ndarray | tf.Tensor],
73
72
  **kwargs: Any,
74
73
  ) -> Document:
75
74
  # Dimension check
@@ -103,9 +102,9 @@ class KIEPredictor(NestedObject, _KIEPredictor):
103
102
  origin_page_shapes = [page.shape[:2] for page in pages]
104
103
 
105
104
  # Forward again to get predictions on straight pages
106
- loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
105
+ loc_preds = self.det_predictor(pages, **kwargs)
107
106
 
108
- dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
107
+ dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
109
108
 
110
109
  # Detach objectness scores from loc_preds
111
110
  objectness_scores = {}
@@ -148,18 +147,18 @@ class KIEPredictor(NestedObject, _KIEPredictor):
148
147
  if not crop_orientations:
149
148
  crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
150
149
 
151
- boxes: Dict = {}
152
- text_preds: Dict = {}
153
- word_crop_orientations: Dict = {}
150
+ boxes: dict = {}
151
+ text_preds: dict = {}
152
+ word_crop_orientations: dict = {}
154
153
  for class_name in dict_loc_preds.keys():
155
154
  boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
156
155
  dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
157
156
  )
158
157
 
159
- boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
160
- objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
161
- text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
162
- crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
158
+ boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
159
+ objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
160
+ text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
161
+ crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
163
162
 
164
163
  if self.detect_language:
165
164
  languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
@@ -180,7 +179,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
180
179
  return out
181
180
 
182
181
  @staticmethod
183
- def get_text(text_pred: Dict) -> str:
182
+ def get_text(text_pred: dict) -> str:
184
183
  text = []
185
184
  for value in text_pred.values():
186
185
  text += [item[0] for item in value]
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Tuple, Union
7
6
 
8
7
  import numpy as np
9
8
  import torch
@@ -19,7 +18,7 @@ class FASTConvLayer(nn.Module):
19
18
  self,
20
19
  in_channels: int,
21
20
  out_channels: int,
22
- kernel_size: Union[int, Tuple[int, int]],
21
+ kernel_size: int | tuple[int, int],
23
22
  stride: int = 1,
24
23
  dilation: int = 1,
25
24
  groups: int = 1,
@@ -93,9 +92,7 @@ class FASTConvLayer(nn.Module):
93
92
 
94
93
  # The following logic is used to reparametrize the layer
95
94
  # Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py
96
- def _identity_to_conv(
97
- self, identity: Union[nn.BatchNorm2d, None]
98
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
95
+ def _identity_to_conv(self, identity: nn.BatchNorm2d | None) -> tuple[torch.Tensor, torch.Tensor] | tuple[int, int]:
99
96
  if identity is None or identity.running_var is None:
100
97
  return 0, 0
101
98
  if not hasattr(self, "id_tensor"):
@@ -106,18 +103,18 @@ class FASTConvLayer(nn.Module):
106
103
  id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
107
104
  self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
108
105
  kernel = self.id_tensor
109
- std = (identity.running_var + identity.eps).sqrt()
106
+ std = (identity.running_var + identity.eps).sqrt() # type: ignore
110
107
  t = (identity.weight / std).reshape(-1, 1, 1, 1)
111
108
  return kernel * t, identity.bias - identity.running_mean * identity.weight / std
112
109
 
113
- def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> Tuple[torch.Tensor, torch.Tensor]:
110
+ def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
114
111
  kernel = conv.weight
115
112
  kernel = self._pad_to_mxn_tensor(kernel)
116
113
  std = (bn.running_var + bn.eps).sqrt() # type: ignore
117
114
  t = (bn.weight / std).reshape(-1, 1, 1, 1)
118
115
  return kernel * t, bn.bias - bn.running_mean * bn.weight / std
119
116
 
120
- def _get_equivalent_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
117
+ def _get_equivalent_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
121
118
  kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
122
119
  if self.ver_conv is not None:
123
120
  kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type]
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -21,7 +21,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
21
21
  self,
22
22
  in_channels: int,
23
23
  out_channels: int,
24
- kernel_size: Union[int, Tuple[int, int]],
24
+ kernel_size: int | tuple[int, int],
25
25
  stride: int = 1,
26
26
  dilation: int = 1,
27
27
  groups: int = 1,
@@ -103,9 +103,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
103
103
 
104
104
  # The following logic is used to reparametrize the layer
105
105
  # Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py
106
- def _identity_to_conv(
107
- self, identity: layers.BatchNormalization
108
- ) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]:
106
+ def _identity_to_conv(self, identity: layers.BatchNormalization) -> tuple[tf.Tensor, tf.Tensor] | tuple[int, int]:
109
107
  if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"):
110
108
  return 0, 0
111
109
  if not hasattr(self, "id_tensor"):
@@ -120,7 +118,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
120
118
  t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
121
119
  return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
122
120
 
123
- def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]:
121
+ def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> tuple[tf.Tensor, tf.Tensor]:
124
122
  kernel = conv.kernel
125
123
  kernel = self._pad_to_mxn_tensor(kernel)
126
124
  std = tf.sqrt(bn.moving_variance + bn.epsilon)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,8 @@
6
6
  # This module 'transformer.py' is inspired by https://github.com/wenwenyu/MASTER-pytorch and Decoder is borrowed
7
7
 
8
8
  import math
9
- from typing import Any, Callable, Optional, Tuple
9
+ from collections.abc import Callable
10
+ from typing import Any
10
11
 
11
12
  import torch
12
13
  from torch import nn
@@ -33,26 +34,24 @@ class PositionalEncoding(nn.Module):
33
34
  """Forward pass
34
35
 
35
36
  Args:
36
- ----
37
37
  x: embeddings (batch, max_len, d_model)
38
38
 
39
- Returns
40
- -------
39
+ Returns:
41
40
  positional embeddings (batch, max_len, d_model)
42
41
  """
43
- x = x + self.pe[:, : x.size(1)]
42
+ x = x + self.pe[:, : x.size(1)] # type: ignore[index]
44
43
  return self.dropout(x)
45
44
 
46
45
 
47
46
  def scaled_dot_product_attention(
48
- query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None
49
- ) -> Tuple[torch.Tensor, torch.Tensor]:
47
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor | None = None
48
+ ) -> tuple[torch.Tensor, torch.Tensor]:
50
49
  """Scaled Dot-Product Attention"""
51
50
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
52
51
  if mask is not None:
53
52
  # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
54
- scores = scores.masked_fill(mask == 0, float("-inf"))
55
- p_attn = torch.softmax(scores, dim=-1)
53
+ scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined]
54
+ p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload]
56
55
  return torch.matmul(p_attn, value), p_attn
57
56
 
58
57
 
@@ -130,7 +129,7 @@ class EncoderBlock(nn.Module):
130
129
  PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
131
130
  ])
132
131
 
133
- def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
132
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
134
133
  output = x
135
134
 
136
135
  for i in range(self.num_layers):
@@ -183,8 +182,8 @@ class Decoder(nn.Module):
183
182
  self,
184
183
  tgt: torch.Tensor,
185
184
  memory: torch.Tensor,
186
- source_mask: Optional[torch.Tensor] = None,
187
- target_mask: Optional[torch.Tensor] = None,
185
+ source_mask: torch.Tensor | None = None,
186
+ target_mask: torch.Tensor | None = None,
188
187
  ) -> torch.Tensor:
189
188
  tgt = self.embed(tgt) * math.sqrt(self.d_model)
190
189
  pos_enc_tgt = self.positional_encoding(tgt)
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, Callable, Optional, Tuple
7
+ from collections.abc import Callable
8
+ from typing import Any
8
9
 
9
10
  import tensorflow as tf
10
11
  from tensorflow.keras import layers
@@ -43,12 +44,10 @@ class PositionalEncoding(layers.Layer, NestedObject):
43
44
  """Forward pass
44
45
 
45
46
  Args:
46
- ----
47
47
  x: embeddings (batch, max_len, d_model)
48
48
  **kwargs: additional arguments
49
49
 
50
- Returns
51
- -------
50
+ Returns:
52
51
  positional embeddings (batch, max_len, d_model)
53
52
  """
54
53
  if x.dtype == tf.float16: # amp fix: cast to half
@@ -60,8 +59,8 @@ class PositionalEncoding(layers.Layer, NestedObject):
60
59
 
61
60
  @tf.function
62
61
  def scaled_dot_product_attention(
63
- query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: Optional[tf.Tensor] = None
64
- ) -> Tuple[tf.Tensor, tf.Tensor]:
62
+ query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: tf.Tensor | None = None
63
+ ) -> tuple[tf.Tensor, tf.Tensor]:
65
64
  """Scaled Dot-Product Attention"""
66
65
  scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) / math.sqrt(query.shape[-1])
67
66
  if mask is not None:
@@ -160,7 +159,7 @@ class EncoderBlock(layers.Layer, NestedObject):
160
159
  PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
161
160
  ]
162
161
 
163
- def call(self, x: tf.Tensor, mask: Optional[tf.Tensor] = None, **kwargs: Any) -> tf.Tensor:
162
+ def call(self, x: tf.Tensor, mask: tf.Tensor | None = None, **kwargs: Any) -> tf.Tensor:
164
163
  output = x
165
164
 
166
165
  for i in range(self.num_layers):
@@ -210,8 +209,8 @@ class Decoder(layers.Layer, NestedObject):
210
209
  self,
211
210
  tgt: tf.Tensor,
212
211
  memory: tf.Tensor,
213
- source_mask: Optional[tf.Tensor] = None,
214
- target_mask: Optional[tf.Tensor] = None,
212
+ source_mask: tf.Tensor | None = None,
213
+ target_mask: tf.Tensor | None = None,
215
214
  **kwargs: Any,
216
215
  ) -> tf.Tensor:
217
216
  tgt = self.embed(tgt, **kwargs) * math.sqrt(self.d_model)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,10 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Tuple
8
7
 
9
8
  import torch
10
9
  from torch import nn
@@ -15,7 +14,7 @@ __all__ = ["PatchEmbedding"]
15
14
  class PatchEmbedding(nn.Module):
16
15
  """Compute 2D patch embeddings with cls token and positional encoding"""
17
16
 
18
- def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None:
17
+ def __init__(self, input_shape: tuple[int, int, int], embed_dim: int, patch_size: tuple[int, int]) -> None:
19
18
  super().__init__()
20
19
  channels, height, width = input_shape
21
20
  self.patch_size = patch_size
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import layers
@@ -17,7 +17,7 @@ __all__ = ["PatchEmbedding"]
17
17
  class PatchEmbedding(layers.Layer, NestedObject):
18
18
  """Compute 2D patch embeddings with cls token and positional encoding"""
19
19
 
20
- def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None:
20
+ def __init__(self, input_shape: tuple[int, int, int], embed_dim: int, patch_size: tuple[int, int]) -> None:
21
21
  super().__init__()
22
22
  height, width, _ = input_shape
23
23
  self.patch_size = patch_size
@@ -1,6 +1,6 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- else:
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]