python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,15 +6,17 @@
6
6
  # Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional
9
+ from typing import Any
10
10
 
11
11
  from torchvision.models import mobilenetv3
12
+ from torchvision.models.mobilenetv3 import MobileNetV3
12
13
 
13
14
  from doctr.datasets import VOCABS
14
15
 
15
16
  from ...utils import load_pretrained_params
16
17
 
17
18
  __all__ = [
19
+ "MobileNetV3",
18
20
  "mobilenet_v3_small",
19
21
  "mobilenet_v3_small_r",
20
22
  "mobilenet_v3_large",
@@ -23,7 +25,7 @@ __all__ = [
23
25
  "mobilenet_v3_small_page_orientation",
24
26
  ]
25
27
 
26
- default_cfgs: Dict[str, Dict[str, Any]] = {
28
+ default_cfgs: dict[str, dict[str, Any]] = {
27
29
  "mobilenet_v3_large": {
28
30
  "mean": (0.694, 0.695, 0.693),
29
31
  "std": (0.299, 0.296, 0.301),
@@ -72,8 +74,8 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
72
74
  def _mobilenet_v3(
73
75
  arch: str,
74
76
  pretrained: bool,
75
- rect_strides: Optional[List[str]] = None,
76
- ignore_keys: Optional[List[str]] = None,
77
+ rect_strides: list[str] | None = None,
78
+ ignore_keys: list[str] | None = None,
77
79
  **kwargs: Any,
78
80
  ) -> mobilenetv3.MobileNetV3:
79
81
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -121,12 +123,10 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M
121
123
  >>> out = model(input_tensor)
122
124
 
123
125
  Args:
124
- ----
125
126
  pretrained: boolean, True if model is pretrained
126
127
  **kwargs: keyword arguments of the MobileNetV3 architecture
127
128
 
128
129
  Returns:
129
- -------
130
130
  a torch.nn.Module
131
131
  """
132
132
  return _mobilenet_v3(
@@ -146,12 +146,10 @@ def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
146
146
  >>> out = model(input_tensor)
147
147
 
148
148
  Args:
149
- ----
150
149
  pretrained: boolean, True if model is pretrained
151
150
  **kwargs: keyword arguments of the MobileNetV3 architecture
152
151
 
153
152
  Returns:
154
- -------
155
153
  a torch.nn.Module
156
154
  """
157
155
  return _mobilenet_v3(
@@ -175,12 +173,10 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M
175
173
  >>> out = model(input_tensor)
176
174
 
177
175
  Args:
178
- ----
179
176
  pretrained: boolean, True if model is pretrained
180
177
  **kwargs: keyword arguments of the MobileNetV3 architecture
181
178
 
182
179
  Returns:
183
- -------
184
180
  a torch.nn.Module
185
181
  """
186
182
  return _mobilenet_v3(
@@ -203,12 +199,10 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
203
199
  >>> out = model(input_tensor)
204
200
 
205
201
  Args:
206
- ----
207
202
  pretrained: boolean, True if model is pretrained
208
203
  **kwargs: keyword arguments of the MobileNetV3 architecture
209
204
 
210
205
  Returns:
211
- -------
212
206
  a torch.nn.Module
213
207
  """
214
208
  return _mobilenet_v3(
@@ -232,12 +226,10 @@ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any)
232
226
  >>> out = model(input_tensor)
233
227
 
234
228
  Args:
235
- ----
236
229
  pretrained: boolean, True if model is pretrained
237
230
  **kwargs: keyword arguments of the MobileNetV3 architecture
238
231
 
239
232
  Returns:
240
- -------
241
233
  a torch.nn.Module
242
234
  """
243
235
  return _mobilenet_v3(
@@ -260,12 +252,10 @@ def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any)
260
252
  >>> out = model(input_tensor)
261
253
 
262
254
  Args:
263
- ----
264
255
  pretrained: boolean, True if model is pretrained
265
256
  **kwargs: keyword arguments of the MobileNetV3 architecture
266
257
 
267
258
  Returns:
268
- -------
269
259
  a torch.nn.Module
270
260
  """
271
261
  return _mobilenet_v3(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,14 +6,14 @@
6
6
  # Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional, Tuple, Union
9
+ from typing import Any
10
10
 
11
11
  import tensorflow as tf
12
12
  from tensorflow.keras import layers
13
13
  from tensorflow.keras.models import Sequential
14
14
 
15
15
  from ....datasets import VOCABS
16
- from ...utils import conv_sequence, load_pretrained_params
16
+ from ...utils import _build_model, conv_sequence, load_pretrained_params
17
17
 
18
18
  __all__ = [
19
19
  "MobileNetV3",
@@ -26,48 +26,48 @@ __all__ = [
26
26
  ]
27
27
 
28
28
 
29
- default_cfgs: Dict[str, Dict[str, Any]] = {
29
+ default_cfgs: dict[str, dict[str, Any]] = {
30
30
  "mobilenet_v3_large": {
31
31
  "mean": (0.694, 0.695, 0.693),
32
32
  "std": (0.299, 0.296, 0.301),
33
33
  "input_shape": (32, 32, 3),
34
34
  "classes": list(VOCABS["french"]),
35
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-47d25d7e.zip&src=0",
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0",
36
36
  },
37
37
  "mobilenet_v3_large_r": {
38
38
  "mean": (0.694, 0.695, 0.693),
39
39
  "std": (0.299, 0.296, 0.301),
40
40
  "input_shape": (32, 32, 3),
41
41
  "classes": list(VOCABS["french"]),
42
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-a108e192.zip&src=0",
42
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0",
43
43
  },
44
44
  "mobilenet_v3_small": {
45
45
  "mean": (0.694, 0.695, 0.693),
46
46
  "std": (0.299, 0.296, 0.301),
47
47
  "input_shape": (32, 32, 3),
48
48
  "classes": list(VOCABS["french"]),
49
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-8a32c32c.zip&src=0",
49
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0",
50
50
  },
51
51
  "mobilenet_v3_small_r": {
52
52
  "mean": (0.694, 0.695, 0.693),
53
53
  "std": (0.299, 0.296, 0.301),
54
54
  "input_shape": (32, 32, 3),
55
55
  "classes": list(VOCABS["french"]),
56
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0",
56
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0",
57
57
  },
58
58
  "mobilenet_v3_small_crop_orientation": {
59
59
  "mean": (0.694, 0.695, 0.693),
60
60
  "std": (0.299, 0.296, 0.301),
61
61
  "input_shape": (128, 128, 3),
62
62
  "classes": [0, -90, 180, 90],
63
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0",
63
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0",
64
64
  },
65
65
  "mobilenet_v3_small_page_orientation": {
66
66
  "mean": (0.694, 0.695, 0.693),
67
67
  "std": (0.299, 0.296, 0.301),
68
68
  "input_shape": (512, 512, 3),
69
69
  "classes": [0, -90, 180, 90],
70
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-aec9553e.zip&src=0",
70
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0",
71
71
  },
72
72
  }
73
73
 
@@ -76,7 +76,7 @@ def hard_swish(x: tf.Tensor) -> tf.Tensor:
76
76
  return x * tf.nn.relu6(x + 3.0) / 6.0
77
77
 
78
78
 
79
- def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
79
+ def _make_divisible(v: float, divisor: int, min_value: int | None = None) -> int:
80
80
  if min_value is None:
81
81
  min_value = divisor
82
82
  new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
@@ -112,7 +112,7 @@ class InvertedResidualConfig:
112
112
  out_channels: int,
113
113
  use_se: bool,
114
114
  activation: str,
115
- stride: Union[int, Tuple[int, int]],
115
+ stride: int | tuple[int, int],
116
116
  width_mult: float = 1,
117
117
  ) -> None:
118
118
  self.input_channels = self.adjust_channels(input_channels, width_mult)
@@ -132,7 +132,6 @@ class InvertedResidual(layers.Layer):
132
132
  """InvertedResidual for mobilenet
133
133
 
134
134
  Args:
135
- ----
136
135
  conf: configuration object for inverted residual
137
136
  """
138
137
 
@@ -201,12 +200,12 @@ class MobileNetV3(Sequential):
201
200
 
202
201
  def __init__(
203
202
  self,
204
- layout: List[InvertedResidualConfig],
203
+ layout: list[InvertedResidualConfig],
205
204
  include_top: bool = True,
206
205
  head_chans: int = 1024,
207
206
  num_classes: int = 1000,
208
- cfg: Optional[Dict[str, Any]] = None,
209
- input_shape: Optional[Tuple[int, int, int]] = None,
207
+ cfg: dict[str, Any] | None = None,
208
+ input_shape: tuple[int, int, int] | None = None,
210
209
  ) -> None:
211
210
  _layers = [
212
211
  Sequential(
@@ -295,9 +294,15 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa
295
294
  cfg=_cfg,
296
295
  **kwargs,
297
296
  )
297
+ _build_model(model)
298
+
298
299
  # Load pretrained parameters
299
300
  if pretrained:
300
- load_pretrained_params(model, default_cfgs[arch]["url"])
301
+ # The number of classes is not the same as the number of classes in the pretrained model =>
302
+ # skip the mismatching layers for fine tuning
303
+ load_pretrained_params(
304
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
305
+ )
301
306
 
302
307
  return model
303
308
 
@@ -314,12 +319,10 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
314
319
  >>> out = model(input_tensor)
315
320
 
316
321
  Args:
317
- ----
318
322
  pretrained: boolean, True if model is pretrained
319
323
  **kwargs: keyword arguments of the MobileNetV3 architecture
320
324
 
321
325
  Returns:
322
- -------
323
326
  a keras.Model
324
327
  """
325
328
  return _mobilenet_v3("mobilenet_v3_small", pretrained, False, **kwargs)
@@ -337,12 +340,10 @@ def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
337
340
  >>> out = model(input_tensor)
338
341
 
339
342
  Args:
340
- ----
341
343
  pretrained: boolean, True if model is pretrained
342
344
  **kwargs: keyword arguments of the MobileNetV3 architecture
343
345
 
344
346
  Returns:
345
- -------
346
347
  a keras.Model
347
348
  """
348
349
  return _mobilenet_v3("mobilenet_v3_small_r", pretrained, True, **kwargs)
@@ -360,12 +361,10 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
360
361
  >>> out = model(input_tensor)
361
362
 
362
363
  Args:
363
- ----
364
364
  pretrained: boolean, True if model is pretrained
365
365
  **kwargs: keyword arguments of the MobileNetV3 architecture
366
366
 
367
367
  Returns:
368
- -------
369
368
  a keras.Model
370
369
  """
371
370
  return _mobilenet_v3("mobilenet_v3_large", pretrained, False, **kwargs)
@@ -383,12 +382,10 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
383
382
  >>> out = model(input_tensor)
384
383
 
385
384
  Args:
386
- ----
387
385
  pretrained: boolean, True if model is pretrained
388
386
  **kwargs: keyword arguments of the MobileNetV3 architecture
389
387
 
390
388
  Returns:
391
- -------
392
389
  a keras.Model
393
390
  """
394
391
  return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs)
@@ -406,12 +403,10 @@ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any)
406
403
  >>> out = model(input_tensor)
407
404
 
408
405
  Args:
409
- ----
410
406
  pretrained: boolean, True if model is pretrained
411
407
  **kwargs: keyword arguments of the MobileNetV3 architecture
412
408
 
413
409
  Returns:
414
- -------
415
410
  a keras.Model
416
411
  """
417
412
  return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs)
@@ -429,12 +424,10 @@ def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any)
429
424
  >>> out = model(input_tensor)
430
425
 
431
426
  Args:
432
- ----
433
427
  pretrained: boolean, True if model is pretrained
434
428
  **kwargs: keyword arguments of the MobileNetV3 architecture
435
429
 
436
430
  Returns:
437
- -------
438
431
  a keras.Model
439
432
  """
440
433
  return _mobilenet_v3("mobilenet_v3_small_page_orientation", pretrained, include_top=True, **kwargs)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Union
7
6
 
8
7
  import numpy as np
9
8
  import torch
@@ -20,35 +19,38 @@ class OrientationPredictor(nn.Module):
20
19
  4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
21
20
 
22
21
  Args:
23
- ----
24
22
  pre_processor: transform inputs for easier batched model inference
25
23
  model: core classification architecture (backbone + classification head)
26
24
  """
27
25
 
28
26
  def __init__(
29
27
  self,
30
- pre_processor: PreProcessor,
31
- model: nn.Module,
28
+ pre_processor: PreProcessor | None,
29
+ model: nn.Module | None,
32
30
  ) -> None:
33
31
  super().__init__()
34
- self.pre_processor = pre_processor
35
- self.model = model.eval()
32
+ self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
33
+ self.model = model.eval() if isinstance(model, nn.Module) else None
36
34
 
37
35
  @torch.inference_mode()
38
36
  def forward(
39
37
  self,
40
- inputs: List[Union[np.ndarray, torch.Tensor]],
41
- ) -> List[Union[List[int], List[float]]]:
38
+ inputs: list[np.ndarray | torch.Tensor],
39
+ ) -> list[list[int] | list[float]]:
42
40
  # Dimension check
43
41
  if any(input.ndim != 3 for input in inputs):
44
42
  raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
45
43
 
44
+ if self.model is None or self.pre_processor is None:
45
+ # predictor is disabled
46
+ return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]
47
+
46
48
  processed_batches = self.pre_processor(inputs)
47
49
  _params = next(self.model.parameters())
48
50
  self.model, processed_batches = set_device_and_dtype(
49
51
  self.model, processed_batches, _params.device, _params.dtype
50
52
  )
51
- predicted_batches = [self.model(batch) for batch in processed_batches]
53
+ predicted_batches = [self.model(batch) for batch in processed_batches] # type: ignore[misc]
52
54
  # confidence
53
55
  probs = [
54
56
  torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
@@ -57,7 +59,7 @@ class OrientationPredictor(nn.Module):
57
59
  predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]
58
60
 
59
61
  class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
60
- classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
62
+ classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] # type: ignore
61
63
  confs = [round(float(p), 2) for prob in probs for p in prob]
62
64
 
63
65
  return [class_idxs, classes, confs]
@@ -1,13 +1,12 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Union
7
6
 
8
7
  import numpy as np
9
8
  import tensorflow as tf
10
- from tensorflow import keras
9
+ from tensorflow.keras import Model
11
10
 
12
11
  from doctr.models.preprocessor import PreProcessor
13
12
  from doctr.utils.repr import NestedObject
@@ -20,29 +19,32 @@ class OrientationPredictor(NestedObject):
20
19
  4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
21
20
 
22
21
  Args:
23
- ----
24
22
  pre_processor: transform inputs for easier batched model inference
25
23
  model: core classification architecture (backbone + classification head)
26
24
  """
27
25
 
28
- _children_names: List[str] = ["pre_processor", "model"]
26
+ _children_names: list[str] = ["pre_processor", "model"]
29
27
 
30
28
  def __init__(
31
29
  self,
32
- pre_processor: PreProcessor,
33
- model: keras.Model,
30
+ pre_processor: PreProcessor | None,
31
+ model: Model | None,
34
32
  ) -> None:
35
- self.pre_processor = pre_processor
36
- self.model = model
33
+ self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
34
+ self.model = model if isinstance(model, Model) else None
37
35
 
38
36
  def __call__(
39
37
  self,
40
- inputs: List[Union[np.ndarray, tf.Tensor]],
41
- ) -> List[Union[List[int], List[float]]]:
38
+ inputs: list[np.ndarray | tf.Tensor],
39
+ ) -> list[list[int] | list[float]]:
42
40
  # Dimension check
43
41
  if any(input.ndim != 3 for input in inputs):
44
42
  raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
45
43
 
44
+ if self.model is None or self.pre_processor is None:
45
+ # predictor is disabled
46
+ return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]
47
+
46
48
  processed_batches = self.pre_processor(inputs)
47
49
  predicted_batches = [self.model(batch, training=False) for batch in processed_batches]
48
50
 
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,11 +1,12 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
+ from collections.abc import Callable
7
8
  from copy import deepcopy
8
- from typing import Any, Callable, Dict, List, Optional, Tuple
9
+ from typing import Any
9
10
 
10
11
  from torch import nn
11
12
  from torchvision.models.resnet import BasicBlock
@@ -21,7 +22,7 @@ from ...utils import conv_sequence_pt, load_pretrained_params
21
22
  __all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide", "resnet_stage"]
22
23
 
23
24
 
24
- default_cfgs: Dict[str, Dict[str, Any]] = {
25
+ default_cfgs: dict[str, dict[str, Any]] = {
25
26
  "resnet18": {
26
27
  "mean": (0.694, 0.695, 0.693),
27
28
  "std": (0.299, 0.296, 0.301),
@@ -60,9 +61,9 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
60
61
  }
61
62
 
62
63
 
63
- def resnet_stage(in_channels: int, out_channels: int, num_blocks: int, stride: int) -> List[nn.Module]:
64
+ def resnet_stage(in_channels: int, out_channels: int, num_blocks: int, stride: int) -> list[nn.Module]:
64
65
  """Build a ResNet stage"""
65
- _layers: List[nn.Module] = []
66
+ _layers: list[nn.Module] = []
66
67
 
67
68
  in_chan = in_channels
68
69
  s = stride
@@ -84,7 +85,6 @@ class ResNet(nn.Sequential):
84
85
  Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
85
86
 
86
87
  Args:
87
- ----
88
88
  num_blocks: number of resnet block in each stage
89
89
  output_channels: number of channels in each stage
90
90
  stage_conv: whether to add a conv_sequence after each stage
@@ -98,19 +98,19 @@ class ResNet(nn.Sequential):
98
98
 
99
99
  def __init__(
100
100
  self,
101
- num_blocks: List[int],
102
- output_channels: List[int],
103
- stage_stride: List[int],
104
- stage_conv: List[bool],
105
- stage_pooling: List[Optional[Tuple[int, int]]],
101
+ num_blocks: list[int],
102
+ output_channels: list[int],
103
+ stage_stride: list[int],
104
+ stage_conv: list[bool],
105
+ stage_pooling: list[tuple[int, int] | None],
106
106
  origin_stem: bool = True,
107
107
  stem_channels: int = 64,
108
- attn_module: Optional[Callable[[int], nn.Module]] = None,
108
+ attn_module: Callable[[int], nn.Module] | None = None,
109
109
  include_top: bool = True,
110
110
  num_classes: int = 1000,
111
- cfg: Optional[Dict[str, Any]] = None,
111
+ cfg: dict[str, Any] | None = None,
112
112
  ) -> None:
113
- _layers: List[nn.Module]
113
+ _layers: list[nn.Module]
114
114
  if origin_stem:
115
115
  _layers = [
116
116
  *conv_sequence_pt(3, stem_channels, True, True, kernel_size=7, padding=3, stride=2),
@@ -156,12 +156,12 @@ class ResNet(nn.Sequential):
156
156
  def _resnet(
157
157
  arch: str,
158
158
  pretrained: bool,
159
- num_blocks: List[int],
160
- output_channels: List[int],
161
- stage_stride: List[int],
162
- stage_conv: List[bool],
163
- stage_pooling: List[Optional[Tuple[int, int]]],
164
- ignore_keys: Optional[List[str]] = None,
159
+ num_blocks: list[int],
160
+ output_channels: list[int],
161
+ stage_stride: list[int],
162
+ stage_conv: list[bool],
163
+ stage_pooling: list[tuple[int, int] | None],
164
+ ignore_keys: list[str] | None = None,
165
165
  **kwargs: Any,
166
166
  ) -> ResNet:
167
167
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -188,7 +188,7 @@ def _tv_resnet(
188
188
  arch: str,
189
189
  pretrained: bool,
190
190
  arch_fn,
191
- ignore_keys: Optional[List[str]] = None,
191
+ ignore_keys: list[str] | None = None,
192
192
  **kwargs: Any,
193
193
  ) -> TVResNet:
194
194
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -224,12 +224,10 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> TVResNet:
224
224
  >>> out = model(input_tensor)
225
225
 
226
226
  Args:
227
- ----
228
227
  pretrained: boolean, True if model is pretrained
229
228
  **kwargs: keyword arguments of the ResNet architecture
230
229
 
231
230
  Returns:
232
- -------
233
231
  A resnet18 model
234
232
  """
235
233
  return _tv_resnet(
@@ -253,12 +251,10 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
253
251
  >>> out = model(input_tensor)
254
252
 
255
253
  Args:
256
- ----
257
254
  pretrained: boolean, True if model is pretrained
258
255
  **kwargs: keyword arguments of the ResNet architecture
259
256
 
260
257
  Returns:
261
- -------
262
258
  A resnet31 model
263
259
  """
264
260
  return _resnet(
@@ -287,12 +283,10 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> TVResNet:
287
283
  >>> out = model(input_tensor)
288
284
 
289
285
  Args:
290
- ----
291
286
  pretrained: boolean, True if model is pretrained
292
287
  **kwargs: keyword arguments of the ResNet architecture
293
288
 
294
289
  Returns:
295
- -------
296
290
  A resnet34 model
297
291
  """
298
292
  return _tv_resnet(
@@ -315,12 +309,10 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
315
309
  >>> out = model(input_tensor)
316
310
 
317
311
  Args:
318
- ----
319
312
  pretrained: boolean, True if model is pretrained
320
313
  **kwargs: keyword arguments of the ResNet architecture
321
314
 
322
315
  Returns:
323
- -------
324
316
  A resnet34_wide model
325
317
  """
326
318
  return _resnet(
@@ -349,12 +341,10 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> TVResNet:
349
341
  >>> out = model(input_tensor)
350
342
 
351
343
  Args:
352
- ----
353
344
  pretrained: boolean, True if model is pretrained
354
345
  **kwargs: keyword arguments of the ResNet architecture
355
346
 
356
347
  Returns:
357
- -------
358
348
  A resnet50 model
359
349
  """
360
350
  return _tv_resnet(