python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import torch
10
10
  from torch import nn
@@ -18,7 +18,7 @@ from ...utils.pytorch import load_pretrained_params
18
18
  __all__ = ["vit_s", "vit_b"]
19
19
 
20
20
 
21
- default_cfgs: Dict[str, Dict[str, Any]] = {
21
+ default_cfgs: dict[str, dict[str, Any]] = {
22
22
  "vit_s": {
23
23
  "mean": (0.694, 0.695, 0.693),
24
24
  "std": (0.299, 0.296, 0.301),
@@ -40,7 +40,6 @@ class ClassifierHead(nn.Module):
40
40
  """Classifier head for Vision Transformer
41
41
 
42
42
  Args:
43
- ----
44
43
  in_channels: number of input channels
45
44
  num_classes: number of output classes
46
45
  """
@@ -65,7 +64,6 @@ class VisionTransformer(nn.Sequential):
65
64
  <https://arxiv.org/pdf/2010.11929.pdf>`_.
66
65
 
67
66
  Args:
68
- ----
69
67
  d_model: dimension of the transformer layers
70
68
  num_layers: number of transformer layers
71
69
  num_heads: number of attention heads
@@ -83,14 +81,14 @@ class VisionTransformer(nn.Sequential):
83
81
  num_layers: int,
84
82
  num_heads: int,
85
83
  ffd_ratio: int,
86
- patch_size: Tuple[int, int] = (4, 4),
87
- input_shape: Tuple[int, int, int] = (3, 32, 32),
84
+ patch_size: tuple[int, int] = (4, 4),
85
+ input_shape: tuple[int, int, int] = (3, 32, 32),
88
86
  dropout: float = 0.0,
89
87
  num_classes: int = 1000,
90
88
  include_top: bool = True,
91
- cfg: Optional[Dict[str, Any]] = None,
89
+ cfg: dict[str, Any] | None = None,
92
90
  ) -> None:
93
- _layers: List[nn.Module] = [
91
+ _layers: list[nn.Module] = [
94
92
  PatchEmbedding(input_shape, d_model, patch_size),
95
93
  EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, nn.GELU()),
96
94
  ]
@@ -104,7 +102,7 @@ class VisionTransformer(nn.Sequential):
104
102
  def _vit(
105
103
  arch: str,
106
104
  pretrained: bool,
107
- ignore_keys: Optional[List[str]] = None,
105
+ ignore_keys: list[str] | None = None,
108
106
  **kwargs: Any,
109
107
  ) -> VisionTransformer:
110
108
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -143,12 +141,10 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
143
141
  >>> out = model(input_tensor)
144
142
 
145
143
  Args:
146
- ----
147
144
  pretrained: boolean, True if model is pretrained
148
145
  **kwargs: keyword arguments of the VisionTransformer architecture
149
146
 
150
147
  Returns:
151
- -------
152
148
  A feature extractor model
153
149
  """
154
150
  return _vit(
@@ -175,12 +171,10 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
175
171
  >>> out = model(input_tensor)
176
172
 
177
173
  Args:
178
- ----
179
174
  pretrained: boolean, True if model is pretrained
180
175
  **kwargs: keyword arguments of the VisionTransformer architecture
181
176
 
182
177
  Returns:
183
- -------
184
178
  A feature extractor model
185
179
  """
186
180
  return _vit(
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import Sequential, layers
@@ -14,25 +14,25 @@ from doctr.models.modules.transformer import EncoderBlock
14
14
  from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding
15
15
  from doctr.utils.repr import NestedObject
16
16
 
17
- from ...utils import load_pretrained_params
17
+ from ...utils import _build_model, load_pretrained_params
18
18
 
19
19
  __all__ = ["vit_s", "vit_b"]
20
20
 
21
21
 
22
- default_cfgs: Dict[str, Dict[str, Any]] = {
22
+ default_cfgs: dict[str, dict[str, Any]] = {
23
23
  "vit_s": {
24
24
  "mean": (0.694, 0.695, 0.693),
25
25
  "std": (0.299, 0.296, 0.301),
26
26
  "input_shape": (3, 32, 32),
27
27
  "classes": list(VOCABS["french"]),
28
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-6300fcc9.zip&src=0",
28
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0",
29
29
  },
30
30
  "vit_b": {
31
31
  "mean": (0.694, 0.695, 0.693),
32
32
  "std": (0.299, 0.296, 0.301),
33
33
  "input_shape": (32, 32, 3),
34
34
  "classes": list(VOCABS["french"]),
35
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-57158446.zip&src=0",
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0",
36
36
  },
37
37
  }
38
38
 
@@ -41,7 +41,6 @@ class ClassifierHead(layers.Layer, NestedObject):
41
41
  """Classifier head for Vision Transformer
42
42
 
43
43
  Args:
44
- ----
45
44
  num_classes: number of output classes
46
45
  """
47
46
 
@@ -61,7 +60,6 @@ class VisionTransformer(Sequential):
61
60
  <https://arxiv.org/pdf/2010.11929.pdf>`_.
62
61
 
63
62
  Args:
64
- ----
65
63
  d_model: dimension of the transformer layers
66
64
  num_layers: number of transformer layers
67
65
  num_heads: number of attention heads
@@ -79,12 +77,12 @@ class VisionTransformer(Sequential):
79
77
  num_layers: int,
80
78
  num_heads: int,
81
79
  ffd_ratio: int,
82
- patch_size: Tuple[int, int] = (4, 4),
83
- input_shape: Tuple[int, int, int] = (32, 32, 3),
80
+ patch_size: tuple[int, int] = (4, 4),
81
+ input_shape: tuple[int, int, int] = (32, 32, 3),
84
82
  dropout: float = 0.0,
85
83
  num_classes: int = 1000,
86
84
  include_top: bool = True,
87
- cfg: Optional[Dict[str, Any]] = None,
85
+ cfg: dict[str, Any] | None = None,
88
86
  ) -> None:
89
87
  _layers = [
90
88
  PatchEmbedding(input_shape, d_model, patch_size),
@@ -121,9 +119,15 @@ def _vit(
121
119
 
122
120
  # Build the model
123
121
  model = VisionTransformer(cfg=_cfg, **kwargs)
122
+ _build_model(model)
123
+
124
124
  # Load pretrained parameters
125
125
  if pretrained:
126
- load_pretrained_params(model, default_cfgs[arch]["url"])
126
+ # The number of classes is not the same as the number of classes in the pretrained model =>
127
+ # skip the mismatching layers for fine tuning
128
+ load_pretrained_params(
129
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
130
+ )
127
131
 
128
132
  return model
129
133
 
@@ -142,12 +146,10 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
142
146
  >>> out = model(input_tensor)
143
147
 
144
148
  Args:
145
- ----
146
149
  pretrained: boolean, True if model is pretrained
147
150
  **kwargs: keyword arguments of the VisionTransformer architecture
148
151
 
149
152
  Returns:
150
- -------
151
153
  A feature extractor model
152
154
  """
153
155
  return _vit(
@@ -173,12 +175,10 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
173
175
  >>> out = model(input_tensor)
174
176
 
175
177
  Args:
176
- ----
177
178
  pretrained: boolean, True if model is pretrained
178
179
  **kwargs: keyword arguments of the VisionTransformer architecture
179
180
 
180
181
  Returns:
181
- -------
182
182
  A feature extractor model
183
183
  """
184
184
  return _vit(
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List
6
+ from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available
8
+ from doctr.file_utils import is_tf_available, is_torch_available
9
9
 
10
10
  from .. import classification
11
11
  from ..preprocessor import PreProcessor
@@ -13,7 +13,7 @@ from .predictor import OrientationPredictor
13
13
 
14
14
  __all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
15
15
 
16
- ARCHS: List[str] = [
16
+ ARCHS: list[str] = [
17
17
  "magc_resnet31",
18
18
  "mobilenet_v3_small",
19
19
  "mobilenet_v3_small_r",
@@ -31,18 +31,37 @@ ARCHS: List[str] = [
31
31
  "vit_s",
32
32
  "vit_b",
33
33
  ]
34
- ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
34
+ ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
35
35
 
36
36
 
37
- def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor:
38
- if arch not in ORIENTATION_ARCHS:
39
- raise ValueError(f"unknown architecture '{arch}'")
37
+ def _orientation_predictor(
38
+ arch: Any, pretrained: bool, model_type: str, disabled: bool = False, **kwargs: Any
39
+ ) -> OrientationPredictor:
40
+ if disabled:
41
+ # Case where the orientation predictor is disabled
42
+ return OrientationPredictor(None, None)
43
+
44
+ if isinstance(arch, str):
45
+ if arch not in ORIENTATION_ARCHS:
46
+ raise ValueError(f"unknown architecture '{arch}'")
47
+
48
+ # Load directly classifier from backbone
49
+ _model = classification.__dict__[arch](pretrained=pretrained)
50
+ else:
51
+ allowed_archs = [classification.MobileNetV3]
52
+ if is_torch_available():
53
+ # Adding the type for torch compiled models to the allowed architectures
54
+ from doctr.models.utils import _CompiledModule
55
+
56
+ allowed_archs.append(_CompiledModule)
57
+
58
+ if not isinstance(arch, tuple(allowed_archs)):
59
+ raise ValueError(f"unknown architecture: {type(arch)}")
60
+ _model = arch
40
61
 
41
- # Load directly classifier from backbone
42
- _model = classification.__dict__[arch](pretrained=pretrained)
43
62
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
44
63
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
45
- kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
64
+ kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
46
65
  input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
47
66
  predictor = OrientationPredictor(
48
67
  PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
@@ -51,7 +70,7 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient
51
70
 
52
71
 
53
72
  def crop_orientation_predictor(
54
- arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
73
+ arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, batch_size: int = 128, **kwargs: Any
55
74
  ) -> OrientationPredictor:
56
75
  """Crop orientation classification architecture.
57
76
 
@@ -62,20 +81,19 @@ def crop_orientation_predictor(
62
81
  >>> out = model([input_crop])
63
82
 
64
83
  Args:
65
- ----
66
84
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
67
85
  pretrained: If True, returns a model pre-trained on our recognition crops dataset
86
+ batch_size: number of samples the model processes in parallel
68
87
  **kwargs: keyword arguments to be passed to the OrientationPredictor
69
88
 
70
89
  Returns:
71
- -------
72
90
  OrientationPredictor
73
91
  """
74
- return _orientation_predictor(arch, pretrained, **kwargs)
92
+ return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="crop", **kwargs)
75
93
 
76
94
 
77
95
  def page_orientation_predictor(
78
- arch: str = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
96
+ arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, batch_size: int = 4, **kwargs: Any
79
97
  ) -> OrientationPredictor:
80
98
  """Page orientation classification architecture.
81
99
 
@@ -86,13 +104,12 @@ def page_orientation_predictor(
86
104
  >>> out = model([input_page])
87
105
 
88
106
  Args:
89
- ----
90
107
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
91
108
  pretrained: If True, returns a model pre-trained on our recognition crops dataset
109
+ batch_size: number of samples the model processes in parallel
92
110
  **kwargs: keyword arguments to be passed to the OrientationPredictor
93
111
 
94
112
  Returns:
95
- -------
96
113
  OrientationPredictor
97
114
  """
98
- return _orientation_predictor(arch, pretrained, **kwargs)
115
+ return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="page", **kwargs)
doctr/models/core.py CHANGED
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
- from typing import Any, Dict, Optional
7
+ from typing import Any
8
8
 
9
9
  from doctr.utils.repr import NestedObject
10
10
 
@@ -14,6 +14,6 @@ __all__ = ["BaseModel"]
14
14
  class BaseModel(NestedObject):
15
15
  """Implements abstract DetectionModel class"""
16
16
 
17
- def __init__(self, cfg: Optional[Dict[str, Any]] = None) -> None:
17
+ def __init__(self, cfg: dict[str, Any] | None = None) -> None:
18
18
  super().__init__()
19
19
  self.cfg = cfg
@@ -1,7 +1,7 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
  from .base import *
3
3
 
4
- if is_tf_available():
5
- from .tensorflow import *
6
- else:
4
+ if is_torch_available():
7
5
  from .pytorch import *
6
+ elif is_tf_available():
7
+ from .tensorflow import *
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Dict, List
7
6
 
8
7
  import numpy as np
9
8
 
@@ -11,16 +10,15 @@ __all__ = ["_remove_padding"]
11
10
 
12
11
 
13
12
  def _remove_padding(
14
- pages: List[np.ndarray],
15
- loc_preds: List[Dict[str, np.ndarray]],
13
+ pages: list[np.ndarray],
14
+ loc_preds: list[dict[str, np.ndarray]],
16
15
  preserve_aspect_ratio: bool,
17
16
  symmetric_pad: bool,
18
17
  assume_straight_pages: bool,
19
- ) -> List[Dict[str, np.ndarray]]:
18
+ ) -> list[dict[str, np.ndarray]]:
20
19
  """Remove padding from the localization predictions
21
20
 
22
21
  Args:
23
- ----
24
22
  pages: list of pages
25
23
  loc_preds: list of localization predictions
26
24
  preserve_aspect_ratio: whether the aspect ratio was preserved during padding
@@ -28,7 +26,6 @@ def _remove_padding(
28
26
  assume_straight_pages: whether the pages are assumed to be straight
29
27
 
30
28
  Returns:
31
- -------
32
29
  list of unpaded localization predictions
33
30
  """
34
31
  if preserve_aspect_ratio:
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -13,12 +13,10 @@ def erode(x: Tensor, kernel_size: int) -> Tensor:
13
13
  """Performs erosion on a given tensor
14
14
 
15
15
  Args:
16
- ----
17
16
  x: boolean tensor of shape (N, C, H, W)
18
17
  kernel_size: the size of the kernel to use for erosion
19
18
 
20
19
  Returns:
21
- -------
22
20
  the eroded tensor
23
21
  """
24
22
  _pad = (kernel_size - 1) // 2
@@ -30,12 +28,10 @@ def dilate(x: Tensor, kernel_size: int) -> Tensor:
30
28
  """Performs dilation on a given tensor
31
29
 
32
30
  Args:
33
- ----
34
31
  x: boolean tensor of shape (N, C, H, W)
35
32
  kernel_size: the size of the kernel to use for dilation
36
33
 
37
34
  Returns:
38
- -------
39
35
  the dilated tensor
40
36
  """
41
37
  _pad = (kernel_size - 1) // 2
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -12,12 +12,10 @@ def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
12
12
  """Performs erosion on a given tensor
13
13
 
14
14
  Args:
15
- ----
16
15
  x: boolean tensor of shape (N, H, W, C)
17
16
  kernel_size: the size of the kernel to use for erosion
18
17
 
19
18
  Returns:
20
- -------
21
19
  the eroded tensor
22
20
  """
23
21
  return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME")
@@ -27,12 +25,10 @@ def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
27
25
  """Performs dilation on a given tensor
28
26
 
29
27
  Args:
30
- ----
31
28
  x: boolean tensor of shape (N, H, W, C)
32
29
  kernel_size: the size of the kernel to use for dilation
33
30
 
34
31
  Returns:
35
- -------
36
32
  the dilated tensor
37
33
  """
38
34
  return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME")
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List
7
6
 
8
7
  import cv2
9
8
  import numpy as np
@@ -17,7 +16,6 @@ class DetectionPostProcessor(NestedObject):
17
16
  """Abstract class to postprocess the raw output of the model
18
17
 
19
18
  Args:
20
- ----
21
19
  box_thresh (float): minimal objectness score to consider a box
22
20
  bin_thresh (float): threshold to apply to segmentation raw heatmap
23
21
  assume straight_pages (bool): if True, fit straight boxes only
@@ -37,13 +35,11 @@ class DetectionPostProcessor(NestedObject):
37
35
  """Compute the confidence score for a polygon : mean of the p values on the polygon
38
36
 
39
37
  Args:
40
- ----
41
38
  pred (np.ndarray): p map returned by the model
42
39
  points: coordinates of the polygon
43
40
  assume_straight_pages: if True, fit straight boxes only
44
41
 
45
42
  Returns:
46
- -------
47
43
  polygon objectness
48
44
  """
49
45
  h, w = pred.shape[:2]
@@ -71,15 +67,13 @@ class DetectionPostProcessor(NestedObject):
71
67
  def __call__(
72
68
  self,
73
69
  proba_map,
74
- ) -> List[List[np.ndarray]]:
70
+ ) -> list[list[np.ndarray]]:
75
71
  """Performs postprocessing for a list of model outputs
76
72
 
77
73
  Args:
78
- ----
79
74
  proba_map: probability map of shape (N, H, W, C)
80
75
 
81
76
  Returns:
82
- -------
83
77
  list of N class predictions (for each input sample), where each class predictions is a list of C tensors
84
78
  of shape (*, 5) or (*, 6)
85
79
  """
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,11 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
- from typing import Dict, List, Tuple, Union
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -22,7 +21,6 @@ class DBPostProcessor(DetectionPostProcessor):
22
21
  <https://github.com/xuannianz/DifferentiableBinarization>`_.
23
22
 
24
23
  Args:
25
- ----
26
24
  unclip ratio: ratio used to unshrink polygons
27
25
  min_size_box: minimal length (pix) to keep a box
28
26
  max_candidates: maximum boxes to consider in a single page
@@ -47,11 +45,9 @@ class DBPostProcessor(DetectionPostProcessor):
47
45
  """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
48
46
 
49
47
  Args:
50
- ----
51
48
  points: The first parameter.
52
49
 
53
50
  Returns:
54
- -------
55
51
  a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
56
52
  """
57
53
  if not self.assume_straight_pages:
@@ -96,25 +92,23 @@ class DBPostProcessor(DetectionPostProcessor):
96
92
  """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
97
93
 
98
94
  Args:
99
- ----
100
95
  pred: Pred map from differentiable binarization output
101
96
  bitmap: Bitmap map computed from pred (binarized)
102
97
  angle_tol: Comparison tolerance of the angle with the median angle across the page
103
98
  ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
104
99
 
105
100
  Returns:
106
- -------
107
101
  np tensor boxes for the bitmap, each box is a 5-element list
108
102
  containing x, y, w, h, score for the box
109
103
  """
110
104
  height, width = bitmap.shape[:2]
111
105
  min_size_box = 2
112
- boxes: List[Union[np.ndarray, List[float]]] = []
106
+ boxes: list[np.ndarray | list[float]] = []
113
107
  # get contours from connected components on the bitmap
114
108
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
115
109
  for contour in contours:
116
110
  # Check whether smallest enclosing bounding box is not too small
117
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): # type: ignore[index]
111
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
118
112
  continue
119
113
  # Compute objectness
120
114
  if self.assume_straight_pages:
@@ -164,7 +158,6 @@ class _DBNet:
164
158
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
165
159
 
166
160
  Args:
167
- ----
168
161
  feature extractor: the backbone serving as feature extractor
169
162
  fpn_channels: number of channels each extracted feature maps is mapped to
170
163
  """
@@ -186,7 +179,6 @@ class _DBNet:
186
179
  """Compute the distance for each point of the map (xs, ys) to the (a, b) segment
187
180
 
188
181
  Args:
189
- ----
190
182
  xs : map of x coordinates (height, width)
191
183
  ys : map of y coordinates (height, width)
192
184
  a: first point defining the [ab] segment
@@ -194,7 +186,6 @@ class _DBNet:
194
186
  eps: epsilon to avoid division by zero
195
187
 
196
188
  Returns:
197
- -------
198
189
  The computed distance
199
190
 
200
191
  """
@@ -214,11 +205,10 @@ class _DBNet:
214
205
  polygon: np.ndarray,
215
206
  canvas: np.ndarray,
216
207
  mask: np.ndarray,
217
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
208
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
218
209
  """Draw a polygon treshold map on a canvas, as described in the DB paper
219
210
 
220
211
  Args:
221
- ----
222
212
  polygon : array of coord., to draw the boundary of the polygon
223
213
  canvas : threshold map to fill with polygons
224
214
  mask : mask for training on threshold polygons
@@ -278,10 +268,10 @@ class _DBNet:
278
268
 
279
269
  def build_target(
280
270
  self,
281
- target: List[Dict[str, np.ndarray]],
282
- output_shape: Tuple[int, int, int],
271
+ target: list[dict[str, np.ndarray]],
272
+ output_shape: tuple[int, int, int],
283
273
  channels_last: bool = True,
284
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
274
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
285
275
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
286
276
  raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
287
277
  if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):