python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -24,7 +24,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
- ----
28
27
  det_predictor: detection module
29
28
  reco_predictor: recognition module
30
29
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -52,8 +51,8 @@ class KIEPredictor(nn.Module, _KIEPredictor):
52
51
  **kwargs: Any,
53
52
  ) -> None:
54
53
  nn.Module.__init__(self)
55
- self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
56
- self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
54
+ self.det_predictor = det_predictor.eval()
55
+ self.reco_predictor = reco_predictor.eval()
57
56
  _KIEPredictor.__init__(
58
57
  self,
59
58
  assume_straight_pages,
@@ -69,7 +68,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
69
68
  @torch.inference_mode()
70
69
  def forward(
71
70
  self,
72
- pages: List[Union[np.ndarray, torch.Tensor]],
71
+ pages: list[np.ndarray | torch.Tensor],
73
72
  **kwargs: Any,
74
73
  ) -> Document:
75
74
  # Dimension check
@@ -89,7 +88,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
89
88
  for out_map in out_maps
90
89
  ]
91
90
  if self.detect_orientation:
92
- general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
91
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
93
92
  orientations = [
94
93
  {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
95
94
  ]
@@ -98,14 +97,14 @@ class KIEPredictor(nn.Module, _KIEPredictor):
98
97
  general_pages_orientations = None
99
98
  origin_pages_orientations = None
100
99
  if self.straighten_pages:
101
- pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
100
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
102
101
  # update page shapes after straightening
103
102
  origin_page_shapes = [page.shape[:2] for page in pages]
104
103
 
105
104
  # Forward again to get predictions on straight pages
106
105
  loc_preds = self.det_predictor(pages, **kwargs)
107
106
 
108
- dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
107
+ dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
109
108
 
110
109
  # Detach objectness scores from loc_preds
111
110
  objectness_scores = {}
@@ -125,7 +124,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
125
124
  crops = {}
126
125
  for class_name in dict_loc_preds.keys():
127
126
  crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
128
- pages, # type: ignore[arg-type]
127
+ pages,
129
128
  dict_loc_preds[class_name],
130
129
  channels_last=channels_last,
131
130
  assume_straight_pages=self.assume_straight_pages,
@@ -150,18 +149,18 @@ class KIEPredictor(nn.Module, _KIEPredictor):
150
149
  if not crop_orientations:
151
150
  crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
152
151
 
153
- boxes: Dict = {}
154
- text_preds: Dict = {}
155
- word_crop_orientations: Dict = {}
152
+ boxes: dict = {}
153
+ text_preds: dict = {}
154
+ word_crop_orientations: dict = {}
156
155
  for class_name in dict_loc_preds.keys():
157
156
  boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
158
157
  dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
159
158
  )
160
159
 
161
- boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
162
- objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
163
- text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
164
- crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
160
+ boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
161
+ objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
162
+ text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
163
+ crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
165
164
 
166
165
  if self.detect_language:
167
166
  languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
@@ -170,11 +169,11 @@ class KIEPredictor(nn.Module, _KIEPredictor):
170
169
  languages_dict = None
171
170
 
