python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, Dict, List, Optional
6
+ from collections.abc import Callable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -22,7 +23,7 @@ from .base import DBPostProcessor, _DBNet
22
23
  __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
23
24
 
24
25
 
25
- default_cfgs: Dict[str, Dict[str, Any]] = {
26
+ default_cfgs: dict[str, dict[str, Any]] = {
26
27
  "db_resnet50": {
27
28
  "input_shape": (3, 1024, 1024),
28
29
  "mean": (0.798, 0.785, 0.772),
@@ -47,7 +48,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
47
48
  class FeaturePyramidNetwork(nn.Module):
48
49
  def __init__(
49
50
  self,
50
- in_channels: List[int],
51
+ in_channels: list[int],
51
52
  out_channels: int,
52
53
  deform_conv: bool = False,
53
54
  ) -> None:
@@ -76,12 +77,12 @@ class FeaturePyramidNetwork(nn.Module):
76
77
  for idx, chans in enumerate(in_channels)
77
78
  ])
78
79
 
79
- def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
80
+ def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
80
81
  if len(x) != len(self.out_branches):
81
82
  raise AssertionError
82
83
  # Conv1x1 to get the same number of channels
83
- _x: List[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)]
84
- out: List[torch.Tensor] = [_x[-1]]
84
+ _x: list[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)]
85
+ out: list[torch.Tensor] = [_x[-1]]
85
86
  for t in _x[:-1][::-1]:
86
87
  out.append(self.upsample(out[-1]) + t)
87
88
 
@@ -96,7 +97,6 @@ class DBNet(_DBNet, nn.Module):
96
97
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
97
98
 
98
99
  Args:
99
- ----
100
100
  feature extractor: the backbone serving as feature extractor
101
101
  head_chans: the number of channels in the head
102
102
  deform_conv: whether to use deformable convolution
@@ -117,8 +117,8 @@ class DBNet(_DBNet, nn.Module):
117
117
  box_thresh: float = 0.1,
118
118
  assume_straight_pages: bool = True,
119
119
  exportable: bool = False,
120
- cfg: Optional[Dict[str, Any]] = None,
121
- class_names: List[str] = [CLASS_NAME],
120
+ cfg: dict[str, Any] | None = None,
121
+ class_names: list[str] = [CLASS_NAME],
122
122
  ) -> None:
123
123
  super().__init__()
124
124
  self.class_names = class_names
@@ -179,13 +179,22 @@ class DBNet(_DBNet, nn.Module):
179
179
  m.weight.data.fill_(1.0)
180
180
  m.bias.data.zero_()
181
181
 
182
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
183
+ """Load pretrained parameters onto the model
184
+
185
+ Args:
186
+ path_or_url: the path or URL to the model parameters (checkpoint)
187
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
188
+ """
189
+ load_pretrained_params(self, path_or_url, **kwargs)
190
+
182
191
  def forward(
183
192
  self,
184
193
  x: torch.Tensor,
185
- target: Optional[List[np.ndarray]] = None,
194
+ target: list[np.ndarray] | None = None,
186
195
  return_model_output: bool = False,
187
196
  return_preds: bool = False,
188
- ) -> Dict[str, torch.Tensor]:
197
+ ) -> dict[str, torch.Tensor]:
189
198
  # Extract feature maps at different stages
190
199
  feats = self.feat_extractor(x)
191
200
  feats = [feats[str(idx)] for idx in range(len(feats))]
@@ -193,7 +202,7 @@ class DBNet(_DBNet, nn.Module):
193
202
  feat_concat = self.fpn(feats)
194
203
  logits = self.prob_head(feat_concat)
195
204
 
196
- out: Dict[str, Any] = {}
205
+ out: dict[str, Any] = {}
197
206
  if self.exportable:
198
207
  out["logits"] = logits
199
208
  return out
@@ -205,11 +214,16 @@ class DBNet(_DBNet, nn.Module):
205
214
  out["out_map"] = prob_map
206
215
 
207
216
  if target is None or return_preds:
217
+ # Disable for torch.compile compatibility
218
+ @torch.compiler.disable # type: ignore[attr-defined]
219
+ def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
220
+ return [
221
+ dict(zip(self.class_names, preds))
222
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
223
+ ]
224
+
208
225
  # Post-process boxes (keep only text predictions)
209
- out["preds"] = [
210
- dict(zip(self.class_names, preds))
211
- for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
212
- ]
226
+ out["preds"] = _postprocess(prob_map)
213
227
 
214
228
  if target is not None:
215
229
  thresh_map = self.thresh_head(feat_concat)
@@ -222,7 +236,7 @@ class DBNet(_DBNet, nn.Module):
222
236
  self,
