python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +8 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +7 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +4 -5
  17. doctr/datasets/ic13.py +4 -5
  18. doctr/datasets/iiit5k.py +6 -5
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +6 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +6 -5
  27. doctr/datasets/svhn.py +6 -5
  28. doctr/datasets/svt.py +4 -5
  29. doctr/datasets/synthtext.py +4 -5
  30. doctr/datasets/utils.py +34 -29
  31. doctr/datasets/vocabs.py +17 -7
  32. doctr/datasets/wildreceipt.py +14 -10
  33. doctr/file_utils.py +2 -7
  34. doctr/io/elements.py +59 -79
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +30 -48
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +8 -11
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +5 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +8 -21
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +6 -8
  52. doctr/models/classification/predictor/tensorflow.py +6 -8
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +20 -31
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +8 -15
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +9 -12
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +6 -12
  65. doctr/models/classification/zoo.py +19 -14
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +14 -26
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +14 -23
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +5 -6
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +3 -7
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +4 -5
  91. doctr/models/kie_predictor/pytorch.py +18 -19
  92. doctr/models/kie_predictor/tensorflow.py +13 -14
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -10
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  101. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +28 -29
  104. doctr/models/predictor/pytorch.py +12 -13
  105. doctr/models/predictor/tensorflow.py +8 -9
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +10 -14
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +11 -23
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +12 -22
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +16 -22
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +12 -21
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +12 -20
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +14 -17
  136. doctr/models/utils/tensorflow.py +17 -16
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +20 -28
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +58 -22
  145. doctr/transforms/modules/tensorflow.py +18 -32
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +16 -47
  150. doctr/utils/metrics.py +17 -37
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +9 -13
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.10.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List
6
+ from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available
8
+ from doctr.file_utils import is_tf_available, is_torch_available
9
9
 
10
10
  from .. import classification
11
11
  from ..preprocessor import PreProcessor
@@ -13,7 +13,7 @@ from .predictor import OrientationPredictor
13
13
 
14
14
  __all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
15
15
 
