python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, Dict, List, Optional
6
+ from collections.abc import Callable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -22,7 +23,7 @@ from .base import DBPostProcessor, _DBNet
22
23
  __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
23
24
 
24
25
 
25
- default_cfgs: Dict[str, Dict[str, Any]] = {
26
+ default_cfgs: dict[str, dict[str, Any]] = {
26
27
  "db_resnet50": {
27
28
  "input_shape": (3, 1024, 1024),
28
29
  "mean": (0.798, 0.785, 0.772),
@@ -47,7 +48,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
47
48
  class FeaturePyramidNetwork(nn.Module):
48
49
  def __init__(
49
50
  self,
50
- in_channels: List[int],
51
+ in_channels: list[int],
51
52
  out_channels: int,
52
53
  deform_conv: bool = False,
53
54
  ) -> None:
@@ -76,12 +77,12 @@ class FeaturePyramidNetwork(nn.Module):
76
77
  for idx, chans in enumerate(in_channels)
77
78
  ])
78
79
 
79
- def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
80
+ def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
80
81
  if len(x) != len(self.out_branches):
81
82
  raise AssertionError
82
83
  # Conv1x1 to get the same number of channels
83
- _x: List[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)]
84
- out: List[torch.Tensor] = [_x[-1]]
84
+ _x: list[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)]
85
+ out: list[torch.Tensor] = [_x[-1]]
85
86
  for t in _x[:-1][::-1]:
86
87
  out.append(self.upsample(out[-1]) + t)
87
88
 
@@ -96,7 +97,6 @@ class DBNet(_DBNet, nn.Module):
96
97
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
97
98
 
98
99
  Args:
99
- ----
100
100
  feature extractor: the backbone serving as feature extractor
101
101
  head_chans: the number of channels in the head
102
102
  deform_conv: whether to use deformable convolution
@@ -117,8 +117,8 @@ class DBNet(_DBNet, nn.Module):
117
117
  box_thresh: float = 0.1,
118
118
  assume_straight_pages: bool = True,
119
119
  exportable: bool = False,
120
- cfg: Optional[Dict[str, Any]] = None,
121
- class_names: List[str] = [CLASS_NAME],
120
+ cfg: dict[str, Any] | None = None,
121
+ class_names: list[str] = [CLASS_NAME],
122
122
  ) -> None:
123
123
  super().__init__()
124
124
  self.class_names = class_names
@@ -182,10 +182,10 @@ class DBNet(_DBNet, nn.Module):
182
182
  def forward(
183
183
  self,
184
184
  x: torch.Tensor,
185
- target: Optional[List[np.ndarray]] = None,
185
+ target: list[np.ndarray] | None = None,
186
186
  return_model_output: bool = False,
187
187
  return_preds: bool = False,
188
- ) -> Dict[str, torch.Tensor]:
188
+ ) -> dict[str, torch.Tensor]:
189
189
  # Extract feature maps at different stages
190
190
  feats = self.feat_extractor(x)
191
191
  feats = [feats[str(idx)] for idx in range(len(feats))]
@@ -193,7 +193,7 @@ class DBNet(_DBNet, nn.Module):
193
193
  feat_concat = self.fpn(feats)
194
194
  logits = self.prob_head(feat_concat)
195
195
 
196
- out: Dict[str, Any] = {}
196
+ out: dict[str, Any] = {}
197
197
  if self.exportable:
198
198
  out["logits"] = logits
199
199
  return out
@@ -205,11 +205,16 @@ class DBNet(_DBNet, nn.Module):
205
205
  out["out_map"] = prob_map
206
206
 
207
207
  if target is None or return_preds:
208
+ # Disable for torch.compile compatibility
209
+ @torch.compiler.disable # type: ignore[attr-defined]
210
+ def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
211
+ return [
212
+ dict(zip(self.class_names, preds))
213
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
214
+ ]
215
+
208
216
  # Post-process boxes (keep only text predictions)
209
- out["preds"] = [
210
- dict(zip(self.class_names, preds))
211
- for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
212
- ]
217
+ out["preds"] = _postprocess(prob_map)
213
218
 
