python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +8 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +7 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +4 -5
  17. doctr/datasets/ic13.py +4 -5
  18. doctr/datasets/iiit5k.py +6 -5
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +6 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +6 -5
  27. doctr/datasets/svhn.py +6 -5
  28. doctr/datasets/svt.py +4 -5
  29. doctr/datasets/synthtext.py +4 -5
  30. doctr/datasets/utils.py +34 -29
  31. doctr/datasets/vocabs.py +17 -7
  32. doctr/datasets/wildreceipt.py +14 -10
  33. doctr/file_utils.py +2 -7
  34. doctr/io/elements.py +59 -79
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +30 -48
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +8 -11
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +5 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +8 -21
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +6 -8
  52. doctr/models/classification/predictor/tensorflow.py +6 -8
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +20 -31
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +8 -15
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +9 -12
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +6 -12
  65. doctr/models/classification/zoo.py +19 -14
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +14 -26
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +14 -23
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +5 -6
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +3 -7
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +4 -5
  91. doctr/models/kie_predictor/pytorch.py +18 -19
  92. doctr/models/kie_predictor/tensorflow.py +13 -14
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -10
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  101. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +28 -29
  104. doctr/models/predictor/pytorch.py +12 -13
  105. doctr/models/predictor/tensorflow.py +8 -9
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +10 -14
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +11 -23
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +12 -22
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +16 -22
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +12 -21
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +12 -20
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +14 -17
  136. doctr/models/utils/tensorflow.py +17 -16
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +20 -28
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +58 -22
  145. doctr/transforms/modules/tensorflow.py +18 -32
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +16 -47
  150. doctr/utils/metrics.py +17 -37
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +9 -13
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.10.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,11 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
- from typing import Dict, List, Tuple, Union
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -23,7 +22,6 @@ class LinkNetPostProcessor(DetectionPostProcessor):
23
22
  """Implements a post processor for LinkNet model.
24
23
 
25
24
  Args:
26
- ----
27
25
  bin_thresh: threshold used to binzarized p_map at inference time
28
26
  box_thresh: minimal objectness score to consider a box
29
27
  assume_straight_pages: whether the inputs were expected to have horizontal text elements
@@ -45,11 +43,9 @@ class LinkNetPostProcessor(DetectionPostProcessor):
45
43
  """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
46
44
 
47
45
  Args:
48
- ----
49
46
  points: The first parameter.
50
47
 
51
48
  Returns:
52
- -------
53
49
  a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
54
50
  """
55
51
  if not self.assume_straight_pages:
@@ -94,24 +90,22 @@ class LinkNetPostProcessor(DetectionPostProcessor):
94
90
  """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
95
91
 
96
92
  Args:
97
- ----
98
93
  pred: Pred map from differentiable linknet output
99
94
  bitmap: Bitmap map computed from pred (binarized)
100
95
  angle_tol: Comparison tolerance of the angle with the median angle across the page
101
96
  ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
102
97
 
103
98
  Returns:
104
- -------
105
99
  np tensor boxes for the bitmap, each box is a 6-element list
106
100
  containing x, y, w, h, alpha, score for the box
107
101
  """
108
102
  height, width = bitmap.shape[:2]
109
- boxes: List[Union[np.ndarray, List[float]]] = []
103
+ boxes: list[np.ndarray | list[float]] = []
110
104
  # get contours from connected components on the bitmap
111
105
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
112
106
  for contour in contours:
113
107
  # Check whether smallest enclosing bounding box is not too small
114
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
108
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
115
109
  continue
116
110
  # Compute objectness
117
111
  if self.assume_straight_pages:
@@ -152,7 +146,6 @@ class _LinkNet(BaseModel):
152
146
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
153
147
 
154
148
  Args:
155
- ----
156
149
  out_chan: number of channels for the output
157
150
  """
158
151
 
@@ -162,20 +155,18 @@ class _LinkNet(BaseModel):
162
155
 