223
237
  out_map: torch.Tensor,
224
238
  thresh_map: torch.Tensor,
225
- target: List[np.ndarray],
239
+ target: list[np.ndarray],
226
240
  gamma: float = 2.0,
227
241
  alpha: float = 0.5,
228
242
  eps: float = 1e-8,
@@ -231,7 +245,6 @@ class DBNet(_DBNet, nn.Module):
231
245
  and a list of masks for each image. From there it computes the loss with the model output
232
246
 
233
247
  Args:
234
- ----
235
248
  out_map: output feature map of the model of shape (N, C, H, W)
236
249
  thresh_map: threshold map of shape (N, C, H, W)
237
250
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
@@ -240,7 +253,6 @@ class DBNet(_DBNet, nn.Module):
240
253
  eps: epsilon factor in dice loss
241
254
 
242
255
  Returns:
243
- -------
244
256
  A loss tensor
245
257
  """
246
258
  if gamma < 0:
@@ -273,7 +285,7 @@ class DBNet(_DBNet, nn.Module):
273
285
  dice_map = torch.softmax(out_map, dim=1)
274
286
  else:
275
287
  # compute binary map instead
276
- dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
288
+ dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
277
289
  # Class reduced
278
290
  inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
279
291
  cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
@@ -290,10 +302,10 @@ def _dbnet(
290
302
  arch: str,
291
303
  pretrained: bool,
292
304
  backbone_fn: Callable[[bool], nn.Module],
293
- fpn_layers: List[str],
294
- backbone_submodule: Optional[str] = None,
305
+ fpn_layers: list[str],
306
+ backbone_submodule: str | None = None,
295
307
  pretrained_backbone: bool = True,
296
- ignore_keys: Optional[List[str]] = None,
308
+ ignore_keys: list[str] | None = None,
297
309
  **kwargs: Any,
298
310
  ) -> DBNet:
299
311
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -325,7 +337,7 @@ def _dbnet(
325
337
  _ignore_keys = (
326
338
  ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
327
339
  )
328
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
340
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
329
341
 
330
342
  return model
331
343
 
@@ -341,12 +353,10 @@ def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet:
341
353
  >>> out = model(input_tensor)
342
354
 
343
355
  Args:
344
- ----
345
356
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
346
357
  **kwargs: keyword arguments of the DBNet architecture
347
358
 
348
359
  Returns:
349
- -------
350
360
  text detection architecture
351
361
  """