214
219
  if target is not None:
215
220
  thresh_map = self.thresh_head(feat_concat)
@@ -222,7 +227,7 @@ class DBNet(_DBNet, nn.Module):
222
227
  self,
223
228
  out_map: torch.Tensor,
224
229
  thresh_map: torch.Tensor,
225
- target: List[np.ndarray],
230
+ target: list[np.ndarray],
226
231
  gamma: float = 2.0,
227
232
  alpha: float = 0.5,
228
233
  eps: float = 1e-8,
@@ -231,7 +236,6 @@ class DBNet(_DBNet, nn.Module):
231
236
  and a list of masks for each image. From there it computes the loss with the model output
232
237
 
233
238
  Args:
234
- ----
235
239
  out_map: output feature map of the model of shape (N, C, H, W)
236
240
  thresh_map: threshold map of shape (N, C, H, W)
237
241
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
@@ -240,7 +244,6 @@ class DBNet(_DBNet, nn.Module):
240
244
  eps: epsilon factor in dice loss
241
245
 
242
246
  Returns:
243
- -------
244
247
  A loss tensor
245
248
  """
246
249
  if gamma < 0:
@@ -273,7 +276,7 @@ class DBNet(_DBNet, nn.Module):
273
276
  dice_map = torch.softmax(out_map, dim=1)
274
277
  else:
275
278
  # compute binary map instead
276
- dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
279
+ dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
277
280
  # Class reduced
278
281
  inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
279
282
  cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
@@ -290,10 +293,10 @@ def _dbnet(
290
293
  arch: str,
291
294
  pretrained: bool,
292
295
  backbone_fn: Callable[[bool], nn.Module],
293
- fpn_layers: List[str],
294
- backbone_submodule: Optional[str] = None,
296
+ fpn_layers: list[str],
297
+ backbone_submodule: str | None = None,
295
298
  pretrained_backbone: bool = True,
296
- ignore_keys: Optional[List[str]] = None,
299
+ ignore_keys: list[str] | None = None,
297
300
  **kwargs: Any,
298
301
  ) -> DBNet:
299
302
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -341,12 +344,10 @@ def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet:
341
344
  >>> out = model(input_tensor)
342
345
 
343
346
  Args:
344
- ----
345
347
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
346
348
  **kwargs: keyword arguments of the DBNet architecture
347
349
 
348
350
  Returns:
349
- -------
350
351
  text detection architecture
351
352
  """