172
171
  out = self.doc_builder(
173
- pages, # type: ignore[arg-type]
172
+ pages,
174
173
  boxes_per_page,
175
174
  objectness_scores_per_page,
176
175
  text_preds_per_page,
177
- origin_page_shapes, # type: ignore[arg-type]
176
+ origin_page_shapes,
178
177
  crop_orientations_per_page,
179
178
  orientations,
180
179
  languages_dict,
@@ -182,7 +181,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
182
181
  return out
183
182
 
184
183
  @staticmethod
185
- def get_text(text_pred: Dict) -> str:
184
+ def get_text(text_pred: dict) -> str:
186
185
  text = []
187
186
  for value in text_pred.values():
188
187
  text += [item[0] for item in value]
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -24,7 +24,6 @@ class KIEPredictor(NestedObject, _KIEPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
- ----
28
27
  det_predictor: detection module
29
28
  reco_predictor: recognition module
30
29
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -69,7 +68,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
69
68
 
70
69
  def __call__(
71
70
  self,
72
- pages: List[Union[np.ndarray, tf.Tensor]],
71
+ pages: list[np.ndarray | tf.Tensor],
73
72
  **kwargs: Any,
74
73
  ) -> Document:
75
74
  # Dimension check
@@ -103,9 +102,9 @@ class KIEPredictor(NestedObject, _KIEPredictor):
103
102
  origin_page_shapes = [page.shape[:2] for page in pages]
104
103
 
105
104
  # Forward again to get predictions on straight pages
106
- loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
105
+ loc_preds = self.det_predictor(pages, **kwargs)
107
106
 
108
- dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
107
+ dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
109
108
 
110
109
  # Detach objectness scores from loc_preds
111
110
  objectness_scores = {}
@@ -148,18 +147,18 @@ class KIEPredictor(NestedObject, _KIEPredictor):
148
147
  if not crop_orientations:
149
148
  crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
150
149
 
151
- boxes: Dict = {}
152
- text_preds: Dict = {}
153
- word_crop_orientations: Dict = {}
150
+ boxes: dict = {}
151
+ text_preds: dict = {}
152
+ word_crop_orientations: dict = {}
154
153
  for class_name in dict_loc_preds.keys():
155
154
  boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
156
155
  dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
157
156
  )
158
157
 
159
- boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
160
- objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
161
- text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
162
- crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
158
+ boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
159
+ objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
160
+ text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
161
+ crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
163
162
 
164
163
  if self.detect_language:
165
164
  languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
@@ -172,7 +171,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
172
171
  boxes_per_page,
173
172
  objectness_scores_per_page,
174
173
  text_preds_per_page,
175
- origin_page_shapes, # type: ignore[arg-type]
174
+ origin_page_shapes,
176
175
  crop_orientations_per_page,
177
176
  orientations,
178
177
  languages_dict,
@@ -180,7 +179,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
180
179
  return out
181
180
 
182
181
  @staticmethod
183
- def get_text(text_pred: Dict) -> str:
182
+ def get_text(text_pred: dict) -> str:
184
183
  text = []
185
184
  for value in text_pred.values():
186
185
  text += [item[0] for item in value]
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,15 +1,62 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Tuple, Union
7
6
 
8
7
  import numpy as np
9
8
  import torch
10
9
  import torch.nn as nn
11
10
 
12
- __all__ = ["FASTConvLayer"]
11
+ __all__ = ["FASTConvLayer", "DropPath", "AdaptiveAvgPool2d"]
12
+
13
+
14
+ class DropPath(nn.Module):
15
+ """
16
+ DropPath (Drop Connect) layer. This is a stochastic version of the identity layer.
17
+ """
18
+
19
+ # Borrowed from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
20
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
21
+ super(DropPath, self).__init__()
22
+ self.drop_prob = drop_prob
23
+ self.scale_by_keep = scale_by_keep
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ if self.drop_prob == 0.0 or not self.training:
27
+ return x
28
+ keep_prob = 1 - self.drop_prob
29
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with different dimensions
30
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
31
+ if keep_prob > 0.0 and self.scale_by_keep:
32
+ random_tensor.div_(keep_prob)
33
+ return x * random_tensor
34
+
35
+
36
+ class AdaptiveAvgPool2d(nn.Module):
37
+ """
38
+ Custom AdaptiveAvgPool2d implementation which is ONNX and `torch.compile` compatible.
39
+
40
+ """
41
+
42
+ def __init__(self, output_size):
43
+ super().__init__()
44
+ self.output_size = output_size
45
+
46
+ def forward(self, x: torch.Tensor):
47
+ H_out, W_out = self.output_size
48
+ N, C, H, W = x.shape
49
+
50
+ out = torch.empty((N, C, H_out, W_out), device=x.device, dtype=x.dtype)
51
+ for oh in range(H_out):
52
+ start_h = (oh * H) // H_out
53
+ end_h = ((oh + 1) * H + H_out - 1) // H_out # ceil((oh+1)*H / H_out)
54
+ for ow in range(W_out):
55
+ start_w = (ow * W) // W_out
56
+ end_w = ((ow + 1) * W + W_out - 1) // W_out # ceil((ow+1)*W / W_out)
57
+ # average over the window
58
+ out[:, :, oh, ow] = x[:, :, start_h:end_h, start_w:end_w].mean(dim=(-2, -1))
59
+ return out
13
60
 
14
61
 
15
62
  class FASTConvLayer(nn.Module):
@@ -19,7 +66,7 @@ class FASTConvLayer(nn.Module):
19
66
  self,
20
67
  in_channels: int,
21
68
  out_channels: int,
22
- kernel_size: Union[int, Tuple[int, int]],
69
+ kernel_size: int | tuple[int, int],
23
70
  stride: int = 1,
24
71
  dilation: int = 1,
25
72
  groups: int = 1,
@@ -93,9 +140,7 @@ class FASTConvLayer(nn.Module):
93
140
 
94
141
  # The following logic is used to reparametrize the layer
95
142
  # Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py
96
- def _identity_to_conv(
97
- self, identity: Union[nn.BatchNorm2d, None]
98
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
143
+ def _identity_to_conv(self, identity: nn.BatchNorm2d | None) -> tuple[torch.Tensor, torch.Tensor] | tuple[int, int]:
99
144
  if identity is None or identity.running_var is None:
100
145
  return 0, 0
101
146
  if not hasattr(self, "id_tensor"):
@@ -106,18 +151,18 @@ class FASTConvLayer(nn.Module):
106
151
  id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
107
152
  self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
108
153
  kernel = self.id_tensor
109
- std = (identity.running_var + identity.eps).sqrt()
154
+ std = (identity.running_var + identity.eps).sqrt() # type: ignore
110
155
  t = (identity.weight / std).reshape(-1, 1, 1, 1)
111
156
  return kernel * t, identity.bias - identity.running_mean * identity.weight / std
112
157
 
113
- def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> Tuple[torch.Tensor, torch.Tensor]:
158
+ def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
114
159
  kernel = conv.weight
115
160
  kernel = self._pad_to_mxn_tensor(kernel)
116
161
  std = (bn.running_var + bn.eps).sqrt() # type: ignore
117
162
  t = (bn.weight / std).reshape(-1, 1, 1, 1)
118
163
  return kernel * t, bn.bias - bn.running_mean * bn.weight / std
119
164
 
120
- def _get_equivalent_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
165
+ def _get_equivalent_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
121
166
  kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
122
167
  if self.ver_conv is not None:
123
168
  kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type]
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -21,7 +21,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
21
21
  self,
22
22
  in_channels: int,
23
23
  out_channels: int,
24
- kernel_size: Union[int, Tuple[int, int]],
24
+ kernel_size: int | tuple[int, int],
25
25
  stride: int = 1,
26
26
  dilation: int = 1,
27
27
  groups: int = 1,
@@ -103,9 +103,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
103
103
 
104
104
  # The following logic is used to reparametrize the layer
105
105
  # Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py
106
- def _identity_to_conv(
107
- self, identity: layers.BatchNormalization
108
- ) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]:
106
+ def _identity_to_conv(self, identity: layers.BatchNormalization) -> tuple[tf.Tensor, tf.Tensor] | tuple[int, int]:
109
107
  if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"):