16
- ARCHS: List[str] = [
16
+ ARCHS: list[str] = [
17
17
  "magc_resnet31",
18
18
  "mobilenet_v3_small",
19
19
  "mobilenet_v3_small_r",
@@ -31,7 +31,7 @@ ARCHS: List[str] = [
31
31
  "vit_s",
32
32
  "vit_b",
33
33
  ]
34
- ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
34
+ ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
35
35
 
36
36
 
37
37
  def _orientation_predictor(
@@ -48,7 +48,14 @@ def _orientation_predictor(
48
48
  # Load directly classifier from backbone
49
49
  _model = classification.__dict__[arch](pretrained=pretrained)
50
50
  else:
51
- if not isinstance(arch, classification.MobileNetV3):
51
+ allowed_archs = [classification.MobileNetV3]
52
+ if is_torch_available():
53
+ # Adding the type for torch compiled models to the allowed architectures
54
+ from doctr.models.utils import _CompiledModule
55
+
56
+ allowed_archs.append(_CompiledModule)
57
+
58
+ if not isinstance(arch, tuple(allowed_archs)):
52
59
  raise ValueError(f"unknown architecture: {type(arch)}")
53
60
  _model = arch
54
61
 
@@ -63,7 +70,7 @@ def _orientation_predictor(
63
70
 
64
71
 
65
72
  def crop_orientation_predictor(
66
- arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
73
+ arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, batch_size: int = 128, **kwargs: Any
67
74
  ) -> OrientationPredictor:
68
75
  """Crop orientation classification architecture.
69
76
 
@@ -74,20 +81,19 @@ def crop_orientation_predictor(
74
81
  >>> out = model([input_crop])
75
82
 
76
83
  Args:
77
- ----
78
84
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
79
85
  pretrained: If True, returns a model pre-trained on our recognition crops dataset
86
+ batch_size: number of samples the model processes in parallel
80
87
  **kwargs: keyword arguments to be passed to the OrientationPredictor
81
88
 
82
89
  Returns:
83
- -------
84
90
  OrientationPredictor
85
91
  """
86
- return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs)
92
+ return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="crop", **kwargs)
87
93
 
88
94
 
89
95
  def page_orientation_predictor(
90
- arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
96
+ arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, batch_size: int = 4, **kwargs: Any
91
97
  ) -> OrientationPredictor:
92
98
  """Page orientation classification architecture.
93
99
 
@@ -98,13 +104,12 @@ def page_orientation_predictor(
98
104
  >>> out = model([input_page])
99
105
 
100
106
  Args:
101
- ----
102
107
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
103
108
  pretrained: If True, returns a model pre-trained on our recognition crops dataset
109
+ batch_size: number of samples the model processes in parallel
104
110
  **kwargs: keyword arguments to be passed to the OrientationPredictor
105
111
 
106
112
  Returns:
107
- -------
108
113
  OrientationPredictor
109
114
  """
110
- return _orientation_predictor(arch, pretrained, model_type="page", **kwargs)
115
+ return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="page", **kwargs)
doctr/models/core.py CHANGED
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
- from typing import Any, Dict, Optional
7
+ from typing import Any
8
8
 
9
9
  from doctr.utils.repr import NestedObject
10
10
 
@@ -14,6 +14,6 @@ __all__ = ["BaseModel"]
14
14
  class BaseModel(NestedObject):
15
15
  """Implements abstract DetectionModel class"""
16
16
 
17
- def __init__(self, cfg: Optional[Dict[str, Any]] = None) -> None:
17
+ def __init__(self, cfg: dict[str, Any] | None = None) -> None:
18
18
  super().__init__()
19
19
  self.cfg = cfg
@@ -1,7 +1,7 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
  from .base import *
3
3
 
4
- if is_tf_available():
5
- from .tensorflow import *
6
- else:
4
+ if is_torch_available():
7
5
  from .pytorch import *
6
+ elif is_tf_available():
7
+ from .tensorflow import *
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Dict, List
7
6
 
8
7
  import numpy as np
9
8
 
@@ -11,16 +10,15 @@ __all__ = ["_remove_padding"]
11
10
 
12
11
 
13
12
  def _remove_padding(
14
- pages: List[np.ndarray],
15
- loc_preds: List[Dict[str, np.ndarray]],
13
+ pages: list[np.ndarray],
14
+ loc_preds: list[dict[str, np.ndarray]],
16
15
  preserve_aspect_ratio: bool,
17
16
  symmetric_pad: bool,
18
17
  assume_straight_pages: bool,
19
- ) -> List[Dict[str, np.ndarray]]:
18
+ ) -> list[dict[str, np.ndarray]]:
20
19
  """Remove padding from the localization predictions
21
20
 
22
21
  Args:
23
- ----
24
22
  pages: list of pages
25
23
  loc_preds: list of localization predictions
26
24
  preserve_aspect_ratio: whether the aspect ratio was preserved during padding
@@ -28,7 +26,6 @@ def _remove_padding(
28
26
  assume_straight_pages: whether the pages are assumed to be straight
29
27
 
30
28
  Returns:
31
- -------
32
29
  list of unpaded localization predictions
33
30
  """
34
31
  if preserve_aspect_ratio:
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -13,12 +13,10 @@ def erode(x: Tensor, kernel_size: int) -> Tensor:
13
13
  """Performs erosion on a given tensor
14
14
 
15
15
  Args:
16
- ----
17
16
  x: boolean tensor of shape (N, C, H, W)
18
17
  kernel_size: the size of the kernel to use for erosion
19
18
 
20
19
  Returns:
21
- -------
22
20
  the eroded tensor
23
21
  """
24
22
  _pad = (kernel_size - 1) // 2
@@ -30,12 +28,10 @@ def dilate(x: Tensor, kernel_size: int) -> Tensor:
30
28
  """Performs dilation on a given tensor
31
29
 
32
30
  Args:
33
- ----
34
31
  x: boolean tensor of shape (N, C, H, W)
35
32
  kernel_size: the size of the kernel to use for dilation
36
33
 
37
34
  Returns:
38
- -------
39
35
  the dilated tensor
40
36
  """
41
37
  _pad = (kernel_size - 1) // 2
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -12,12 +12,10 @@ def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
12
12
  """Performs erosion on a given tensor
13
13
 
14
14
  Args:
15
- ----
16
15
  x: boolean tensor of shape (N, H, W, C)
17
16
  kernel_size: the size of the kernel to use for erosion
18
17
 
19
18
  Returns:
20
- -------
21
19
  the eroded tensor
22
20
  """
23
21
  return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME")
@@ -27,12 +25,10 @@ def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
27
25
  """Performs dilation on a given tensor
28
26
 
29
27
  Args:
30
- ----
31
28
  x: boolean tensor of shape (N, H, W, C)
32
29
  kernel_size: the size of the kernel to use for dilation
33
30
 
34
31
  Returns:
35
- -------
36
32
  the dilated tensor
37
33
  """
38
34
  return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME")
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List
7
6
 
8
7
  import cv2
9
8
  import numpy as np
@@ -17,7 +16,6 @@ class DetectionPostProcessor(NestedObject):
17
16
  """Abstract class to postprocess the raw output of the model
18
17
 
19
18
  Args:
20
- ----
21
19
  box_thresh (float): minimal objectness score to consider a box
22
20
  bin_thresh (float): threshold to apply to segmentation raw heatmap
23
21
  assume straight_pages (bool): if True, fit straight boxes only
@@ -37,13 +35,11 @@ class DetectionPostProcessor(NestedObject):
37
35
  """Compute the confidence score for a polygon : mean of the p values on the polygon
38
36
 
39
37
  Args:
40
- ----
41
38
  pred (np.ndarray): p map returned by the model
42
39
  points: coordinates of the polygon
43
40
  assume_straight_pages: if True, fit straight boxes only
44
41
 
45
42
  Returns:
46
- -------
47
43
  polygon objectness
48
44
  """
49
45
  h, w = pred.shape[:2]
@@ -71,15 +67,13 @@ class DetectionPostProcessor(NestedObject):
71
67
  def __call__(
72
68
  self,
73
69
  proba_map,
74
- ) -> List[List[np.ndarray]]:
70
+ ) -> list[list[np.ndarray]]:
75
71
  """Performs postprocessing for a list of model outputs
76
72
 
77
73
  Args:
78
- ----
79
74
  proba_map: probability map of shape (N, H, W, C)
80
75
 
81
76
  Returns:
82
- -------
83
77
  list of N class predictions (for each input sample), where each class predictions is a list of C tensors
84
78
  of shape (*, 5) or (*, 6)
85
79
  """
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,11 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
- from typing import Dict, List, Tuple, Union
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -22,7 +21,6 @@ class DBPostProcessor(DetectionPostProcessor):
22
21
  <https://github.com/xuannianz/DifferentiableBinarization>`_.
23
22
 
24
23
  Args:
25
- ----
26
24
  unclip ratio: ratio used to unshrink polygons
27
25
  min_size_box: minimal length (pix) to keep a box
28
26
  max_candidates: maximum boxes to consider in a single page
@@ -47,11 +45,9 @@ class DBPostProcessor(DetectionPostProcessor):
47
45
  """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
48
46
 
49
47
  Args:
50
- ----
51
48
  points: The first parameter.
52
49
 
53
50
  Returns:
54
- -------
55
51
  a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
56
52
  """
57
53
  if not self.assume_straight_pages:
@@ -96,25 +92,23 @@ class DBPostProcessor(DetectionPostProcessor):
96
92
  """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
97
93
 
98
94
  Args:
99
- ----
100
95
  pred: Pred map from differentiable binarization output
101
96
  bitmap: Bitmap map computed from pred (binarized)
102
97
  angle_tol: Comparison tolerance of the angle with the median angle across the page
103
98
  ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
104
99
 
105
100
  Returns:
106
- -------
107
101
  np tensor boxes for the bitmap, each box is a 5-element list
108
102
  containing x, y, w, h, score for the box
109
103
  """
110
104
  height, width = bitmap.shape[:2]
111
105
  min_size_box = 2
112
- boxes: List[Union[np.ndarray, List[float]]] = []
106
+ boxes: list[np.ndarray | list[float]] = []
113
107
  # get contours from connected components on the bitmap
114
108
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
115
109
  for contour in contours:
116
110
  # Check whether smallest enclosing bounding box is not too small
117
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): # type: ignore[index]
111
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
118
112
  continue
119
113
  # Compute objectness
120
114
  if self.assume_straight_pages:
@@ -164,7 +158,6 @@ class _DBNet:
164
158
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
165
159
 
166
160
  Args:
167
- ----
168
161
  feature extractor: the backbone serving as feature extractor
169
162
  fpn_channels: number of channels each extracted feature maps is mapped to
170
163
  """
@@ -186,7 +179,6 @@ class _DBNet:
186
179
  """Compute the distance for each point of the map (xs, ys) to the (a, b) segment
187
180
 
188
181
  Args:
189
- ----
190
182
  xs : map of x coordinates (height, width)
191
183
  ys : map of y coordinates (height, width)
192
184
  a: first point defining the [ab] segment
@@ -194,7 +186,6 @@ class _DBNet:
194
186
  eps: epsilon to avoid division by zero
195
187
 
196
188
  Returns:
197
- -------
198
189
  The computed distance
199
190
 
200
191
  """
@@ -214,11 +205,10 @@ class _DBNet:
214
205
  polygon: np.ndarray,
215
206
  canvas: np.ndarray,
216
207
  mask: np.ndarray,
217
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
208
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
218
209
  """Draw a polygon treshold map on a canvas, as described in the DB paper
219
210
 
220
211
  Args:
221
- ----
222
212
  polygon : array of coord., to draw the boundary of the polygon
223
213
  canvas : threshold map to fill with polygons
224
214
  mask : mask for training on threshold polygons
@@ -278,10 +268,10 @@ class _DBNet:
278
268
 
279
269
  def build_target(
280
270
  self,
281
- target: List[Dict[str, np.ndarray]],
282
- output_shape: Tuple[int, int, int],
271
+ target: list[dict[str, np.ndarray]],
272
+ output_shape: tuple[int, int, int],
283
273
  channels_last: bool = True,
284
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
274
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
285
275
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
286
276
  raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
287
277
  if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, Dict, List, Optional
6
+ from collections.abc import Callable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -22,7 +23,7 @@ from .base import DBPostProcessor, _DBNet
22
23
  __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
23
24
 
24
25
 
25
- default_cfgs: Dict[str, Dict[str, Any]] = {
26
+ default_cfgs: dict[str, dict[str, Any]] = {
26
27
  "db_resnet50": {
27
28
  "input_shape": (3, 1024, 1024),
28
29
  "mean": (0.798, 0.785, 0.772),
@@ -47,7 +48,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
47
48
  class FeaturePyramidNetwork(nn.Module):
48
49
  def __init__(
49
50
  self,
50
- in_channels: List[int],
51
+ in_channels: list[int],
51
52
  out_channels: int,
52
53
  deform_conv: bool = False,
53
54
  ) -> None:
@@ -76,12 +77,12 @@ class FeaturePyramidNetwork(nn.Module):
76
77
  for idx, chans in enumerate(in_channels)
77
78
  ])
78
79
 
79
- def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
80
+ def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
80
81
  if len(x) != len(self.out_branches):
81
82
  raise AssertionError
82
83
  # Conv1x1 to get the same number of channels
83
- _x: List[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)]
84
- out: List[torch.Tensor] = [_x[-1]]
84
+ _x: list[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)]
85
+ out: list[torch.Tensor] = [_x[-1]]
85
86
  for t in _x[:-1][::-1]:
86
87
  out.append(self.upsample(out[-1]) + t)
87
88
 
@@ -96,7 +97,6 @@ class DBNet(_DBNet, nn.Module):
96
97
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
97
98
 
98
99
  Args:
99
- ----
100
100
  feature extractor: the backbone serving as feature extractor
101
101
  head_chans: the number of channels in the head
102
102
  deform_conv: whether to use deformable convolution
@@ -117,8 +117,8 @@ class DBNet(_DBNet, nn.Module):
117
117
  box_thresh: float = 0.1,
118
118
  assume_straight_pages: bool = True,
119
119
  exportable: bool = False,
120
- cfg: Optional[Dict[str, Any]] = None,
121
- class_names: List[str] = [CLASS_NAME],
120
+ cfg: dict[str, Any] | None = None,
121
+ class_names: list[str] = [CLASS_NAME],
122
122
  ) -> None:
123
123
  super().__init__()
124
124
  self.class_names = class_names
@@ -182,10 +182,10 @@ class DBNet(_DBNet, nn.Module):
182
182
  def forward(
183
183
  self,
184
184
  x: torch.Tensor,
185
- target: Optional[List[np.ndarray]] = None,
185
+ target: list[np.ndarray] | None = None,
186
186
  return_model_output: bool = False,
187
187
  return_preds: bool = False,
188
- ) -> Dict[str, torch.Tensor]:
188
+ ) -> dict[str, torch.Tensor]:
189
189
  # Extract feature maps at different stages
190
190
  feats = self.feat_extractor(x)
191
191
  feats = [feats[str(idx)] for idx in range(len(feats))]
@@ -193,7 +193,7 @@ class DBNet(_DBNet, nn.Module):
193
193
  feat_concat = self.fpn(feats)
194
194
  logits = self.prob_head(feat_concat)
195
195
 
196
- out: Dict[str, Any] = {}
196
+ out: dict[str, Any] = {}
197
197
  if self.exportable:
198
198
  out["logits"] = logits
199
199
  return out
@@ -205,11 +205,16 @@ class DBNet(_DBNet, nn.Module):
205
205
  out["out_map"] = prob_map
206
206
 
207
207
  if target is None or return_preds:
208
+ # Disable for torch.compile compatibility
209
+ @torch.compiler.disable # type: ignore[attr-defined]
210
+ def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
211
+ return [
212
+ dict(zip(self.class_names, preds))
213
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
214
+ ]
215
+
208
216
  # Post-process boxes (keep only text predictions)
209
- out["preds"] = [
210
- dict(zip(self.class_names, preds))
211
- for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
212
- ]
217
+ out["preds"] = _postprocess(prob_map)
213
218
 
214
219
  if target is not None:
215
220
  thresh_map = self.thresh_head(feat_concat)
@@ -222,7 +227,7 @@ class DBNet(_DBNet, nn.Module):
222
227
  self,
223
228
  out_map: torch.Tensor,
224
229
  thresh_map: torch.Tensor,
225
- target: List[np.ndarray],
230
+ target: list[np.ndarray],
226
231
  gamma: float = 2.0,
227
232
  alpha: float = 0.5,
228
233
  eps: float = 1e-8,
@@ -231,7 +236,6 @@ class DBNet(_DBNet, nn.Module):
231
236
  and a list of masks for each image. From there it computes the loss with the model output
232
237
 
233
238
  Args:
234
- ----
235
239
  out_map: output feature map of the model of shape (N, C, H, W)
236
240
  thresh_map: threshold map of shape (N, C, H, W)
237
241
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
@@ -240,7 +244,6 @@ class DBNet(_DBNet, nn.Module):
240
244
  eps: epsilon factor in dice loss
241
245
 
242
246
  Returns:
243
- -------
244
247
  A loss tensor
245
248
  """
246
249
  if gamma < 0:
@@ -273,7 +276,7 @@ class DBNet(_DBNet, nn.Module):
273
276
  dice_map = torch.softmax(out_map, dim=1)
274
277
  else:
275
278
  # compute binary map instead
276
- dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
279
+ dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
277
280
  # Class reduced
278
281
  inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
279
282
  cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
@@ -290,10 +293,10 @@ def _dbnet(
290
293
  arch: str,
291
294
  pretrained: bool,
292
295
  backbone_fn: Callable[[bool], nn.Module],
293
- fpn_layers: List[str],
294
- backbone_submodule: Optional[str] = None,
296
+ fpn_layers: list[str],
297
+ backbone_submodule: str | None = None,
295
298
  pretrained_backbone: bool = True,
296
- ignore_keys: Optional[List[str]] = None,
299
+ ignore_keys: list[str] | None = None,
297
300
  **kwargs: Any,
298
301
  ) -> DBNet:
299
302
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -341,12 +344,10 @@ def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet:
341
344
  >>> out = model(input_tensor)
342
345
 
343
346
  Args:
344
- ----
345
347
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
346
348
  **kwargs: keyword arguments of the DBNet architecture
347
349
 
348
350
  Returns:
349
- -------
350
351
  text detection architecture
351
352
  """
352
353
  return _dbnet(
@@ -376,12 +377,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
376
377
  >>> out = model(input_tensor)
377
378
 
378
379
  Args:
379
- ----
380
380
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
381
381
  **kwargs: keyword arguments of the DBNet architecture
382
382
 
383
383
  Returns:
384
- -------
385
384
  text detection architecture
386
385
  """
387
386
  return _dbnet(
@@ -411,12 +410,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
411
410
  >>> out = model(input_tensor)
412
411
 
413
412
  Args:
414
- ----
415
413
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
416
414
  **kwargs: keyword arguments of the DBNet architecture
417
415
 
418
416
  Returns:
419
- -------
420
417
  text detection architecture
421
418
  """
422
419
  return _dbnet(