352
353
  return _dbnet(
@@ -376,12 +377,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
376
377
  >>> out = model(input_tensor)
377
378
 
378
379
  Args:
379
- ----
380
380
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
381
381
  **kwargs: keyword arguments of the DBNet architecture
382
382
 
383
383
  Returns:
384
- -------
385
384
  text detection architecture
386
385
  """
387
386
  return _dbnet(
@@ -411,12 +410,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
411
410
  >>> out = model(input_tensor)
412
411
 
413
412
  Args:
414
- ----
415
413
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
416
414
  **kwargs: keyword arguments of the DBNet architecture
417
415
 
418
416
  Returns:
419
- -------
420
417
  text detection architecture
421
418
  """
422
419
  return _dbnet(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,16 +6,21 @@
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional, Tuple
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
13
- from tensorflow import keras
14
- from tensorflow.keras import layers
13
+ from tensorflow.keras import Model, Sequential, layers, losses
15
14
  from tensorflow.keras.applications import ResNet50
16
15
 
17
16
  from doctr.file_utils import CLASS_NAME
18
- from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
17
+ from doctr.models.utils import (
18
+ IntermediateLayerGetter,
19
+ _bf16_to_float32,
20
+ _build_model,
21
+ conv_sequence,
22
+ load_pretrained_params,
23
+ )
19
24
  from doctr.utils.repr import NestedObject
20
25
 
21
26
  from ...classification import mobilenet_v3_large
@@ -24,18 +29,18 @@ from .base import DBPostProcessor, _DBNet
24
29
  __all__ = ["DBNet", "db_resnet50", "db_mobilenet_v3_large"]
25
30
 
26
31
 
27
- default_cfgs: Dict[str, Dict[str, Any]] = {
32
+ default_cfgs: dict[str, dict[str, Any]] = {
28
33
  "db_resnet50": {
29
34
  "mean": (0.798, 0.785, 0.772),
30
35
  "std": (0.264, 0.2749, 0.287),
31
36
  "input_shape": (1024, 1024, 3),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-84171458.zip&src=0",
37
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0",
33
38
  },
34
39
  "db_mobilenet_v3_large": {
35
40
  "mean": (0.798, 0.785, 0.772),
36
41
  "std": (0.264, 0.2749, 0.287),
37
42
  "input_shape": (1024, 1024, 3),
38
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-da524564.zip&src=0",
43
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0",
39
44
  },
40
45
  }
41
46
 
@@ -45,7 +50,6 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
45
50
  <https://arxiv.org/pdf/1612.03144.pdf>`_.
46
51
 
47
52
  Args:
48
- ----
49
53
  channels: number of channel to output
50
54
  """
51
55
 
@@ -67,12 +71,10 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
67
71
  """Module which performs a 3x3 convolution followed by up-sampling
68
72
 
69
73
  Args:
70
- ----
71
74
  channels: number of output channels
72
75
  dilation_factor (int): dilation factor to scale the convolution output before concatenation
73
76
 
74
77
  Returns:
75
- -------
76
78
  a keras.layers.Layer object, wrapping these operations in a sequential module
77
79
 
78
80
  """
@@ -81,7 +83,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
81
83
  if dilation_factor > 1:
82
84
  _layers.append(layers.UpSampling2D(size=(dilation_factor, dilation_factor), interpolation="nearest"))
83
85
 
84
- module = keras.Sequential(_layers)
86
+ module = Sequential(_layers)
85
87
 
86
88
  return module
87
89
 
@@ -90,7 +92,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
90
92
 
91
93
  def call(
92
94
  self,
93
- x: List[tf.Tensor],
95
+ x: list[tf.Tensor],
94
96
  **kwargs: Any,
95
97
  ) -> tf.Tensor:
96
98
  # Channel mapping
@@ -104,12 +106,11 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
104
106
  return layers.concatenate(results)
105
107
 
106
108
 
107
- class DBNet(_DBNet, keras.Model, NestedObject):
109
+ class DBNet(_DBNet, Model, NestedObject):
108
110
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
109
111
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
110
112
 
111
113
  Args:
112
- ----
113
114
  feature extractor: the backbone serving as feature extractor
114
115
  fpn_channels: number of channels each extracted feature maps is mapped to
115
116
  bin_thresh: threshold for binarization
@@ -120,7 +121,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
120
121
  class_names: list of class names
121
122
  """
122
123
 
123
- _children_names: List[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"]
124
+ _children_names: list[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"]
124
125
 
125
126
  def __init__(
126
127
  self,
@@ -130,8 +131,8 @@ class DBNet(_DBNet, keras.Model, NestedObject):
130
131
  box_thresh: float = 0.1,
131
132
  assume_straight_pages: bool = True,
132
133
  exportable: bool = False,
133
- cfg: Optional[Dict[str, Any]] = None,
134
- class_names: List[str] = [CLASS_NAME],
134
+ cfg: dict[str, Any] | None = None,
135
+ class_names: list[str] = [CLASS_NAME],
135
136
  ) -> None:
136
137
  super().__init__()
137
138
  self.class_names = class_names
@@ -147,14 +148,14 @@ class DBNet(_DBNet, keras.Model, NestedObject):
147
148
  _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
148
149
  output_shape = tuple(self.fpn(_inputs).shape)
149
150
 
150
- self.probability_head = keras.Sequential([
151
+ self.probability_head = Sequential([
151
152
  *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
152
153
  layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
153
154
  layers.BatchNormalization(),
154
155
  layers.Activation("relu"),
155
156
  layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
156
157
  ])
157
- self.threshold_head = keras.Sequential([
158
+ self.threshold_head = Sequential([
158
159
  *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
159
160
  layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
160
161
  layers.BatchNormalization(),
@@ -170,7 +171,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
170
171
  self,
171
172
  out_map: tf.Tensor,
172
173
  thresh_map: tf.Tensor,
173
- target: List[Dict[str, np.ndarray]],
174
+ target: list[dict[str, np.ndarray]],
174
175
  gamma: float = 2.0,
175
176
  alpha: float = 0.5,
176
177
  eps: float = 1e-8,
@@ -179,7 +180,6 @@ class DBNet(_DBNet, keras.Model, NestedObject):
179
180
  and a list of masks for each image. From there it computes the loss with the model output
180
181
 
181
182
  Args:
182
- ----
183
183
  out_map: output feature map of the model of shape (N, H, W, C)
184
184
  thresh_map: threshold map of shape (N, H, W, C)
185
185
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
@@ -188,7 +188,6 @@ class DBNet(_DBNet, keras.Model, NestedObject):
188
188
  eps: epsilon factor in dice loss
189
189
 
190
190
  Returns:
191
- -------
192
191
  A loss tensor
193
192
  """
194
193
  if gamma < 0:
@@ -206,7 +205,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
206
205
 
207
206
  # Focal loss
208
207
  focal_scale = 10.0
209
- bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
208
+ bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
210
209
 
211
210
  # Convert logits to prob, compute gamma factor
212
211
  p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
@@ -241,16 +240,16 @@ class DBNet(_DBNet, keras.Model, NestedObject):
241
240
  def call(
242
241
  self,
243
242
  x: tf.Tensor,
244
- target: Optional[List[Dict[str, np.ndarray]]] = None,
243
+ target: list[dict[str, np.ndarray]] | None = None,
245
244
  return_model_output: bool = False,
246
245
  return_preds: bool = False,
247
246
  **kwargs: Any,
248
- ) -> Dict[str, Any]:
247
+ ) -> dict[str, Any]:
249
248
  feat_maps = self.feat_extractor(x, **kwargs)
250
249
  feat_concat = self.fpn(feat_maps, **kwargs)
251
250
  logits = self.probability_head(feat_concat, **kwargs)
252
251
 
253
- out: Dict[str, tf.Tensor] = {}
252
+ out: dict[str, tf.Tensor] = {}
254
253
  if self.exportable:
255
254
  out["logits"] = logits
256
255
  return out
@@ -277,9 +276,9 @@ def _db_resnet(
277
276
  arch: str,
278
277
  pretrained: bool,
279
278
  backbone_fn,
280
- fpn_layers: List[str],
279
+ fpn_layers: list[str],
281
280
  pretrained_backbone: bool = True,
282
- input_shape: Optional[Tuple[int, int, int]] = None,
281
+ input_shape: tuple[int, int, int] | None = None,
283
282
  **kwargs: Any,
284
283
  ) -> DBNet:
285
284
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -305,9 +304,16 @@ def _db_resnet(
305
304
 
306
305
  # Build the model
307
306
  model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
307
+ _build_model(model)
308
+
308
309
  # Load pretrained parameters
309
310
  if pretrained:
310
- load_pretrained_params(model, _cfg["url"])
311
+ # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
312
+ load_pretrained_params(
313
+ model,
314
+ _cfg["url"],
315
+ skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
316
+ )
311
317
 
312
318
  return model
313
319
 
@@ -316,9 +322,9 @@ def _db_mobilenet(
316
322
  arch: str,
317
323
  pretrained: bool,
318
324
  backbone_fn,
319
- fpn_layers: List[str],
325
+ fpn_layers: list[str],
320
326
  pretrained_backbone: bool = True,
321
- input_shape: Optional[Tuple[int, int, int]] = None,
327
+ input_shape: tuple[int, int, int] | None = None,
322
328
  **kwargs: Any,
323
329
  ) -> DBNet:
324
330
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -326,6 +332,10 @@ def _db_mobilenet(
326
332
  # Patch the config
327
333
  _cfg = deepcopy(default_cfgs[arch])
328
334
  _cfg["input_shape"] = input_shape or _cfg["input_shape"]
335
+ if not kwargs.get("class_names", None):
336
+ kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
337
+ else:
338
+ kwargs["class_names"] = sorted(kwargs["class_names"])
329
339
 
330
340
  # Feature extractor
331
341
  feat_extractor = IntermediateLayerGetter(
@@ -339,9 +349,15 @@ def _db_mobilenet(
339
349
 
340
350
  # Build the model
341
351
  model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
352
+ _build_model(model)
342
353
  # Load pretrained parameters
343
354
  if pretrained:
344
- load_pretrained_params(model, _cfg["url"])
355
+ # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
356
+ load_pretrained_params(
357
+ model,
358
+ _cfg["url"],
359
+ skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
360
+ )
345
361
 
346
362
  return model
347
363
 
@@ -357,12 +373,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
357
373
  >>> out = model(input_tensor)
358
374
 
359
375
  Args:
360
- ----
361
376
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
362
377
  **kwargs: keyword arguments of the DBNet architecture
363
378
 
364
379
  Returns:
365
- -------
366
380
  text detection architecture
367
381
  """
368
382
  return _db_resnet(
@@ -385,12 +399,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
385
399
  >>> out = model(input_tensor)
386
400
 
387
401
  Args:
388
- ----
389
402
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
390
403
  **kwargs: keyword arguments of the DBNet architecture
391
404
 
392
405
  Returns:
393
- -------
394
406
  text detection architecture
395
407
  """
396
408
  return _db_mobilenet(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,11 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
- from typing import Dict, List, Tuple, Union
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -23,7 +22,6 @@ class FASTPostProcessor(DetectionPostProcessor):
23
22
  """Implements a post processor for FAST model.
24
23
 
25
24
  Args:
26
- ----
27
25
  bin_thresh: threshold used to binzarized p_map at inference time
28
26
  box_thresh: minimal objectness score to consider a box
29
27
  assume_straight_pages: whether the inputs were expected to have horizontal text elements
@@ -45,11 +43,9 @@ class FASTPostProcessor(DetectionPostProcessor):
45
43
  """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
46
44
 
47
45
  Args:
48
- ----
49
46
  points: The first parameter.
50
47
 
51
48
  Returns:
52
- -------
53
49
  a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
54
50
  """
55
51
  if not self.assume_straight_pages:
@@ -94,24 +90,22 @@ class FASTPostProcessor(DetectionPostProcessor):
94
90
  """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
95
91
 
96
92
  Args:
97
- ----
98
93
  pred: Pred map from differentiable linknet output
99
94
  bitmap: Bitmap map computed from pred (binarized)
100
95
  angle_tol: Comparison tolerance of the angle with the median angle across the page
101
96
  ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
102
97
 
103
98
  Returns:
104
- -------
105
99
  np tensor boxes for the bitmap, each box is a 6-element list
106
100
  containing x, y, w, h, alpha, score for the box
107
101
  """
108
102
  height, width = bitmap.shape[:2]
109
- boxes: List[Union[np.ndarray, List[float]]] = []
103
+ boxes: list[np.ndarray | list[float]] = []
110
104
  # get contours from connected components on the bitmap
111
105
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
112
106
  for contour in contours:
113
107
  # Check whether smallest enclosing bounding box is not too small
114
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
108
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
115
109
  continue
116
110
  # Compute objectness
117
111
  if self.assume_straight_pages:
@@ -158,20 +152,18 @@ class _FAST(BaseModel):
158
152
 
159
153
  def build_target(
160
154
  self,
161
- target: List[Dict[str, np.ndarray]],
162
- output_shape: Tuple[int, int, int],
155
+ target: list[dict[str, np.ndarray]],
156
+ output_shape: tuple[int, int, int],
163
157
  channels_last: bool = True,
164
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
158
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
165
159
  """Build the target, and it's mask to be used from loss computation.
166
160
 
167
161
  Args:
168
- ----
169
162
  target: target coming from dataset
170
163
  output_shape: shape of the output of the model without batch_size
171
164
  channels_last: whether channels are last or not
172
165
 
173
166
  Returns:
174
- -------
175
167
  the new formatted target, mask and shrunken text kernel
176
168
  """
177
169
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):