python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, Dict, List, Optional, Tuple
6
+ from collections.abc import Callable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -20,7 +21,7 @@ from .base import LinkNetPostProcessor, _LinkNet
20
21
  __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
21
22
 
22
23
 
23
- default_cfgs: Dict[str, Dict[str, Any]] = {
24
+ default_cfgs: dict[str, dict[str, Any]] = {
24
25
  "linknet_resnet18": {
25
26
  "input_shape": (3, 1024, 1024),
26
27
  "mean": (0.798, 0.785, 0.772),
@@ -43,7 +44,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
43
44
 
44
45
 
45
46
  class LinkNetFPN(nn.Module):
46
- def __init__(self, layer_shapes: List[Tuple[int, int, int]]) -> None:
47
+ def __init__(self, layer_shapes: list[tuple[int, int, int]]) -> None:
47
48
  super().__init__()
48
49
  strides = [
49
50
  1 if (in_shape[-1] == out_shape[-1]) else 2
@@ -74,7 +75,7 @@ class LinkNetFPN(nn.Module):
74
75
  nn.ReLU(inplace=True),
75
76
  )
76
77
 
77
- def forward(self, feats: List[torch.Tensor]) -> torch.Tensor:
78
+ def forward(self, feats: list[torch.Tensor]) -> torch.Tensor:
78
79
  out = feats[-1]
79
80
  for decoder, fmap in zip(self.decoders[::-1], feats[:-1][::-1]):
80
81
  out = decoder(out) + fmap
@@ -89,7 +90,6 @@ class LinkNet(nn.Module, _LinkNet):
89
90
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
90
91
 
91
92
  Args:
92
- ----
93
93
  feature extractor: the backbone serving as feature extractor
94
94
  bin_thresh: threshold for binarization of the output feature map
95
95
  box_thresh: minimal objectness score to consider a box
@@ -108,8 +108,8 @@ class LinkNet(nn.Module, _LinkNet):
108
108
  head_chans: int = 32,
109
109
  assume_straight_pages: bool = True,
110
110
  exportable: bool = False,
111
- cfg: Optional[Dict[str, Any]] = None,
112
- class_names: List[str] = [CLASS_NAME],
111
+ cfg: dict[str, Any] | None = None,
112
+ class_names: list[str] = [CLASS_NAME],
113
113
  ) -> None:
114
114
  super().__init__()
115
115
  self.class_names = class_names
@@ -160,19 +160,28 @@ class LinkNet(nn.Module, _LinkNet):
160
160
  m.weight.data.fill_(1.0)
161
161
  m.bias.data.zero_()
162
162
 
163
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
164
+ """Load pretrained parameters onto the model
165
+
166
+ Args:
167
+ path_or_url: the path or URL to the model parameters (checkpoint)
168
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
169
+ """
170
+ load_pretrained_params(self, path_or_url, **kwargs)
171
+
163
172
  def forward(
164
173
  self,
165
174
  x: torch.Tensor,
166
- target: Optional[List[np.ndarray]] = None,
175
+ target: list[np.ndarray] | None = None,
167
176
  return_model_output: bool = False,
168
177
  return_preds: bool = False,
169
178
  **kwargs: Any,
170
- ) -> Dict[str, Any]:
179
+ ) -> dict[str, Any]:
171
180
  feats = self.feat_extractor(x)
172
181
  logits = self.fpn([feats[str(idx)] for idx in range(len(feats))])
173
182
  logits = self.classifier(logits)
174
183
 
175
- out: Dict[str, Any] = {}
184
+ out: dict[str, Any] = {}
176
185
  if self.exportable:
177
186
  out["logits"] = logits
178
187
  return out
@@ -183,11 +192,16 @@ class LinkNet(nn.Module, _LinkNet):
183
192
  out["out_map"] = prob_map
184
193
 
185
194
  if target is None or return_preds:
186
- # Post-process boxes
187
- out["preds"] = [
188
- dict(zip(self.class_names, preds))
189
- for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
190
- ]
195
+ # Disable for torch.compile compatibility
196
+ @torch.compiler.disable # type: ignore[attr-defined]
197
+ def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
198
+ return [
199
+ dict(zip(self.class_names, preds))
200
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
201
+ ]
202
+
203
+ # Post-process boxes (keep only text predictions)
204
+ out["preds"] = _postprocess(prob_map)
191
205
 
192
206
  if target is not None:
193
207
  loss = self.compute_loss(logits, target)
@@ -198,7 +212,7 @@ class LinkNet(nn.Module, _LinkNet):
198
212
  def compute_loss(
199
213
  self,
200
214
  out_map: torch.Tensor,
201
- target: List[np.ndarray],
215
+ target: list[np.ndarray],
202
216
  gamma: float = 2.0,
203
217
  alpha: float = 0.5,
204
218
  eps: float = 1e-8,
@@ -207,7 +221,6 @@ class LinkNet(nn.Module, _LinkNet):
207
221
  <https://github.com/tensorflow/addons/>`_.
208
222
 
209
223
  Args:
210
- ----
211
224
  out_map: output feature map of the model of shape (N, num_classes, H, W)
212
225
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
213
226
  gamma: modulating factor in the focal loss formula
@@ -215,7 +228,6 @@ class LinkNet(nn.Module, _LinkNet):
215
228
  eps: epsilon factor in dice loss
216
229
 
217
230
  Returns:
218
- -------
219
231
  A loss tensor
220
232
  """
221
233
  _target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
@@ -252,9 +264,9 @@ def _linknet(
252
264
  arch: str,
253
265
  pretrained: bool,
254
266
  backbone_fn: Callable[[bool], nn.Module],
255
- fpn_layers: List[str],
267
+ fpn_layers: list[str],
256
268
  pretrained_backbone: bool = True,
257
- ignore_keys: Optional[List[str]] = None,
269
+ ignore_keys: list[str] | None = None,
258
270
  **kwargs: Any,
259
271
  ) -> LinkNet:
260
272
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -279,7 +291,7 @@ def _linknet(
279
291
  _ignore_keys = (
280
292
  ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
281
293
  )
282
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
294
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
283
295
 
284
296
  return model
285
297
 
@@ -295,12 +307,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
295
307
  >>> out = model(input_tensor)
296
308
 
297
309
  Args:
298
- ----
299
310
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
300
311
  **kwargs: keyword arguments of the LinkNet architecture
301
312
 
302
313
  Returns:
303
- -------
304
314
  text detection architecture
305
315
  """
306
316
  return _linknet(
@@ -327,12 +337,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
327
337
  >>> out = model(input_tensor)
328
338
 
329
339
  Args:
330
- ----
331
340
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
332
341
  **kwargs: keyword arguments of the LinkNet architecture
333
342
 
334
343
  Returns:
335
- -------
336
344
  text detection architecture
337
345
  """
338
346
  return _linknet(
@@ -359,12 +367,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
359
367
  >>> out = model(input_tensor)
360
368
 
361
369
  Args:
362
- ----
363
370
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
364
371
  **kwargs: keyword arguments of the LinkNet architecture
365
372
 
366
373
  Returns:
367
- -------
368
374
  text detection architecture
369
375
  """
370
376
  return _linknet(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional, Tuple
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
@@ -27,7 +27,7 @@ from .base import LinkNetPostProcessor, _LinkNet
27
27
 
28
28
  __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
29
29
 
30
- default_cfgs: Dict[str, Dict[str, Any]] = {
30
+ default_cfgs: dict[str, dict[str, Any]] = {
31
31
  "linknet_resnet18": {
32
32
  "mean": (0.798, 0.785, 0.772),
33
33
  "std": (0.264, 0.2749, 0.287),
@@ -73,7 +73,7 @@ class LinkNetFPN(Model, NestedObject):
73
73
  def __init__(
74
74
  self,
75
75
  out_chans: int,
76
- in_shapes: List[Tuple[int, ...]],
76
+ in_shapes: list[tuple[int, ...]],
77
77
  ) -> None:
78
78
  super().__init__()
79
79
  self.out_chans = out_chans
@@ -85,7 +85,7 @@ class LinkNetFPN(Model, NestedObject):
85
85
  for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
86
86
  ]
87
87
 
88
- def call(self, x: List[tf.Tensor], **kwargs: Any) -> tf.Tensor:
88
+ def call(self, x: list[tf.Tensor], **kwargs: Any) -> tf.Tensor:
89
89
  out = 0
90
90
  for decoder, fmap in zip(self.decoders, x[::-1]):
91
91
  out = decoder(out + fmap, **kwargs)
@@ -100,7 +100,6 @@ class LinkNet(_LinkNet, Model):
100
100
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
101
101
 
102
102
  Args:
103
- ----
104
103
  feature extractor: the backbone serving as feature extractor
105
104
  fpn_channels: number of channels each extracted feature maps is mapped to
106
105
  bin_thresh: threshold for binarization of the output feature map
@@ -111,7 +110,7 @@ class LinkNet(_LinkNet, Model):
111
110
  class_names: list of class names
112
111
  """
113
112
 
114
- _children_names: List[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"]
113
+ _children_names: list[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"]
115
114
 
116
115
  def __init__(
117
116
  self,
@@ -121,8 +120,8 @@ class LinkNet(_LinkNet, Model):
121
120
  box_thresh: float = 0.1,
122
121
  assume_straight_pages: bool = True,
123
122
  exportable: bool = False,
124
- cfg: Optional[Dict[str, Any]] = None,
125
- class_names: List[str] = [CLASS_NAME],
123
+ cfg: dict[str, Any] | None = None,
124
+ class_names: list[str] = [CLASS_NAME],
126
125
  ) -> None:
127
126
  super().__init__(cfg=cfg)
128
127
 
@@ -164,10 +163,19 @@ class LinkNet(_LinkNet, Model):
164
163
  assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
165
164
  )
166
165
 
166
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
167
+ """Load pretrained parameters onto the model
168
+
169
+ Args:
170
+ path_or_url: the path or URL to the model parameters (checkpoint)
171
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
172
+ """
173
+ load_pretrained_params(self, path_or_url, **kwargs)
174
+
167
175
  def compute_loss(
168
176
  self,
169
177
  out_map: tf.Tensor,
170
- target: List[Dict[str, np.ndarray]],
178
+ target: list[dict[str, np.ndarray]],
171
179
  gamma: float = 2.0,
172
180
  alpha: float = 0.5,
173
181
  eps: float = 1e-8,
@@ -176,7 +184,6 @@ class LinkNet(_LinkNet, Model):
176
184
  <https://github.com/tensorflow/addons/>`_.
177
185
 
178
186
  Args:
179
- ----
180
187
  out_map: output feature map of the model of shape N x H x W x 1
181
188
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
182
189
  gamma: modulating factor in the focal loss formula
@@ -184,7 +191,6 @@ class LinkNet(_LinkNet, Model):
184
191
  eps: epsilon factor in dice loss
185
192
 
186
193
  Returns:
187
- -------
188
194
  A loss tensor
189
195
  """
190
196
  seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True)
@@ -218,16 +224,16 @@ class LinkNet(_LinkNet, Model):
218
224
  def call(
219
225
  self,
220
226
  x: tf.Tensor,
221
- target: Optional[List[Dict[str, np.ndarray]]] = None,
227
+ target: list[dict[str, np.ndarray]] | None = None,
222
228
  return_model_output: bool = False,
223
229
  return_preds: bool = False,
224
230
  **kwargs: Any,
225
- ) -> Dict[str, Any]:
231
+ ) -> dict[str, Any]:
226
232
  feat_maps = self.feat_extractor(x, **kwargs)
227
233
  logits = self.fpn(feat_maps, **kwargs)
228
234
  logits = self.classifier(logits, **kwargs)
229
235
 
230
- out: Dict[str, tf.Tensor] = {}
236
+ out: dict[str, tf.Tensor] = {}
231
237
  if self.exportable:
232
238
  out["logits"] = logits
233
239
  return out
@@ -253,9 +259,9 @@ def _linknet(
253
259
  arch: str,
254
260
  pretrained: bool,
255
261
  backbone_fn,
256
- fpn_layers: List[str],
262
+ fpn_layers: list[str],
257
263
  pretrained_backbone: bool = True,
258
- input_shape: Optional[Tuple[int, int, int]] = None,
264
+ input_shape: tuple[int, int, int] | None = None,
259
265
  **kwargs: Any,
260
266
  ) -> LinkNet:
261
267
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -285,8 +291,7 @@ def _linknet(
285
291
  # Load pretrained parameters
286
292
  if pretrained:
287
293
  # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
288
- load_pretrained_params(
289
- model,
294
+ model.from_pretrained(
290
295
  _cfg["url"],
291
296
  skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
292
297
  )
@@ -305,12 +310,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
305
310
  >>> out = model(input_tensor)
306
311
 
307
312
  Args:
308
- ----
309
313
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
310
314
  **kwargs: keyword arguments of the LinkNet architecture
311
315
 
312
316
  Returns:
313
- -------
314
317
  text detection architecture
315
318
  """
316
319
  return _linknet(
@@ -333,12 +336,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
333
336
  >>> out = model(input_tensor)
334
337
 
335
338
  Args:
336
- ----
337
339
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
338
340
  **kwargs: keyword arguments of the LinkNet architecture
339
341
 
340
342
  Returns:
341
- -------
342
343
  text detection architecture
343
344
  """
344
345
  return _linknet(
@@ -361,12 +362,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
361
362
  >>> out = model(input_tensor)
362
363
 
363
364
  Args:
364
- ----
365
365
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
366
366
  **kwargs: keyword arguments of the LinkNet architecture
367
367
 
368
368
  Returns:
369
- -------
370
369
  text detection architecture
371
370
  """
372
371
  return _linknet(
@@ -1,6 +1,6 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- else:
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -20,7 +20,6 @@ class DetectionPredictor(nn.Module):
20
20
  """Implements an object able to localize text elements in a document
21
21
 
22
22
  Args:
23
- ----
24
23
  pre_processor: transform inputs for easier batched model inference
25
24
  model: core detection architecture
26
25
  """
@@ -37,10 +36,10 @@ class DetectionPredictor(nn.Module):
37
36
  @torch.inference_mode()
38
37
  def forward(
39
38
  self,
40
- pages: List[Union[np.ndarray, torch.Tensor]],
39
+ pages: list[np.ndarray | torch.Tensor],
41
40
  return_maps: bool = False,
42
41
  **kwargs: Any,
43
- ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
42
+ ) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
44
43
  # Extract parameters from the preprocessor
45
44
  preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
46
45
  symmetric_pad = self.pre_processor.resize.symmetric_pad
@@ -60,11 +59,11 @@ class DetectionPredictor(nn.Module):
60
59
  ]
61
60
  # Remove padding from loc predictions
62
61
  preds = _remove_padding(
63
- pages, # type: ignore[arg-type]
62
+ pages,
64
63
  [pred for batch in predicted_batches for pred in batch["preds"]],
65
64
  preserve_aspect_ratio=preserve_aspect_ratio,
66
65
  symmetric_pad=symmetric_pad,
67
- assume_straight_pages=assume_straight_pages,
66
+ assume_straight_pages=assume_straight_pages, # type: ignore[arg-type]
68
67
  )
69
68
 
70
69
  if return_maps:
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -20,12 +20,11 @@ class DetectionPredictor(NestedObject):
20
20
  """Implements an object able to localize text elements in a document
21
21
 
22
22
  Args:
23
- ----
24
23
  pre_processor: transform inputs for easier batched model inference
25
24
  model: core detection architecture
26
25
  """
27
26
 
28
- _children_names: List[str] = ["pre_processor", "model"]
27
+ _children_names: list[str] = ["pre_processor", "model"]
29
28
 
30
29
  def __init__(
31
30
  self,
@@ -37,10 +36,10 @@ class DetectionPredictor(NestedObject):
37
36
 
38
37
  def __call__(
39
38
  self,
40
- pages: List[Union[np.ndarray, tf.Tensor]],
39
+ pages: list[np.ndarray | tf.Tensor],
41
40
  return_maps: bool = False,
42
41
  **kwargs: Any,
43
- ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
42
+ ) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
44
43
  # Extract parameters from the preprocessor
45
44
  preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
46
45
  symmetric_pad = self.pre_processor.resize.symmetric_pad
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List
6
+ from typing import Any
7
7
 
8
8
  from doctr.file_utils import is_tf_available, is_torch_available
9
9
 
@@ -14,7 +14,7 @@ from .predictor import DetectionPredictor
14
14
 
15
15
  __all__ = ["detection_predictor"]
16
16
 
17
- ARCHS: List[str]
17
+ ARCHS: list[str]
18
18
 
19
19
 
20
20
  if is_tf_available():
@@ -56,7 +56,14 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
56
56
  if isinstance(_model, detection.FAST):
57
57
  _model = reparameterize(_model)
58
58
  else:
59
- if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
59
+ allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST]
60
+ if is_torch_available():
61
+ # Adding the type for torch compiled models to the allowed architectures
62
+ from doctr.models.utils import _CompiledModule
63
+
64
+ allowed_archs.append(_CompiledModule)
65
+
66
+ if not isinstance(arch, tuple(allowed_archs)):
60
67
  raise ValueError(f"unknown architecture: {type(arch)}")
61
68
 
62
69
  _model = arch
@@ -79,6 +86,9 @@ def detection_predictor(
79
86
  arch: Any = "fast_base",
80
87
  pretrained: bool = False,
81
88
  assume_straight_pages: bool = True,
89
+ preserve_aspect_ratio: bool = True,
90
+ symmetric_pad: bool = True,
91
+ batch_size: int = 2,
82
92
  **kwargs: Any,
83
93
  ) -> DetectionPredictor:
84
94
  """Text detection architecture.
@@ -90,14 +100,24 @@ def detection_predictor(
90
100
  >>> out = model([input_page])
91
101
 
92
102
  Args:
93
- ----
94
103
  arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
95
104
  pretrained: If True, returns a model pre-trained on our text detection dataset
96
105
  assume_straight_pages: If True, fit straight boxes to the page
106
+ preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
107
+ running the detection model on it
108
+ symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
109
+ batch_size: number of samples the model processes in parallel
97
110
  **kwargs: optional keyword arguments passed to the architecture
98
111
 
99
112
  Returns:
100
- -------
101
113
  Detection predictor
102
114
  """
103
- return _predictor(arch, pretrained, assume_straight_pages, **kwargs)
115
+ return _predictor(
116
+ arch=arch,
117
+ pretrained=pretrained,
118
+ assume_straight_pages=assume_straight_pages,
119
+ preserve_aspect_ratio=preserve_aspect_ratio,
120
+ symmetric_pad=symmetric_pad,
121
+ batch_size=batch_size,
122
+ **kwargs,
123
+ )
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -61,7 +61,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
61
61
  """Save model and config to disk for pushing to huggingface hub
62
62
 
63
63
  Args:
64
- ----
65
64
  model: TF or PyTorch model to be saved
66
65
  save_dir: directory to save model and config
67
66
  arch: architecture name
@@ -97,7 +96,6 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
97
96
  >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
98
97
 
99
98
  Args:
100
- ----
101
99
  model: TF or PyTorch model to be saved
102
100
  model_name: name of the model which is also the repository name
103
101
  task: task name
@@ -114,9 +112,9 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
114
112
  # default readme
115
113
  readme = textwrap.dedent(
116
114
  f"""
117
- ---
115
+
118
116
  language: en
119
- ---
117
+
120
118
 
121
119
  <p align="center">
122
120
  <img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
@@ -190,12 +188,10 @@ def from_hub(repo_id: str, **kwargs: Any):
190
188
  >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn")
191
189
 
192
190
  Args:
193
- ----
194
191
  repo_id: HuggingFace model hub repo
195
192
  kwargs: kwargs of `hf_hub_download` or `snapshot_download`
196
193
 
197
194
  Returns:
198
- -------
199
195
  Model loaded with the checkpoint
200
196
  """
201
197
  # Get the config
@@ -221,10 +217,10 @@ def from_hub(repo_id: str, **kwargs: Any):
221
217
 
222
218
  # Load checkpoint
223
219
  if is_torch_available():
224
- state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu")
225
- model.load_state_dict(state_dict)
220
+ weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
226
221
  else: # tf
227
222
  weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
228
- model.load_weights(weights)
223
+
224
+ model.from_pretrained(weights)
229
225
 
230
226
  return model
@@ -1,6 +1,6 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- else:
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Optional
6
+ from typing import Any
7
7
 
8
8
  from doctr.models.builder import KIEDocumentBuilder
9
9
 
@@ -17,7 +17,6 @@ class _KIEPredictor(_OCRPredictor):
17
17
  """Implements an object able to localize and identify text elements in a set of documents
18
18
 
19
19
  Args:
20
- ----
21
20
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
22
21
  without rotated textual elements.
23
22
  straighten_pages: if True, estimates the page general orientation based on the median line orientation.
@@ -30,8 +29,8 @@ class _KIEPredictor(_OCRPredictor):
30
29
  kwargs: keyword args of `DocumentBuilder`
31
30
  """
32
31
 
33
- crop_orientation_predictor: Optional[OrientationPredictor]
34
- page_orientation_predictor: Optional[OrientationPredictor]
32
+ crop_orientation_predictor: OrientationPredictor | None
33
+ page_orientation_predictor: OrientationPredictor | None
35
34
 
36
35
  def __init__(
37
36
  self,