352
362
  return _dbnet(
@@ -376,12 +386,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
376
386
  >>> out = model(input_tensor)
377
387
 
378
388
  Args:
379
- ----
380
389
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
381
390
  **kwargs: keyword arguments of the DBNet architecture
382
391
 
383
392
  Returns:
384
- -------
385
393
  text detection architecture
386
394
  """
387
395
  return _dbnet(
@@ -411,12 +419,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
411
419
  >>> out = model(input_tensor)
412
420
 
413
421
  Args:
414
- ----
415
422
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
416
423
  **kwargs: keyword arguments of the DBNet architecture
417
424
 
418
425
  Returns:
419
- -------
420
426
  text detection architecture
421
427
  """
422
428
  return _dbnet(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional, Tuple
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
@@ -29,7 +29,7 @@ from .base import DBPostProcessor, _DBNet
29
29
  __all__ = ["DBNet", "db_resnet50", "db_mobilenet_v3_large"]
30
30
 
31
31
 
32
- default_cfgs: Dict[str, Dict[str, Any]] = {
32
+ default_cfgs: dict[str, dict[str, Any]] = {
33
33
  "db_resnet50": {
34
34
  "mean": (0.798, 0.785, 0.772),
35
35
  "std": (0.264, 0.2749, 0.287),
@@ -50,7 +50,6 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
50
50
  <https://arxiv.org/pdf/1612.03144.pdf>`_.
51
51
 
52
52
  Args:
53
- ----
54
53
  channels: number of channel to output
55
54
  """
56
55
 
@@ -72,12 +71,10 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
72
71
  """Module which performs a 3x3 convolution followed by up-sampling
73
72
 
74
73
  Args:
75
- ----
76
74
  channels: number of output channels
77
75
  dilation_factor (int): dilation factor to scale the convolution output before concatenation
78
76
 
79
77
  Returns:
80
- -------
81
78
  a keras.layers.Layer object, wrapping these operations in a sequential module
82
79
 
83
80
  """
@@ -95,7 +92,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
95
92
 
96
93
  def call(
97
94
  self,
98
- x: List[tf.Tensor],
95
+ x: list[tf.Tensor],
99
96
  **kwargs: Any,
100
97
  ) -> tf.Tensor:
101
98
  # Channel mapping
@@ -114,7 +111,6 @@ class DBNet(_DBNet, Model, NestedObject):
114
111
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
115
112
 
116
113
  Args:
117
- ----
118
114
  feature extractor: the backbone serving as feature extractor
119
115
  fpn_channels: number of channels each extracted feature maps is mapped to
120
116
  bin_thresh: threshold for binarization
@@ -125,7 +121,7 @@ class DBNet(_DBNet, Model, NestedObject):
125
121
  class_names: list of class names
126
122
  """
127
123
 
128
- _children_names: List[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"]
124
+ _children_names: list[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"]
129
125
 
130
126
  def __init__(
131
127
  self,
@@ -135,8 +131,8 @@ class DBNet(_DBNet, Model, NestedObject):
135
131
  box_thresh: float = 0.1,
136
132
  assume_straight_pages: bool = True,
137
133
  exportable: bool = False,
138
- cfg: Optional[Dict[str, Any]] = None,
139
- class_names: List[str] = [CLASS_NAME],
134
+ cfg: dict[str, Any] | None = None,
135
+ class_names: list[str] = [CLASS_NAME],
140
136
  ) -> None:
141
137
  super().__init__()
142
138
  self.class_names = class_names
@@ -171,11 +167,20 @@ class DBNet(_DBNet, Model, NestedObject):
171
167
  assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
172
168
  )
173
169
 
170
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
171
+ """Load pretrained parameters onto the model
172
+
173
+ Args:
174
+ path_or_url: the path or URL to the model parameters (checkpoint)
175
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
176
+ """
177
+ load_pretrained_params(self, path_or_url, **kwargs)
178
+
174
179
  def compute_loss(
175
180
  self,
176
181
  out_map: tf.Tensor,
177
182
  thresh_map: tf.Tensor,
178
- target: List[Dict[str, np.ndarray]],
183
+ target: list[dict[str, np.ndarray]],
179
184
  gamma: float = 2.0,
180
185
  alpha: float = 0.5,
181
186
  eps: float = 1e-8,
@@ -184,7 +189,6 @@ class DBNet(_DBNet, Model, NestedObject):
184
189
  and a list of masks for each image. From there it computes the loss with the model output
185
190
 
186
191
  Args:
187
- ----
188
192
  out_map: output feature map of the model of shape (N, H, W, C)
189
193
  thresh_map: threshold map of shape (N, H, W, C)
190
194
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
@@ -193,7 +197,6 @@ class DBNet(_DBNet, Model, NestedObject):
193
197
  eps: epsilon factor in dice loss
194
198
 
195
199
  Returns:
196
- -------
197
200
  A loss tensor
198
201
  """
199
202
  if gamma < 0:
@@ -246,16 +249,16 @@ class DBNet(_DBNet, Model, NestedObject):
246
249
  def call(
247
250
  self,
248
251
  x: tf.Tensor,
249
- target: Optional[List[Dict[str, np.ndarray]]] = None,
252
+ target: list[dict[str, np.ndarray]] | None = None,
250
253
  return_model_output: bool = False,
251
254
  return_preds: bool = False,
252
255
  **kwargs: Any,
253
- ) -> Dict[str, Any]:
256
+ ) -> dict[str, Any]:
254
257
  feat_maps = self.feat_extractor(x, **kwargs)
255
258
  feat_concat = self.fpn(feat_maps, **kwargs)
256
259
  logits = self.probability_head(feat_concat, **kwargs)
257
260
 
258
- out: Dict[str, tf.Tensor] = {}
261
+ out: dict[str, tf.Tensor] = {}
259
262
  if self.exportable:
260
263
  out["logits"] = logits
261
264
  return out
@@ -282,9 +285,9 @@ def _db_resnet(
282
285
  arch: str,
283
286
  pretrained: bool,
284
287
  backbone_fn,
285
- fpn_layers: List[str],
288
+ fpn_layers: list[str],
286
289
  pretrained_backbone: bool = True,
287
- input_shape: Optional[Tuple[int, int, int]] = None,
290
+ input_shape: tuple[int, int, int] | None = None,
288
291
  **kwargs: Any,
289
292
  ) -> DBNet:
290
293
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -315,8 +318,7 @@ def _db_resnet(
315
318
  # Load pretrained parameters
316
319
  if pretrained:
317
320
  # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
318
- load_pretrained_params(
319
- model,
321
+ model.from_pretrained(
320
322
  _cfg["url"],
321
323
  skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
322
324
  )
@@ -328,9 +330,9 @@ def _db_mobilenet(
328
330
  arch: str,
329
331
  pretrained: bool,
330
332
  backbone_fn,
331
- fpn_layers: List[str],
333
+ fpn_layers: list[str],
332
334
  pretrained_backbone: bool = True,
333
- input_shape: Optional[Tuple[int, int, int]] = None,
335
+ input_shape: tuple[int, int, int] | None = None,
334
336
  **kwargs: Any,
335
337
  ) -> DBNet:
336
338
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -359,8 +361,7 @@ def _db_mobilenet(
359
361
  # Load pretrained parameters
360
362
  if pretrained:
361
363
  # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
362
- load_pretrained_params(
363
- model,
364
+ model.from_pretrained(
364
365
  _cfg["url"],
365
366
  skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
366
367
  )
@@ -379,12 +380,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
379
380
  >>> out = model(input_tensor)
380
381
 
381
382
  Args:
382
- ----
383
383
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
384
384
  **kwargs: keyword arguments of the DBNet architecture
385
385
 
386
386
  Returns:
387
- -------
388
387
  text detection architecture
389
388
  """
390
389
  return _db_resnet(
@@ -407,12 +406,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
407
406
  >>> out = model(input_tensor)
408
407
 
409
408
  Args:
410
- ----
411
409
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
412
410
  **kwargs: keyword arguments of the DBNet architecture
413
411
 
414
412
  Returns:
415
- -------
416
413
  text detection architecture
417
414
  """
418
415
  return _db_mobilenet(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,11 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
- from typing import Dict, List, Tuple, Union
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -23,7 +22,6 @@ class FASTPostProcessor(DetectionPostProcessor):
23
22
  """Implements a post processor for FAST model.
24
23
 
25
24
  Args:
26
- ----
27
25
  bin_thresh: threshold used to binzarized p_map at inference time
28
26
  box_thresh: minimal objectness score to consider a box
29
27
  assume_straight_pages: whether the inputs were expected to have horizontal text elements
@@ -45,11 +43,9 @@ class FASTPostProcessor(DetectionPostProcessor):
45
43
  """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
46
44
 
47
45
  Args:
48
- ----
49
46
  points: The first parameter.
50
47
 
51
48
  Returns:
52
- -------
53
49
  a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
54
50
  """
55
51
  if not self.assume_straight_pages:
@@ -60,9 +56,8 @@ class FASTPostProcessor(DetectionPostProcessor):
60
56
  area = (rect[1][0] + 1) * (1 + rect[1][1])
61
57
  length = 2 * (rect[1][0] + rect[1][1]) + 2
62
58
  else:
63
- poly = Polygon(points)
64
- area = poly.area
65
- length = poly.length
59
+ area = cv2.contourArea(points)
60
+ length = cv2.arcLength(points, closed=True)
66
61
  distance = area * self.unclip_ratio / length # compute distance to expand polygon
67
62
  offset = pyclipper.PyclipperOffset()
68
63
  offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
@@ -94,24 +89,22 @@ class FASTPostProcessor(DetectionPostProcessor):
94
89
  """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
95
90
 
96
91
  Args:
97
- ----
98
92
  pred: Pred map from differentiable linknet output
99
93
  bitmap: Bitmap map computed from pred (binarized)
100
94
  angle_tol: Comparison tolerance of the angle with the median angle across the page
101
95
  ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
102
96
 
103
97
  Returns:
104
- -------
105
98
  np tensor boxes for the bitmap, each box is a 6-element list
106
99
  containing x, y, w, h, alpha, score for the box
107
100
  """
108
101
  height, width = bitmap.shape[:2]
109
- boxes: List[Union[np.ndarray, List[float]]] = []
102
+ boxes: list[np.ndarray | list[float]] = []
110
103
  # get contours from connected components on the bitmap
111
104
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
112
105
  for contour in contours:
113
106
  # Check whether smallest enclosing bounding box is not too small
114
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
107
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
115
108
  continue
116
109
  # Compute objectness
117
110
  if self.assume_straight_pages:
@@ -158,20 +151,18 @@ class _FAST(BaseModel):
158
151
 
159
152
  def build_target(
160
153
  self,
161
- target: List[Dict[str, np.ndarray]],
162
- output_shape: Tuple[int, int, int],
154
+ target: list[dict[str, np.ndarray]],
155
+ output_shape: tuple[int, int, int],
163
156
  channels_last: bool = True,
164
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
157
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
165
158
  """Build the target, and it's mask to be used from loss computation.
166
159
 
167
160
  Args:
168
- ----
169
161
  target: target coming from dataset
170
162
  output_shape: shape of the output of the model without batch_size
171
163
  channels_last: whether channels are last or not
172
164
 
173
165
  Returns:
174
- -------
175
166
  the new formatted target, mask and shrunken text kernel
176
167
  """
177
168
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):