python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, Dict, List, Optional, Tuple
6
+ from collections.abc import Callable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -20,7 +21,7 @@ from .base import LinkNetPostProcessor, _LinkNet
20
21
  __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
21
22
 
22
23
 
23
- default_cfgs: Dict[str, Dict[str, Any]] = {
24
+ default_cfgs: dict[str, dict[str, Any]] = {
24
25
  "linknet_resnet18": {
25
26
  "input_shape": (3, 1024, 1024),
26
27
  "mean": (0.798, 0.785, 0.772),
@@ -43,7 +44,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
43
44
 
44
45
 
45
46
  class LinkNetFPN(nn.Module):
46
- def __init__(self, layer_shapes: List[Tuple[int, int, int]]) -> None:
47
+ def __init__(self, layer_shapes: list[tuple[int, int, int]]) -> None:
47
48
  super().__init__()
48
49
  strides = [
49
50
  1 if (in_shape[-1] == out_shape[-1]) else 2
@@ -74,7 +75,7 @@ class LinkNetFPN(nn.Module):
74
75
  nn.ReLU(inplace=True),
75
76
  )
76
77
 
77
- def forward(self, feats: List[torch.Tensor]) -> torch.Tensor:
78
+ def forward(self, feats: list[torch.Tensor]) -> torch.Tensor:
78
79
  out = feats[-1]
79
80
  for decoder, fmap in zip(self.decoders[::-1], feats[:-1][::-1]):
80
81
  out = decoder(out) + fmap
@@ -89,7 +90,6 @@ class LinkNet(nn.Module, _LinkNet):
89
90
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
90
91
 
91
92
  Args:
92
- ----
93
93
  feature extractor: the backbone serving as feature extractor
94
94
  bin_thresh: threshold for binarization of the output feature map
95
95
  box_thresh: minimal objectness score to consider a box
@@ -108,8 +108,8 @@ class LinkNet(nn.Module, _LinkNet):
108
108
  head_chans: int = 32,
109
109
  assume_straight_pages: bool = True,
110
110
  exportable: bool = False,
111
- cfg: Optional[Dict[str, Any]] = None,
112
- class_names: List[str] = [CLASS_NAME],
111
+ cfg: dict[str, Any] | None = None,
112
+ class_names: list[str] = [CLASS_NAME],
113
113
  ) -> None:
114
114
  super().__init__()
115
115
  self.class_names = class_names
@@ -163,16 +163,16 @@ class LinkNet(nn.Module, _LinkNet):
163
163
  def forward(
164
164
  self,
165
165
  x: torch.Tensor,
166
- target: Optional[List[np.ndarray]] = None,
166
+ target: list[np.ndarray] | None = None,
167
167
  return_model_output: bool = False,
168
168
  return_preds: bool = False,
169
169
  **kwargs: Any,
170
- ) -> Dict[str, Any]:
170
+ ) -> dict[str, Any]:
171
171
  feats = self.feat_extractor(x)
172
172
  logits = self.fpn([feats[str(idx)] for idx in range(len(feats))])
173
173
  logits = self.classifier(logits)
174
174
 
175
- out: Dict[str, Any] = {}
175
+ out: dict[str, Any] = {}
176
176
  if self.exportable:
177
177
  out["logits"] = logits
178
178
  return out
@@ -183,11 +183,16 @@ class LinkNet(nn.Module, _LinkNet):
183
183
  out["out_map"] = prob_map
184
184
 
185
185
  if target is None or return_preds:
186
- # Post-process boxes
187
- out["preds"] = [
188
- dict(zip(self.class_names, preds))
189
- for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
190
- ]
186
+ # Disable for torch.compile compatibility
187
+ @torch.compiler.disable # type: ignore[attr-defined]
188
+ def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
189
+ return [
190
+ dict(zip(self.class_names, preds))
191
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
192
+ ]
193
+
194
+ # Post-process boxes (keep only text predictions)
195
+ out["preds"] = _postprocess(prob_map)
191
196
 
192
197
  if target is not None:
193
198
  loss = self.compute_loss(logits, target)
@@ -198,7 +203,7 @@ class LinkNet(nn.Module, _LinkNet):
198
203
  def compute_loss(
199
204
  self,
200
205
  out_map: torch.Tensor,
201
- target: List[np.ndarray],
206
+ target: list[np.ndarray],
202
207
  gamma: float = 2.0,
203
208
  alpha: float = 0.5,
204
209
  eps: float = 1e-8,
@@ -207,7 +212,6 @@ class LinkNet(nn.Module, _LinkNet):
207
212
  <https://github.com/tensorflow/addons/>`_.
208
213
 
209
214
  Args:
210
- ----
211
215
  out_map: output feature map of the model of shape (N, num_classes, H, W)
212
216
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
213
217
  gamma: modulating factor in the focal loss formula
@@ -215,7 +219,6 @@ class LinkNet(nn.Module, _LinkNet):
215
219
  eps: epsilon factor in dice loss
216
220
 
217
221
  Returns:
218
- -------
219
222
  A loss tensor
220
223
  """
221
224
  _target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
@@ -252,9 +255,9 @@ def _linknet(
252
255
  arch: str,
253
256
  pretrained: bool,
254
257
  backbone_fn: Callable[[bool], nn.Module],
255
- fpn_layers: List[str],
258
+ fpn_layers: list[str],
256
259
  pretrained_backbone: bool = True,
257
- ignore_keys: Optional[List[str]] = None,
260
+ ignore_keys: list[str] | None = None,
258
261
  **kwargs: Any,
259
262
  ) -> LinkNet:
260
263
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -295,12 +298,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
295
298
  >>> out = model(input_tensor)
296
299
 
297
300
  Args:
298
- ----
299
301
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
300
302
  **kwargs: keyword arguments of the LinkNet architecture
301
303
 
302
304
  Returns:
303
- -------
304
305
  text detection architecture
305
306
  """
306
307
  return _linknet(
@@ -327,12 +328,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
327
328
  >>> out = model(input_tensor)
328
329
 
329
330
  Args:
330
- ----
331
331
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
332
332
  **kwargs: keyword arguments of the LinkNet architecture
333
333
 
334
334
  Returns:
335
- -------
336
335
  text detection architecture
337
336
  """
338
337
  return _linknet(
@@ -359,12 +358,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
359
358
  >>> out = model(input_tensor)
360
359
 
361
360
  Args:
362
- ----
363
361
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
364
362
  **kwargs: keyword arguments of the LinkNet architecture
365
363
 
366
364
  Returns:
367
- -------
368
365
  text detection architecture
369
366
  """
370
367
  return _linknet(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,40 +6,45 @@
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional, Tuple
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
13
- from tensorflow import keras
14
- from tensorflow.keras import Model, Sequential, layers
13
+ from tensorflow.keras import Model, Sequential, layers, losses
15
14
 
16
15
  from doctr.file_utils import CLASS_NAME
17
16
  from doctr.models.classification import resnet18, resnet34, resnet50
18
- from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
17
+ from doctr.models.utils import (
18
+ IntermediateLayerGetter,
19
+ _bf16_to_float32,
20
+ _build_model,
21
+ conv_sequence,
22
+ load_pretrained_params,
23
+ )
19
24
  from doctr.utils.repr import NestedObject
20
25
 
21
26
  from .base import LinkNetPostProcessor, _LinkNet
22
27
 
23
28
  __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
24
29
 
25
- default_cfgs: Dict[str, Dict[str, Any]] = {
30
+ default_cfgs: dict[str, dict[str, Any]] = {
26
31
  "linknet_resnet18": {
27
32
  "mean": (0.798, 0.785, 0.772),
28
33
  "std": (0.264, 0.2749, 0.287),
29
34
  "input_shape": (1024, 1024, 3),
30
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-b9ee56e6.zip&src=0",
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0",
31
36
  },
32
37
  "linknet_resnet34": {
33
38
  "mean": (0.798, 0.785, 0.772),
34
39
  "std": (0.264, 0.2749, 0.287),
35
40
  "input_shape": (1024, 1024, 3),
36
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-51909c56.zip&src=0",
41
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0",
37
42
  },
38
43
  "linknet_resnet50": {
39
44
  "mean": (0.798, 0.785, 0.772),
40
45
  "std": (0.264, 0.2749, 0.287),
41
46
  "input_shape": (1024, 1024, 3),
42
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-ac9f3829.zip&src=0",
47
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0",
43
48
  },
44
49
  }
45
50
 
@@ -68,7 +73,7 @@ class LinkNetFPN(Model, NestedObject):
68
73
  def __init__(
69
74
  self,
70
75
  out_chans: int,
71
- in_shapes: List[Tuple[int, ...]],
76
+ in_shapes: list[tuple[int, ...]],
72
77
  ) -> None:
73
78
  super().__init__()
74
79
  self.out_chans = out_chans
@@ -80,22 +85,21 @@ class LinkNetFPN(Model, NestedObject):
80
85
  for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
81
86
  ]
82
87
 
83
- def call(self, x: List[tf.Tensor]) -> tf.Tensor:
88
+ def call(self, x: list[tf.Tensor], **kwargs: Any) -> tf.Tensor:
84
89
  out = 0
85
90
  for decoder, fmap in zip(self.decoders, x[::-1]):
86
- out = decoder(out + fmap)
91
+ out = decoder(out + fmap, **kwargs)
87
92
  return out
88
93
 
89
94
  def extra_repr(self) -> str:
90
95
  return f"out_chans={self.out_chans}"
91
96
 
92
97
 
93
- class LinkNet(_LinkNet, keras.Model):
98
+ class LinkNet(_LinkNet, Model):
94
99
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
95
100
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
96
101
 
97
102
  Args:
98
- ----
99
103
  feature extractor: the backbone serving as feature extractor
100
104
  fpn_channels: number of channels each extracted feature maps is mapped to
101
105
  bin_thresh: threshold for binarization of the output feature map
@@ -106,7 +110,7 @@ class LinkNet(_LinkNet, keras.Model):
106
110
  class_names: list of class names
107
111
  """
108
112
 
109
- _children_names: List[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"]
113
+ _children_names: list[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"]
110
114
 
111
115
  def __init__(
112
116
  self,
@@ -116,8 +120,8 @@ class LinkNet(_LinkNet, keras.Model):
116
120
  box_thresh: float = 0.1,
117
121
  assume_straight_pages: bool = True,
118
122
  exportable: bool = False,
119
- cfg: Optional[Dict[str, Any]] = None,
120
- class_names: List[str] = [CLASS_NAME],
123
+ cfg: dict[str, Any] | None = None,
124
+ class_names: list[str] = [CLASS_NAME],
121
125
  ) -> None:
122
126
  super().__init__(cfg=cfg)
123
127
 
@@ -162,7 +166,7 @@ class LinkNet(_LinkNet, keras.Model):
162
166
  def compute_loss(
163
167
  self,
164
168
  out_map: tf.Tensor,
165
- target: List[Dict[str, np.ndarray]],
169
+ target: list[dict[str, np.ndarray]],
166
170
  gamma: float = 2.0,
167
171
  alpha: float = 0.5,
168
172
  eps: float = 1e-8,
@@ -171,7 +175,6 @@ class LinkNet(_LinkNet, keras.Model):
171
175
  <https://github.com/tensorflow/addons/>`_.
172
176
 
173
177
  Args:
174
- ----
175
178
  out_map: output feature map of the model of shape N x H x W x 1
176
179
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
177
180
  gamma: modulating factor in the focal loss formula
@@ -179,7 +182,6 @@ class LinkNet(_LinkNet, keras.Model):
179
182
  eps: epsilon factor in dice loss
180
183
 
181
184
  Returns:
182
- -------
183
185
  A loss tensor
184
186
  """
185
187
  seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True)
@@ -187,7 +189,7 @@ class LinkNet(_LinkNet, keras.Model):
187
189
  seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
188
190
  seg_mask = tf.cast(seg_mask, tf.float32)
189
191
 
190
- bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
192
+ bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
191
193
  proba_map = tf.sigmoid(out_map)
192
194
 
193
195
  # Focal loss
@@ -213,16 +215,16 @@ class LinkNet(_LinkNet, keras.Model):
213
215
  def call(
214
216
  self,
215
217
  x: tf.Tensor,
216
- target: Optional[List[Dict[str, np.ndarray]]] = None,
218
+ target: list[dict[str, np.ndarray]] | None = None,
217
219
  return_model_output: bool = False,
218
220
  return_preds: bool = False,
219
221
  **kwargs: Any,
220
- ) -> Dict[str, Any]:
222
+ ) -> dict[str, Any]:
221
223
  feat_maps = self.feat_extractor(x, **kwargs)
222
224
  logits = self.fpn(feat_maps, **kwargs)
223
225
  logits = self.classifier(logits, **kwargs)
224
226
 
225
- out: Dict[str, tf.Tensor] = {}
227
+ out: dict[str, tf.Tensor] = {}
226
228
  if self.exportable:
227
229
  out["logits"] = logits
228
230
  return out
@@ -248,9 +250,9 @@ def _linknet(
248
250
  arch: str,
249
251
  pretrained: bool,
250
252
  backbone_fn,
251
- fpn_layers: List[str],
253
+ fpn_layers: list[str],
252
254
  pretrained_backbone: bool = True,
253
- input_shape: Optional[Tuple[int, int, int]] = None,
255
+ input_shape: tuple[int, int, int] | None = None,
254
256
  **kwargs: Any,
255
257
  ) -> LinkNet:
256
258
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -275,9 +277,16 @@ def _linknet(
275
277
 
276
278
  # Build the model
277
279
  model = LinkNet(feat_extractor, cfg=_cfg, **kwargs)
280
+ _build_model(model)
281
+
278
282
  # Load pretrained parameters
279
283
  if pretrained:
280
- load_pretrained_params(model, _cfg["url"])
284
+ # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
285
+ load_pretrained_params(
286
+ model,
287
+ _cfg["url"],
288
+ skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
289
+ )
281
290
 
282
291
  return model
283
292
 
@@ -293,12 +302,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
293
302
  >>> out = model(input_tensor)
294
303
 
295
304
  Args:
296
- ----
297
305
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
298
306
  **kwargs: keyword arguments of the LinkNet architecture
299
307
 
300
308
  Returns:
301
- -------
302
309
  text detection architecture
303
310
  """
304
311
  return _linknet(
@@ -321,12 +328,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
321
328
  >>> out = model(input_tensor)
322
329
 
323
330
  Args:
324
- ----
325
331
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
326
332
  **kwargs: keyword arguments of the LinkNet architecture
327
333
 
328
334
  Returns:
329
- -------
330
335
  text detection architecture
331
336
  """
332
337
  return _linknet(
@@ -349,12 +354,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
349
354
  >>> out = model(input_tensor)
350
355
 
351
356
  Args:
352
- ----
353
357
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
354
358
  **kwargs: keyword arguments of the LinkNet architecture
355
359
 
356
360
  Returns:
357
- -------
358
361
  text detection architecture
359
362
  """
360
363
  return _linknet(
@@ -1,6 +1,6 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- else:
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -20,7 +20,6 @@ class DetectionPredictor(nn.Module):
20
20
  """Implements an object able to localize text elements in a document
21
21
 
22
22
  Args:
23
- ----
24
23
  pre_processor: transform inputs for easier batched model inference
25
24
  model: core detection architecture
26
25
  """
@@ -37,10 +36,10 @@ class DetectionPredictor(nn.Module):
37
36
  @torch.inference_mode()
38
37
  def forward(
39
38
  self,
40
- pages: List[Union[np.ndarray, torch.Tensor]],
39
+ pages: list[np.ndarray | torch.Tensor],
41
40
  return_maps: bool = False,
42
41
  **kwargs: Any,
43
- ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
42
+ ) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
44
43
  # Extract parameters from the preprocessor
45
44
  preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
46
45
  symmetric_pad = self.pre_processor.resize.symmetric_pad
@@ -60,11 +59,11 @@ class DetectionPredictor(nn.Module):
60
59
  ]
61
60
  # Remove padding from loc predictions
62
61
  preds = _remove_padding(
63
- pages, # type: ignore[arg-type]
62
+ pages,
64
63
  [pred for batch in predicted_batches for pred in batch["preds"]],
65
64
  preserve_aspect_ratio=preserve_aspect_ratio,
66
65
  symmetric_pad=symmetric_pad,
67
- assume_straight_pages=assume_straight_pages,
66
+ assume_straight_pages=assume_straight_pages, # type: ignore[arg-type]
68
67
  )
69
68
 
70
69
  if return_maps:
@@ -1,13 +1,13 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
10
- from tensorflow import keras
10
+ from tensorflow.keras import Model
11
11
 
12
12
  from doctr.models.detection._utils import _remove_padding
13
13
  from doctr.models.preprocessor import PreProcessor
@@ -20,27 +20,26 @@ class DetectionPredictor(NestedObject):
20
20
  """Implements an object able to localize text elements in a document
21
21
 
22
22
  Args:
23
- ----
24
23
  pre_processor: transform inputs for easier batched model inference
25
24
  model: core detection architecture
26
25
  """
27
26
 
28
- _children_names: List[str] = ["pre_processor", "model"]
27
+ _children_names: list[str] = ["pre_processor", "model"]
29
28
 
30
29
  def __init__(
31
30
  self,
32
31
  pre_processor: PreProcessor,
33
- model: keras.Model,
32
+ model: Model,
34
33
  ) -> None:
35
34
  self.pre_processor = pre_processor
36
35
  self.model = model
37
36
 
38
37
  def __call__(
39
38
  self,
40
- pages: List[Union[np.ndarray, tf.Tensor]],
39
+ pages: list[np.ndarray | tf.Tensor],
41
40
  return_maps: bool = False,
42
41
  **kwargs: Any,
43
- ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
42
+ ) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
44
43
  # Extract parameters from the preprocessor
45
44
  preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
46
45
  symmetric_pad = self.pre_processor.resize.symmetric_pad
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List
6
+ from typing import Any
7
7
 
8
8
  from doctr.file_utils import is_tf_available, is_torch_available
9
9
 
@@ -14,7 +14,7 @@ from .predictor import DetectionPredictor
14
14
 
15
15
  __all__ = ["detection_predictor"]
16
16
 
17
- ARCHS: List[str]
17
+ ARCHS: list[str]
18
18
 
19
19
 
20
20
  if is_tf_available():
@@ -56,7 +56,14 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
56
56
  if isinstance(_model, detection.FAST):
57
57
  _model = reparameterize(_model)
58
58
  else:
59
- if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
59
+ allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST]
60
+ if is_torch_available():
61
+ # Adding the type for torch compiled models to the allowed architectures
62
+ from doctr.models.utils import _CompiledModule
63
+
64
+ allowed_archs.append(_CompiledModule)
65
+
66
+ if not isinstance(arch, tuple(allowed_archs)):
60
67
  raise ValueError(f"unknown architecture: {type(arch)}")
61
68
 
62
69
  _model = arch
@@ -79,6 +86,9 @@ def detection_predictor(
79
86
  arch: Any = "fast_base",
80
87
  pretrained: bool = False,
81
88
  assume_straight_pages: bool = True,
89
+ preserve_aspect_ratio: bool = True,
90
+ symmetric_pad: bool = True,
91
+ batch_size: int = 2,
82
92
  **kwargs: Any,
83
93
  ) -> DetectionPredictor:
84
94
  """Text detection architecture.
@@ -90,14 +100,24 @@ def detection_predictor(
90
100
  >>> out = model([input_page])
91
101
 
92
102
  Args:
93
- ----
94
103
  arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
95
104
  pretrained: If True, returns a model pre-trained on our text detection dataset
96
105
  assume_straight_pages: If True, fit straight boxes to the page
106
+ preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
107
+ running the detection model on it
108
+ symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
109
+ batch_size: number of samples the model processes in parallel
97
110
  **kwargs: optional keyword arguments passed to the architecture
98
111
 
99
112
  Returns:
100
- -------
101
113
  Detection predictor
102
114
  """
103
- return _predictor(arch, pretrained, assume_straight_pages, **kwargs)
115
+ return _predictor(
116
+ arch=arch,
117
+ pretrained=pretrained,
118
+ assume_straight_pages=assume_straight_pages,
119
+ preserve_aspect_ratio=preserve_aspect_ratio,
120
+ symmetric_pad=symmetric_pad,
121
+ batch_size=batch_size,
122
+ **kwargs,
123
+ )
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -20,7 +20,6 @@ from huggingface_hub import (
20
20
  get_token_permission,
21
21
  hf_hub_download,
22
22
  login,
23
- snapshot_download,
24
23
  )
25
24
 
26
25
  from doctr import models
@@ -33,7 +32,7 @@ __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config
33
32
 
34
33
 
35
34
  AVAILABLE_ARCHS = {
36
- "classification": models.classification.zoo.ARCHS,
35
+ "classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS,
37
36
  "detection": models.detection.zoo.ARCHS,
38
37
  "recognition": models.recognition.zoo.ARCHS,
39
38
  }
@@ -62,7 +61,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
62
61
  """Save model and config to disk for pushing to huggingface hub
63
62
 
64
63
  Args:
65
- ----
66
64
  model: TF or PyTorch model to be saved
67
65
  save_dir: directory to save model and config
68
66
  arch: architecture name
@@ -74,7 +72,7 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
74
72
  weights_path = save_directory / "pytorch_model.bin"
75
73
  torch.save(model.state_dict(), weights_path)
76
74
  elif is_tf_available():
77
- weights_path = save_directory / "tf_model" / "weights"
75
+ weights_path = save_directory / "tf_model.weights.h5"
78
76
  model.save_weights(str(weights_path))
79
77
 
80
78
  config_path = save_directory / "config.json"
@@ -98,7 +96,6 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
98
96
  >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
99
97
 
100
98
  Args:
101
- ----
102
99
  model: TF or PyTorch model to be saved
103
100
  model_name: name of the model which is also the repository name
104
101
  task: task name
@@ -115,9 +112,9 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
115
112
  # default readme
116
113
  readme = textwrap.dedent(
117
114
  f"""
118
- ---
115
+
119
116
  language: en
120
- ---
117
+
121
118
 
122
119
  <p align="center">
123
120
  <img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
@@ -174,7 +171,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
174
171
 
175
172
  local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name)
176
173
  repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False)
177
- repo = Repository(local_dir=local_cache_dir, clone_from=repo_url, use_auth_token=True)
174
+ repo = Repository(local_dir=local_cache_dir, clone_from=repo_url)
178
175
 
179
176
  with repo.commit(commit_message):
180
177
  _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task)
@@ -191,12 +188,10 @@ def from_hub(repo_id: str, **kwargs: Any):
191
188
  >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn")
192
189
 
193
190
  Args:
194
- ----
195
191
  repo_id: HuggingFace model hub repo
196
192
  kwargs: kwargs of `hf_hub_download` or `snapshot_download`
197
193
 
198
194
  Returns:
199
- -------
200
195
  Model loaded with the checkpoint
201
196
  """
202
197
  # Get the config
@@ -225,7 +220,7 @@ def from_hub(repo_id: str, **kwargs: Any):
225
220
  state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu")
226
221
  model.load_state_dict(state_dict)
227
222
  else: # tf
228
- repo_path = snapshot_download(repo_id, **kwargs)
229
- model.load_weights(os.path.join(repo_path, "tf_model", "weights"))
223
+ weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
224
+ model.load_weights(weights)
230
225
 
231
226
  return model