163
156
  def build_target(
164
157
  self,
165
- target: List[Dict[str, np.ndarray]],
166
- output_shape: Tuple[int, int, int],
158
+ target: list[dict[str, np.ndarray]],
159
+ output_shape: tuple[int, int, int],
167
160
  channels_last: bool = True,
168
- ) -> Tuple[np.ndarray, np.ndarray]:
161
+ ) -> tuple[np.ndarray, np.ndarray]:
169
162
  """Build the target, and it's mask to be used from loss computation.
170
163
 
171
164
  Args:
172
- ----
173
165
  target: target coming from dataset
174
166
  output_shape: shape of the output of the model without batch_size
175
167
  channels_last: whether channels are last or not
176
168
 
177
169
  Returns:
178
- -------
179
170
  the new formatted target and the mask
180
171
  """
181
172
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, Dict, List, Optional, Tuple
6
+ from collections.abc import Callable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -20,7 +21,7 @@ from .base import LinkNetPostProcessor, _LinkNet
20
21
  __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
21
22
 
22
23
 
23
- default_cfgs: Dict[str, Dict[str, Any]] = {
24
+ default_cfgs: dict[str, dict[str, Any]] = {
24
25
  "linknet_resnet18": {
25
26
  "input_shape": (3, 1024, 1024),
26
27
  "mean": (0.798, 0.785, 0.772),
@@ -43,7 +44,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
43
44
 
44
45
 
45
46
  class LinkNetFPN(nn.Module):
46
- def __init__(self, layer_shapes: List[Tuple[int, int, int]]) -> None:
47
+ def __init__(self, layer_shapes: list[tuple[int, int, int]]) -> None:
47
48
  super().__init__()
48
49
  strides = [
49
50
  1 if (in_shape[-1] == out_shape[-1]) else 2
@@ -74,7 +75,7 @@ class LinkNetFPN(nn.Module):
74
75
  nn.ReLU(inplace=True),
75
76
  )
76
77
 
77
- def forward(self, feats: List[torch.Tensor]) -> torch.Tensor:
78
+ def forward(self, feats: list[torch.Tensor]) -> torch.Tensor:
78
79
  out = feats[-1]
79
80
  for decoder, fmap in zip(self.decoders[::-1], feats[:-1][::-1]):
80
81
  out = decoder(out) + fmap
@@ -89,7 +90,6 @@ class LinkNet(nn.Module, _LinkNet):
89
90
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
90
91
 
91
92
  Args:
92
- ----
93
93
  feature extractor: the backbone serving as feature extractor
94
94
  bin_thresh: threshold for binarization of the output feature map
95
95
  box_thresh: minimal objectness score to consider a box
@@ -108,8 +108,8 @@ class LinkNet(nn.Module, _LinkNet):
108
108
  head_chans: int = 32,
109
109
  assume_straight_pages: bool = True,
110
110
  exportable: bool = False,
111
- cfg: Optional[Dict[str, Any]] = None,
112
- class_names: List[str] = [CLASS_NAME],
111
+ cfg: dict[str, Any] | None = None,
112
+ class_names: list[str] = [CLASS_NAME],
113
113
  ) -> None:
114
114
  super().__init__()
115
115
  self.class_names = class_names
@@ -163,16 +163,16 @@ class LinkNet(nn.Module, _LinkNet):
163
163
  def forward(
164
164
  self,
165
165
  x: torch.Tensor,
166
- target: Optional[List[np.ndarray]] = None,
166
+ target: list[np.ndarray] | None = None,
167
167
  return_model_output: bool = False,
168
168
  return_preds: bool = False,
169
169
  **kwargs: Any,
170
- ) -> Dict[str, Any]:
170
+ ) -> dict[str, Any]:
171
171
  feats = self.feat_extractor(x)
172
172
  logits = self.fpn([feats[str(idx)] for idx in range(len(feats))])
173
173
  logits = self.classifier(logits)
174
174
 
175
- out: Dict[str, Any] = {}
175
+ out: dict[str, Any] = {}
176
176
  if self.exportable:
177
177
  out["logits"] = logits
178
178
  return out
@@ -183,11 +183,16 @@ class LinkNet(nn.Module, _LinkNet):
183
183
  out["out_map"] = prob_map
184
184
 
185
185
  if target is None or return_preds:
186
- # Post-process boxes
187
- out["preds"] = [
188
- dict(zip(self.class_names, preds))
189
- for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
190
- ]
186
+ # Disable for torch.compile compatibility
187
+ @torch.compiler.disable # type: ignore[attr-defined]
188
+ def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
189
+ return [
190
+ dict(zip(self.class_names, preds))
191
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
192
+ ]
193
+
194
+ # Post-process boxes (keep only text predictions)
195
+ out["preds"] = _postprocess(prob_map)
191
196
 
192
197
  if target is not None:
193
198
  loss = self.compute_loss(logits, target)
@@ -198,7 +203,7 @@ class LinkNet(nn.Module, _LinkNet):
198
203
  def compute_loss(
199
204
  self,
200
205
  out_map: torch.Tensor,
201
- target: List[np.ndarray],
206
+ target: list[np.ndarray],
202
207
  gamma: float = 2.0,
203
208
  alpha: float = 0.5,
204
209
  eps: float = 1e-8,
@@ -207,7 +212,6 @@ class LinkNet(nn.Module, _LinkNet):
207
212
  <https://github.com/tensorflow/addons/>`_.
208
213
 
209
214
  Args:
210
- ----
211
215
  out_map: output feature map of the model of shape (N, num_classes, H, W)
212
216
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
213
217
  gamma: modulating factor in the focal loss formula
@@ -215,7 +219,6 @@ class LinkNet(nn.Module, _LinkNet):
215
219
  eps: epsilon factor in dice loss
216
220
 
217
221
  Returns:
218
- -------
219
222
  A loss tensor
220
223
  """
221
224
  _target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
@@ -252,9 +255,9 @@ def _linknet(
252
255
  arch: str,
253
256
  pretrained: bool,
254
257
  backbone_fn: Callable[[bool], nn.Module],
255
- fpn_layers: List[str],
258
+ fpn_layers: list[str],
256
259
  pretrained_backbone: bool = True,
257
- ignore_keys: Optional[List[str]] = None,
260
+ ignore_keys: list[str] | None = None,
258
261
  **kwargs: Any,
259
262
  ) -> LinkNet:
260
263
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -295,12 +298,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
295
298
  >>> out = model(input_tensor)
296
299
 
297
300
  Args:
298
- ----
299
301
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
300
302
  **kwargs: keyword arguments of the LinkNet architecture
301
303
 
302
304
  Returns:
303
- -------
304
305
  text detection architecture
305
306
  """
306
307
  return _linknet(
@@ -327,12 +328,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
327
328
  >>> out = model(input_tensor)
328
329
 
329
330
  Args:
330
- ----
331
331
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
332
332
  **kwargs: keyword arguments of the LinkNet architecture
333
333
 
334
334
  Returns:
335
- -------
336
335
  text detection architecture
337
336
  """
338
337
  return _linknet(
@@ -359,12 +358,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
359
358
  >>> out = model(input_tensor)
360
359
 
361
360
  Args:
362
- ----
363
361
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
364
362
  **kwargs: keyword arguments of the LinkNet architecture
365
363
 
366
364
  Returns:
367
- -------
368
365
  text detection architecture
369
366
  """
370
367
  return _linknet(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional, Tuple
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
@@ -27,7 +27,7 @@ from .base import LinkNetPostProcessor, _LinkNet
27
27
 
28
28
  __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
29
29
 
30
- default_cfgs: Dict[str, Dict[str, Any]] = {
30
+ default_cfgs: dict[str, dict[str, Any]] = {
31
31
  "linknet_resnet18": {
32
32
  "mean": (0.798, 0.785, 0.772),
33
33
  "std": (0.264, 0.2749, 0.287),
@@ -73,7 +73,7 @@ class LinkNetFPN(Model, NestedObject):
73
73
  def __init__(
74
74
  self,
75
75
  out_chans: int,
76
- in_shapes: List[Tuple[int, ...]],
76
+ in_shapes: list[tuple[int, ...]],
77
77
  ) -> None:
78
78
  super().__init__()
79
79
  self.out_chans = out_chans
@@ -85,7 +85,7 @@ class LinkNetFPN(Model, NestedObject):
85
85
  for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
86
86
  ]
87
87
 
88
- def call(self, x: List[tf.Tensor], **kwargs: Any) -> tf.Tensor:
88
+ def call(self, x: list[tf.Tensor], **kwargs: Any) -> tf.Tensor:
89
89
  out = 0
90
90
  for decoder, fmap in zip(self.decoders, x[::-1]):
91
91
  out = decoder(out + fmap, **kwargs)
@@ -100,7 +100,6 @@ class LinkNet(_LinkNet, Model):
100
100
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
101
101
 
102
102
  Args:
103
- ----
104
103
  feature extractor: the backbone serving as feature extractor
105
104
  fpn_channels: number of channels each extracted feature maps is mapped to
106
105
  bin_thresh: threshold for binarization of the output feature map
@@ -111,7 +110,7 @@ class LinkNet(_LinkNet, Model):
111
110
  class_names: list of class names
112
111
  """
113
112
 
114
- _children_names: List[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"]
113
+ _children_names: list[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"]
115
114
 
116
115
  def __init__(
117
116
  self,
@@ -121,8 +120,8 @@ class LinkNet(_LinkNet, Model):
121
120
  box_thresh: float = 0.1,
122
121
  assume_straight_pages: bool = True,
123
122
  exportable: bool = False,
124
- cfg: Optional[Dict[str, Any]] = None,
125
- class_names: List[str] = [CLASS_NAME],
123
+ cfg: dict[str, Any] | None = None,
124
+ class_names: list[str] = [CLASS_NAME],
126
125
  ) -> None:
127
126
  super().__init__(cfg=cfg)
128
127
 
@@ -167,7 +166,7 @@ class LinkNet(_LinkNet, Model):
167
166
  def compute_loss(
168
167
  self,
169
168
  out_map: tf.Tensor,
170
- target: List[Dict[str, np.ndarray]],
169
+ target: list[dict[str, np.ndarray]],
171
170
  gamma: float = 2.0,
172
171
  alpha: float = 0.5,
173
172
  eps: float = 1e-8,
@@ -176,7 +175,6 @@ class LinkNet(_LinkNet, Model):
176
175
  <https://github.com/tensorflow/addons/>`_.
177
176
 
178
177
  Args:
179
- ----
180
178
  out_map: output feature map of the model of shape N x H x W x 1
181
179
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
182
180
  gamma: modulating factor in the focal loss formula
@@ -184,7 +182,6 @@ class LinkNet(_LinkNet, Model):
184
182
  eps: epsilon factor in dice loss
185
183
 
186
184
  Returns:
187
- -------
188
185
  A loss tensor
189
186
  """
190
187
  seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True)
@@ -218,16 +215,16 @@ class LinkNet(_LinkNet, Model):
218
215
  def call(
219
216
  self,
220
217
  x: tf.Tensor,
221
- target: Optional[List[Dict[str, np.ndarray]]] = None,
218
+ target: list[dict[str, np.ndarray]] | None = None,
222
219
  return_model_output: bool = False,
223
220
  return_preds: bool = False,
224
221
  **kwargs: Any,
225
- ) -> Dict[str, Any]:
222
+ ) -> dict[str, Any]:
226
223
  feat_maps = self.feat_extractor(x, **kwargs)
227
224
  logits = self.fpn(feat_maps, **kwargs)
228
225
  logits = self.classifier(logits, **kwargs)
229
226
 
230
- out: Dict[str, tf.Tensor] = {}
227
+ out: dict[str, tf.Tensor] = {}
231
228
  if self.exportable:
232
229
  out["logits"] = logits
233
230
  return out
@@ -253,9 +250,9 @@ def _linknet(
253
250
  arch: str,
254
251
  pretrained: bool,
255
252
  backbone_fn,
256
- fpn_layers: List[str],
253
+ fpn_layers: list[str],
257
254
  pretrained_backbone: bool = True,
258
- input_shape: Optional[Tuple[int, int, int]] = None,
255
+ input_shape: tuple[int, int, int] | None = None,
259
256
  **kwargs: Any,
260
257
  ) -> LinkNet:
261
258
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -305,12 +302,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
305
302
  >>> out = model(input_tensor)
306
303
 
307
304
  Args:
308
- ----
309
305
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
310
306
  **kwargs: keyword arguments of the LinkNet architecture
311
307
 
312
308
  Returns:
313
- -------
314
309
  text detection architecture
315
310
  """
316
311
  return _linknet(
@@ -333,12 +328,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
333
328
  >>> out = model(input_tensor)
334
329
 
335
330
  Args:
336
- ----
337
331
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
338
332
  **kwargs: keyword arguments of the LinkNet architecture
339
333
 
340
334
  Returns:
341
- -------
342
335
  text detection architecture
343
336
  """
344
337
  return _linknet(
@@ -361,12 +354,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
361
354
  >>> out = model(input_tensor)
362
355
 
363
356
  Args:
364
- ----
365
357
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
366
358
  **kwargs: keyword arguments of the LinkNet architecture
367
359
 
368
360
  Returns:
369
- -------
370
361
  text detection architecture
371
362
  """
372
363
  return _linknet(
@@ -1,6 +1,6 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- else:
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -20,7 +20,6 @@ class DetectionPredictor(nn.Module):
20
20
  """Implements an object able to localize text elements in a document
21
21
 
22
22
  Args:
23
- ----
24
23
  pre_processor: transform inputs for easier batched model inference
25
24
  model: core detection architecture
26
25
  """
@@ -37,10 +36,10 @@ class DetectionPredictor(nn.Module):
37
36
  @torch.inference_mode()
38
37
  def forward(
39
38
  self,
40
- pages: List[Union[np.ndarray, torch.Tensor]],
39
+ pages: list[np.ndarray | torch.Tensor],
41
40
  return_maps: bool = False,
42
41
  **kwargs: Any,
43
- ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
42
+ ) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
44
43
  # Extract parameters from the preprocessor
45
44
  preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
46
45
  symmetric_pad = self.pre_processor.resize.symmetric_pad
@@ -60,11 +59,11 @@ class DetectionPredictor(nn.Module):
60
59
  ]
61
60
  # Remove padding from loc predictions
62
61
  preds = _remove_padding(
63
- pages, # type: ignore[arg-type]
62
+ pages,
64
63
  [pred for batch in predicted_batches for pred in batch["preds"]],
65
64
  preserve_aspect_ratio=preserve_aspect_ratio,
66
65
  symmetric_pad=symmetric_pad,
67
- assume_straight_pages=assume_straight_pages,
66
+ assume_straight_pages=assume_straight_pages, # type: ignore[arg-type]
68
67
  )
69
68
 
70
69
  if return_maps:
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -20,12 +20,11 @@ class DetectionPredictor(NestedObject):
20
20
  """Implements an object able to localize text elements in a document
21
21
 
22
22
  Args:
23
- ----
24
23
  pre_processor: transform inputs for easier batched model inference
25
24
  model: core detection architecture
26
25
  """
27
26
 
28
- _children_names: List[str] = ["pre_processor", "model"]
27
+ _children_names: list[str] = ["pre_processor", "model"]
29
28
 
30
29
  def __init__(
31
30
  self,
@@ -37,10 +36,10 @@ class DetectionPredictor(NestedObject):
37
36
 
38
37
  def __call__(
39
38
  self,
40
- pages: List[Union[np.ndarray, tf.Tensor]],
39
+ pages: list[np.ndarray | tf.Tensor],
41
40
  return_maps: bool = False,
42
41
  **kwargs: Any,
43
- ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
42
+ ) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
44
43
  # Extract parameters from the preprocessor
45
44
  preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
46
45
  symmetric_pad = self.pre_processor.resize.symmetric_pad
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List
6
+ from typing import Any
7
7
 
8
8
  from doctr.file_utils import is_tf_available, is_torch_available
9
9
 
@@ -14,7 +14,7 @@ from .predictor import DetectionPredictor
14
14
 
15
15
  __all__ = ["detection_predictor"]
16
16
 
17
- ARCHS: List[str]
17
+ ARCHS: list[str]
18
18
 
19
19
 
20
20
  if is_tf_available():
@@ -56,7 +56,14 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
56
56
  if isinstance(_model, detection.FAST):
57
57
  _model = reparameterize(_model)
58
58
  else:
59
- if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
59
+ allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST]
60
+ if is_torch_available():
61
+ # Adding the type for torch compiled models to the allowed architectures
62
+ from doctr.models.utils import _CompiledModule
63
+
64
+ allowed_archs.append(_CompiledModule)
65
+
66
+ if not isinstance(arch, tuple(allowed_archs)):
60
67
  raise ValueError(f"unknown architecture: {type(arch)}")
61
68
 
62
69
  _model = arch
@@ -79,6 +86,9 @@ def detection_predictor(
79
86
  arch: Any = "fast_base",
80
87
  pretrained: bool = False,
81
88
  assume_straight_pages: bool = True,
89
+ preserve_aspect_ratio: bool = True,
90
+ symmetric_pad: bool = True,
91
+ batch_size: int = 2,
82
92
  **kwargs: Any,
83
93
  ) -> DetectionPredictor:
84
94
  """Text detection architecture.
@@ -90,14 +100,24 @@ def detection_predictor(
90
100
  >>> out = model([input_page])
91
101
 
92
102
  Args:
93
- ----
94
103
  arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
95
104
  pretrained: If True, returns a model pre-trained on our text detection dataset
96
105
  assume_straight_pages: If True, fit straight boxes to the page
106
+ preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
107
+ running the detection model on it
108
+ symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
109
+ batch_size: number of samples the model processes in parallel
97
110
  **kwargs: optional keyword arguments passed to the architecture
98
111
 
99
112
  Returns:
100
- -------
101
113
  Detection predictor
102
114
  """
103
- return _predictor(arch, pretrained, assume_straight_pages, **kwargs)
115
+ return _predictor(
116
+ arch=arch,
117
+ pretrained=pretrained,
118
+ assume_straight_pages=assume_straight_pages,
119
+ preserve_aspect_ratio=preserve_aspect_ratio,
120
+ symmetric_pad=symmetric_pad,
121
+ batch_size=batch_size,
122
+ **kwargs,
123
+ )
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -61,7 +61,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
61
61
  """Save model and config to disk for pushing to huggingface hub
62
62
 
63
63
  Args:
64
- ----
65
64
  model: TF or PyTorch model to be saved
66
65
  save_dir: directory to save model and config
67
66
  arch: architecture name
@@ -97,7 +96,6 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
97
96
  >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
98
97
 
99
98
  Args:
100
- ----
101
99
  model: TF or PyTorch model to be saved
102
100
  model_name: name of the model which is also the repository name
103
101
  task: task name
@@ -114,9 +112,9 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
114
112
  # default readme
115
113
  readme = textwrap.dedent(
116
114
  f"""
117
- ---
115
+
118
116
  language: en
119
- ---
117
+
120
118
 
121
119
  <p align="center">
122
120
  <img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
@@ -190,12 +188,10 @@ def from_hub(repo_id: str, **kwargs: Any):
190
188
  >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn")
191
189
 
192
190
  Args:
193
- ----
194
191
  repo_id: HuggingFace model hub repo
195
192
  kwargs: kwargs of `hf_hub_download` or `snapshot_download`
196
193
 
197
194
  Returns:
198
- -------
199
195
  Model loaded with the checkpoint
200
196
  """
201
197
  # Get the config
@@ -1,6 +1,6 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- else:
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]