110
108
  return 0, 0
111
109
  if not hasattr(self, "id_tensor"):
@@ -120,7 +118,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
120
118
  t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
121
119
  return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
122
120
 
123
- def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]:
121
+ def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> tuple[tf.Tensor, tf.Tensor]:
124
122
  kernel = conv.kernel
125
123
  kernel = self._pad_to_mxn_tensor(kernel)
126
124
  std = tf.sqrt(bn.moving_variance + bn.epsilon)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,8 @@
6
6
  # This module 'transformer.py' is inspired by https://github.com/wenwenyu/MASTER-pytorch and Decoder is borrowed
7
7
 
8
8
  import math
9
- from typing import Any, Callable, Optional, Tuple
9
+ from collections.abc import Callable
10
+ from typing import Any
10
11
 
11
12
  import torch
12
13
  from torch import nn
@@ -33,26 +34,24 @@ class PositionalEncoding(nn.Module):
33
34
  """Forward pass
34
35
 
35
36
  Args:
36
- ----
37
37
  x: embeddings (batch, max_len, d_model)
38
38
 
39
- Returns
40
- -------
39
+ Returns:
41
40
  positional embeddings (batch, max_len, d_model)
42
41
  """
43
- x = x + self.pe[:, : x.size(1)]
42
+ x = x + self.pe[:, : x.size(1)] # type: ignore[index]
44
43
  return self.dropout(x)
45
44
 
46
45
 
47
46
  def scaled_dot_product_attention(
48
- query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None
49
- ) -> Tuple[torch.Tensor, torch.Tensor]:
47
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor | None = None
48
+ ) -> tuple[torch.Tensor, torch.Tensor]:
50
49
  """Scaled Dot-Product Attention"""
51
50
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
52
51
  if mask is not None:
53
52
  # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
54
- scores = scores.masked_fill(mask == 0, float("-inf"))
55
- p_attn = torch.softmax(scores, dim=-1)
53
+ scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined]
54
+ p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload]
56
55
  return torch.matmul(p_attn, value), p_attn
