python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import Sequential, layers
@@ -19,7 +19,7 @@ from ...utils import _build_model, load_pretrained_params
19
19
  __all__ = ["vit_s", "vit_b"]
20
20
 
21
21
 
22
- default_cfgs: Dict[str, Dict[str, Any]] = {
22
+ default_cfgs: dict[str, dict[str, Any]] = {
23
23
  "vit_s": {
24
24
  "mean": (0.694, 0.695, 0.693),
25
25
  "std": (0.299, 0.296, 0.301),
@@ -41,7 +41,6 @@ class ClassifierHead(layers.Layer, NestedObject):
41
41
  """Classifier head for Vision Transformer
42
42
 
43
43
  Args:
44
- ----
45
44
  num_classes: number of output classes
46
45
  """
47
46
 
@@ -61,7 +60,6 @@ class VisionTransformer(Sequential):
61
60
  <https://arxiv.org/pdf/2010.11929.pdf>`_.
62
61
 
63
62
  Args:
64
- ----
65
63
  d_model: dimension of the transformer layers
66
64
  num_layers: number of transformer layers
67
65
  num_heads: number of attention heads
@@ -79,12 +77,12 @@ class VisionTransformer(Sequential):
79
77
  num_layers: int,
80
78
  num_heads: int,
81
79
  ffd_ratio: int,
82
- patch_size: Tuple[int, int] = (4, 4),
83
- input_shape: Tuple[int, int, int] = (32, 32, 3),
80
+ patch_size: tuple[int, int] = (4, 4),
81
+ input_shape: tuple[int, int, int] = (32, 32, 3),
84
82
  dropout: float = 0.0,
85
83
  num_classes: int = 1000,
86
84
  include_top: bool = True,
87
- cfg: Optional[Dict[str, Any]] = None,
85
+ cfg: dict[str, Any] | None = None,
88
86
  ) -> None:
89
87
  _layers = [
90
88
  PatchEmbedding(input_shape, d_model, patch_size),
@@ -103,6 +101,15 @@ class VisionTransformer(Sequential):
103
101
  super().__init__(_layers)
104
102
  self.cfg = cfg
105
103
 
104
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
105
+ """Load pretrained parameters onto the model
106
+
107
+ Args:
108
+ path_or_url: the path or URL to the model parameters (checkpoint)
109
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
110
+ """
111
+ load_pretrained_params(self, path_or_url, **kwargs)
112
+
106
113
 
107
114
  def _vit(
108
115
  arch: str,
@@ -148,12 +155,10 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
148
155
  >>> out = model(input_tensor)
149
156
 
150
157
  Args:
151
- ----
152
158
  pretrained: boolean, True if model is pretrained
153
159
  **kwargs: keyword arguments of the VisionTransformer architecture
154
160
 
155
161
  Returns:
156
- -------
157
162
  A feature extractor model
158
163
  """
159
164
  return _vit(
@@ -179,12 +184,10 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
179
184
  >>> out = model(input_tensor)
180
185
 
181
186
  Args:
182
- ----
183
187
  pretrained: boolean, True if model is pretrained
184
188
  **kwargs: keyword arguments of the VisionTransformer architecture
185
189
 
186
190
  Returns:
187
- -------
188
191
  A feature extractor model
189
192
  """
190
193
  return _vit(
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List
6
+ from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available
8
+ from doctr.file_utils import is_tf_available, is_torch_available
9
9
 
10
10
  from .. import classification
11
11
  from ..preprocessor import PreProcessor
@@ -13,7 +13,7 @@ from .predictor import OrientationPredictor
13
13
 
14
14
  __all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
15
15
 
16
- ARCHS: List[str] = [
16
+ ARCHS: list[str] = [
17
17
  "magc_resnet31",
18
18
  "mobilenet_v3_small",
19
19
  "mobilenet_v3_small_r",
@@ -31,7 +31,11 @@ ARCHS: List[str] = [
31
31
  "vit_s",
32
32
  "vit_b",
33
33
  ]
34
- ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
34
+
35
+ if is_torch_available():
36
+ ARCHS.extend(["vip_tiny", "vip_base"])
37
+
38
+ ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
35
39
 
36
40
 
37
41
  def _orientation_predictor(
@@ -48,7 +52,14 @@ def _orientation_predictor(
48
52
  # Load directly classifier from backbone
49
53
  _model = classification.__dict__[arch](pretrained=pretrained)
50
54
  else:
51
- if not isinstance(arch, classification.MobileNetV3):
55
+ allowed_archs = [classification.MobileNetV3]
56
+ if is_torch_available():
57
+ # Adding the type for torch compiled models to the allowed architectures
58
+ from doctr.models.utils import _CompiledModule
59
+
60
+ allowed_archs.append(_CompiledModule)
61
+
62
+ if not isinstance(arch, tuple(allowed_archs)):
52
63
  raise ValueError(f"unknown architecture: {type(arch)}")
53
64
  _model = arch
54
65
 
@@ -63,7 +74,7 @@ def _orientation_predictor(
63
74
 
64
75
 
65
76
  def crop_orientation_predictor(
66
- arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
77
+ arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, batch_size: int = 128, **kwargs: Any
67
78
  ) -> OrientationPredictor:
68
79
  """Crop orientation classification architecture.
69
80
 
@@ -74,20 +85,19 @@ def crop_orientation_predictor(
74
85
  >>> out = model([input_crop])
75
86
 
76
87
  Args:
77
- ----
78
88
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
79
89
  pretrained: If True, returns a model pre-trained on our recognition crops dataset
90
+ batch_size: number of samples the model processes in parallel
80
91
  **kwargs: keyword arguments to be passed to the OrientationPredictor
81
92
 
82
93
  Returns:
83
- -------
84
94
  OrientationPredictor
85
95
  """
86
- return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs)
96
+ return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="crop", **kwargs)
87
97
 
88
98
 
89
99
  def page_orientation_predictor(
90
- arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
100
+ arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, batch_size: int = 4, **kwargs: Any
91
101
  ) -> OrientationPredictor:
92
102
  """Page orientation classification architecture.
93
103
 
@@ -98,13 +108,12 @@ def page_orientation_predictor(
98
108
  >>> out = model([input_page])
99
109
 
100
110
  Args:
101
- ----
102
111
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
103
112
  pretrained: If True, returns a model pre-trained on our recognition crops dataset
113
+ batch_size: number of samples the model processes in parallel
104
114
  **kwargs: keyword arguments to be passed to the OrientationPredictor
105
115
 
106
116
  Returns:
107
- -------
108
117
  OrientationPredictor
109
118
  """
110
- return _orientation_predictor(arch, pretrained, model_type="page", **kwargs)
119
+ return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="page", **kwargs)
doctr/models/core.py CHANGED
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
- from typing import Any, Dict, Optional
7
+ from typing import Any
8
8
 
9
9
  from doctr.utils.repr import NestedObject
10
10
 
@@ -14,6 +14,6 @@ __all__ = ["BaseModel"]
14
14
  class BaseModel(NestedObject):
15
15
  """Implements abstract DetectionModel class"""
16
16
 
17
- def __init__(self, cfg: Optional[Dict[str, Any]] = None) -> None:
17
+ def __init__(self, cfg: dict[str, Any] | None = None) -> None:
18
18
  super().__init__()
19
19
  self.cfg = cfg
@@ -1,7 +1,7 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
  from .base import *
3
3
 
4
- if is_tf_available():
5
- from .tensorflow import *
6
- else:
4
+ if is_torch_available():
7
5
  from .pytorch import *
6
+ elif is_tf_available():
7
+ from .tensorflow import *
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Dict, List
7
6
 
8
7
  import numpy as np
9
8
 
@@ -11,16 +10,15 @@ __all__ = ["_remove_padding"]
11
10
 
12
11
 
13
12
  def _remove_padding(
14
- pages: List[np.ndarray],
15
- loc_preds: List[Dict[str, np.ndarray]],
13
+ pages: list[np.ndarray],
14
+ loc_preds: list[dict[str, np.ndarray]],
16
15
  preserve_aspect_ratio: bool,
17
16
  symmetric_pad: bool,
18
17
  assume_straight_pages: bool,
19
- ) -> List[Dict[str, np.ndarray]]:
18
+ ) -> list[dict[str, np.ndarray]]:
20
19
  """Remove padding from the localization predictions
21
20
 
22
21
  Args:
23
- ----
24
22
  pages: list of pages
25
23
  loc_preds: list of localization predictions
26
24
  preserve_aspect_ratio: whether the aspect ratio was preserved during padding
@@ -28,7 +26,6 @@ def _remove_padding(
28
26
  assume_straight_pages: whether the pages are assumed to be straight
29
27
 
30
28
  Returns:
31
- -------
32
29
  list of unpaded localization predictions
33
30
  """
34
31
  if preserve_aspect_ratio:
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -13,12 +13,10 @@ def erode(x: Tensor, kernel_size: int) -> Tensor:
13
13
  """Performs erosion on a given tensor
14
14
 
15
15
  Args:
16
- ----
17
16
  x: boolean tensor of shape (N, C, H, W)
18
17
  kernel_size: the size of the kernel to use for erosion
19
18
 
20
19
  Returns:
21
- -------
22
20
  the eroded tensor
23
21
  """
24
22
  _pad = (kernel_size - 1) // 2
@@ -30,12 +28,10 @@ def dilate(x: Tensor, kernel_size: int) -> Tensor:
30
28
  """Performs dilation on a given tensor
31
29
 
32
30
  Args:
33
- ----
34
31
  x: boolean tensor of shape (N, C, H, W)
35
32
  kernel_size: the size of the kernel to use for dilation
36
33
 
37
34
  Returns:
38
- -------
39
35
  the dilated tensor
40
36
  """
41
37
  _pad = (kernel_size - 1) // 2
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -12,12 +12,10 @@ def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
12
12
  """Performs erosion on a given tensor
13
13
 
14
14
  Args:
15
- ----
16
15
  x: boolean tensor of shape (N, H, W, C)
17
16
  kernel_size: the size of the kernel to use for erosion
18
17
 
19
18
  Returns:
20
- -------
21
19
  the eroded tensor
22
20
  """
23
21
  return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME")
@@ -27,12 +25,10 @@ def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
27
25
  """Performs dilation on a given tensor
28
26
 
29
27
  Args:
30
- ----
31
28
  x: boolean tensor of shape (N, H, W, C)
32
29
  kernel_size: the size of the kernel to use for dilation
33
30
 
34
31
  Returns:
35
- -------
36
32
  the dilated tensor
37
33
  """
38
34
  return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME")
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List
7
6
 
8
7
  import cv2
9
8
  import numpy as np
@@ -17,7 +16,6 @@ class DetectionPostProcessor(NestedObject):
17
16
  """Abstract class to postprocess the raw output of the model
18
17
 
19
18
  Args:
20
- ----
21
19
  box_thresh (float): minimal objectness score to consider a box
22
20
  bin_thresh (float): threshold to apply to segmentation raw heatmap
23
21
  assume straight_pages (bool): if True, fit straight boxes only
@@ -37,13 +35,11 @@ class DetectionPostProcessor(NestedObject):
37
35
  """Compute the confidence score for a polygon : mean of the p values on the polygon
38
36
 
39
37
  Args:
40
- ----
41
38
  pred (np.ndarray): p map returned by the model
42
39
  points: coordinates of the polygon
43
40
  assume_straight_pages: if True, fit straight boxes only
44
41
 
45
42
  Returns:
46
- -------
47
43
  polygon objectness
48
44
  """
49
45
  h, w = pred.shape[:2]
@@ -71,15 +67,13 @@ class DetectionPostProcessor(NestedObject):
71
67
  def __call__(
72
68
  self,
73
69
  proba_map,
74
- ) -> List[List[np.ndarray]]:
70
+ ) -> list[list[np.ndarray]]:
75
71
  """Performs postprocessing for a list of model outputs
76
72
 
77
73
  Args:
78
- ----
79
74
  proba_map: probability map of shape (N, H, W, C)
80
75
 
81
76
  Returns:
82
- -------
83
77
  list of N class predictions (for each input sample), where each class predictions is a list of C tensors
84
78
  of shape (*, 5) or (*, 6)
85
79
  """
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,11 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
- from typing import Dict, List, Tuple, Union
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -22,7 +21,6 @@ class DBPostProcessor(DetectionPostProcessor):
22
21
  <https://github.com/xuannianz/DifferentiableBinarization>`_.
23
22
 
24
23
  Args:
25
- ----
26
24
  unclip ratio: ratio used to unshrink polygons
27
25
  min_size_box: minimal length (pix) to keep a box
28
26
  max_candidates: maximum boxes to consider in a single page
@@ -47,11 +45,9 @@ class DBPostProcessor(DetectionPostProcessor):
47
45
  """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
48
46
 
49
47
  Args:
50
- ----
51
48
  points: The first parameter.
52
49
 
53
50
  Returns:
54
- -------
55
51
  a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
56
52
  """
57
53
  if not self.assume_straight_pages:
@@ -62,9 +58,8 @@ class DBPostProcessor(DetectionPostProcessor):
62
58
  area = (rect[1][0] + 1) * (1 + rect[1][1])
63
59
  length = 2 * (rect[1][0] + rect[1][1]) + 2
64
60
  else:
65
- poly = Polygon(points)
66
- area = poly.area
67
- length = poly.length
61
+ area = cv2.contourArea(points)
62
+ length = cv2.arcLength(points, closed=True)
68
63
  distance = area * self.unclip_ratio / length # compute distance to expand polygon
69
64
  offset = pyclipper.PyclipperOffset()
70
65
  offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
@@ -96,25 +91,23 @@ class DBPostProcessor(DetectionPostProcessor):
96
91
  """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
97
92
 
98
93
  Args:
99
- ----
100
94
  pred: Pred map from differentiable binarization output
101
95
  bitmap: Bitmap map computed from pred (binarized)
102
96
  angle_tol: Comparison tolerance of the angle with the median angle across the page
103
97
  ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
104
98
 
105
99
  Returns:
106
- -------
107
100
  np tensor boxes for the bitmap, each box is a 5-element list
108
101
  containing x, y, w, h, score for the box
109
102
  """
110
103
  height, width = bitmap.shape[:2]
111
104
  min_size_box = 2
112
- boxes: List[Union[np.ndarray, List[float]]] = []
105
+ boxes: list[np.ndarray | list[float]] = []
113
106
  # get contours from connected components on the bitmap
114
107
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
115
108
  for contour in contours:
116
109
  # Check whether smallest enclosing bounding box is not too small
117
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): # type: ignore[index]
110
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
118
111
  continue
119
112
  # Compute objectness
120
113
  if self.assume_straight_pages:
@@ -164,7 +157,6 @@ class _DBNet:
164
157
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
165
158
 
166
159
  Args:
167
- ----
168
160
  feature extractor: the backbone serving as feature extractor
169
161
  fpn_channels: number of channels each extracted feature maps is mapped to
170
162
  """
@@ -186,7 +178,6 @@ class _DBNet:
186
178
  """Compute the distance for each point of the map (xs, ys) to the (a, b) segment
187
179
 
188
180
  Args:
189
- ----
190
181
  xs : map of x coordinates (height, width)
191
182
  ys : map of y coordinates (height, width)
192
183
  a: first point defining the [ab] segment
@@ -194,7 +185,6 @@ class _DBNet:
194
185
  eps: epsilon to avoid division by zero
195
186
 
196
187
  Returns:
197
- -------
198
188
  The computed distance
199
189
 
200
190
  """
@@ -214,11 +204,10 @@ class _DBNet:
214
204
  polygon: np.ndarray,
215
205
  canvas: np.ndarray,
216
206
  mask: np.ndarray,
217
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
218
- """Draw a polygon treshold map on a canvas, as described in the DB paper
207
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
208
+ """Draw a polygon threshold map on a canvas, as described in the DB paper
219
209
 
220
210
  Args:
221
- ----
222
211
  polygon : array of coord., to draw the boundary of the polygon
223
212
  canvas : threshold map to fill with polygons
224
213
  mask : mask for training on threshold polygons
@@ -278,10 +267,10 @@ class _DBNet:
278
267
 
279
268
  def build_target(
280
269
  self,
281
- target: List[Dict[str, np.ndarray]],
282
- output_shape: Tuple[int, int, int],
270
+ target: list[dict[str, np.ndarray]],
271
+ output_shape: tuple[int, int, int],
283
272
  channels_last: bool = True,
284
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
273
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
285
274
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
286
275
  raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
287
276
  if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):