57
56
 
58
57
 
@@ -130,7 +129,7 @@ class EncoderBlock(nn.Module):
130
129
  PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
131
130
  ])
132
131
 
133
- def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
132
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
134
133
  output = x
135
134
 
136
135
  for i in range(self.num_layers):
@@ -183,8 +182,8 @@ class Decoder(nn.Module):
183
182
  self,
184
183
  tgt: torch.Tensor,
185
184
  memory: torch.Tensor,
186
- source_mask: Optional[torch.Tensor] = None,
187
- target_mask: Optional[torch.Tensor] = None,
185
+ source_mask: torch.Tensor | None = None,
186
+ target_mask: torch.Tensor | None = None,
188
187
  ) -> torch.Tensor:
189
188
  tgt = self.embed(tgt) * math.sqrt(self.d_model)
190
189
  pos_enc_tgt = self.positional_encoding(tgt)
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, Callable, Optional, Tuple
7
+ from collections.abc import Callable
8
+ from typing import Any
8
9
 
9
10
  import tensorflow as tf
10
11
  from tensorflow.keras import layers
@@ -43,12 +44,10 @@ class PositionalEncoding(layers.Layer, NestedObject):
43
44
  """Forward pass
44
45
 
45
46
  Args:
46
- ----
47
47
  x: embeddings (batch, max_len, d_model)
48
48
  **kwargs: additional arguments
49
49
 
50
- Returns
51
- -------
50
+ Returns:
52
51
  positional embeddings (batch, max_len, d_model)
53
52
  """
54
53
  if x.dtype == tf.float16: # amp fix: cast to half
@@ -60,8 +59,8 @@ class PositionalEncoding(layers.Layer, NestedObject):
60
59
 
61
60
  @tf.function
62
61
  def scaled_dot_product_attention(
63
- query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: Optional[tf.Tensor] = None
64
- ) -> Tuple[tf.Tensor, tf.Tensor]:
62
+ query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: tf.Tensor | None = None
63
+ ) -> tuple[tf.Tensor, tf.Tensor]:
65
64
  """Scaled Dot-Product Attention"""
66
65
  scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) / math.sqrt(query.shape[-1])
67
66
  if mask is not None:
@@ -160,7 +159,7 @@ class EncoderBlock(layers.Layer, NestedObject):
160
159
  PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
161
160
  ]
162
161
 
163
- def call(self, x: tf.Tensor, mask: Optional[tf.Tensor] = None, **kwargs: Any) -> tf.Tensor:
162
+ def call(self, x: tf.Tensor, mask: tf.Tensor | None = None, **kwargs: Any) -> tf.Tensor:
164
163
  output = x
165
164
 
166
165
  for i in range(self.num_layers):
@@ -210,8 +209,8 @@ class Decoder(layers.Layer, NestedObject):
210
209
  self,
211
210
  tgt: tf.Tensor,
212
211
  memory: tf.Tensor,
213
- source_mask: Optional[tf.Tensor] = None,
214
- target_mask: Optional[tf.Tensor] = None,
212
+ source_mask: tf.Tensor | None = None,
213
+ target_mask: tf.Tensor | None = None,
215
214
  **kwargs: Any,
216
215
  ) -> tf.Tensor:
217
216
  tgt = self.embed(tgt, **kwargs) * math.sqrt(self.d_model)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,10 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Tuple
8
7
 
9
8
  import torch
10
9
  from torch import nn
@@ -15,7 +14,7 @@ __all__ = ["PatchEmbedding"]
15
14
  class PatchEmbedding(nn.Module):
16
15
  """Compute 2D patch embeddings with cls token and positional encoding"""
17
16
 
18
- def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None:
17
+ def __init__(self, input_shape: tuple[int, int, int], embed_dim: int, patch_size: tuple[int, int]) -> None:
19
18
  super().__init__()
20
19
  channels, height, width = input_shape
21
20
  self.patch_size = patch_size
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import layers
@@ -17,7 +17,7 @@ __all__ = ["PatchEmbedding"]
17
17
  class PatchEmbedding(layers.Layer, NestedObject):
18
18
  """Compute 2D patch embeddings with cls token and positional encoding"""
19
19
 
20
- def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None:
20
+ def __init__(self, input_shape: tuple[int, int, int], embed_dim: int, patch_size: tuple[int, int]) -> None:
21
21
  super().__init__()
22
22
  height, width, _ = input_shape
23
23
  self.patch_size = patch_size
@@ -1,6 +1,6 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- else:
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]