python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -7,20 +7,19 @@
7
7
  import math
8
8
  from copy import deepcopy
9
9
  from functools import partial
10
- from typing import Any, Dict, List, Optional, Tuple
10
+ from typing import Any
11
11
 
12
12
  import torch
13
13
  from torch import nn
14
14
 
15
15
  from doctr.datasets import VOCABS
16
16
 
17
- from ...utils.pytorch import load_pretrained_params
18
17
  from ..resnet.pytorch import ResNet
19
18
 
20
19
  __all__ = ["magc_resnet31"]
21
20
 
22
21
 
23
- default_cfgs: Dict[str, Dict[str, Any]] = {
22
+ default_cfgs: dict[str, dict[str, Any]] = {
24
23
  "magc_resnet31": {
25
24
  "mean": (0.694, 0.695, 0.693),
26
25
  "std": (0.299, 0.296, 0.301),
@@ -36,7 +35,6 @@ class MAGC(nn.Module):
36
35
  <https://arxiv.org/pdf/1910.02562.pdf>`_.
37
36
 
38
37
  Args:
39
- ----
40
38
  inplanes: input channels
41
39
  headers: number of headers to split channels
42
40
  attn_scale: if True, re-scale attention to counteract the variance distibutions
@@ -50,7 +48,7 @@ class MAGC(nn.Module):
50
48
  headers: int = 8,
51
49
  attn_scale: bool = False,
52
50
  ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper
53
- cfg: Optional[Dict[str, Any]] = None,
51
+ cfg: dict[str, Any] | None = None,
54
52
  ) -> None:
55
53
  super().__init__()
56
54
 
@@ -105,12 +103,12 @@ class MAGC(nn.Module):
105
103
  def _magc_resnet(
106
104
  arch: str,
107
105
  pretrained: bool,
108
- num_blocks: List[int],
109
- output_channels: List[int],
110
- stage_stride: List[int],
111
- stage_conv: List[bool],
112
- stage_pooling: List[Optional[Tuple[int, int]]],
113
- ignore_keys: Optional[List[str]] = None,
106
+ num_blocks: list[int],
107
+ output_channels: list[int],
108
+ stage_stride: list[int],
109
+ stage_conv: list[bool],
110
+ stage_pooling: list[tuple[int, int] | None],
111
+ ignore_keys: list[str] | None = None,
114
112
  **kwargs: Any,
115
113
  ) -> ResNet:
116
114
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -137,7 +135,7 @@ def _magc_resnet(
137
135
  # The number of classes is not the same as the number of classes in the pretrained model =>
138
136
  # remove the last layer weights
139
137
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
140
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
138
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
141
139
 
142
140
  return model
143
141
 
@@ -154,12 +152,10 @@ def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
154
152
  >>> out = model(input_tensor)
155
153
 
156
154
  Args:
157
- ----
158
155
  pretrained: boolean, True if model is pretrained
159
156
  **kwargs: keyword arguments of the ResNet architecture
160
157
 
161
158
  Returns:
162
- -------
163
159
  A feature extractor model
164
160
  """
165
161
  return _magc_resnet(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  import math
7
7
  from copy import deepcopy
8
8
  from functools import partial
9
- from typing import Any, Dict, List, Optional, Tuple
9
+ from typing import Any
10
10
 
11
11
  import tensorflow as tf
12
12
  from tensorflow.keras import activations, layers
@@ -14,13 +14,13 @@ from tensorflow.keras.models import Sequential
14
14
 
15
15
  from doctr.datasets import VOCABS
16
16
 
17
- from ...utils import _build_model, load_pretrained_params
17
+ from ...utils import _build_model
18
18
  from ..resnet.tensorflow import ResNet
19
19
 
20
20
  __all__ = ["magc_resnet31"]
21
21
 
22
22
 
23
- default_cfgs: Dict[str, Dict[str, Any]] = {
23
+ default_cfgs: dict[str, dict[str, Any]] = {
24
24
  "magc_resnet31": {
25
25
  "mean": (0.694, 0.695, 0.693),
26
26
  "std": (0.299, 0.296, 0.301),
@@ -36,7 +36,6 @@ class MAGC(layers.Layer):
36
36
  <https://arxiv.org/pdf/1910.02562.pdf>`_.
37
37
 
38
38
  Args:
39
- ----
40
39
  inplanes: input channels
41
40
  headers: number of headers to split channels
42
41
  attn_scale: if True, re-scale attention to counteract the variance distibutions
@@ -122,11 +121,11 @@ class MAGC(layers.Layer):
122
121
  def _magc_resnet(
123
122
  arch: str,
124
123
  pretrained: bool,
125
- num_blocks: List[int],
126
- output_channels: List[int],
127
- stage_downsample: List[bool],
128
- stage_conv: List[bool],
129
- stage_pooling: List[Optional[Tuple[int, int]]],
124
+ num_blocks: list[int],
125
+ output_channels: list[int],
126
+ stage_downsample: list[bool],
127
+ stage_conv: list[bool],
128
+ stage_pooling: list[tuple[int, int] | None],
130
129
  origin_stem: bool = True,
131
130
  **kwargs: Any,
132
131
  ) -> ResNet:
@@ -158,8 +157,8 @@ def _magc_resnet(
158
157
  if pretrained:
159
158
  # The number of classes is not the same as the number of classes in the pretrained model =>
160
159
  # skip the mismatching layers for fine tuning
161
- load_pretrained_params(
162
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
160
+ model.from_pretrained(
161
+ default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
163
162
  )
164
163
 
165
164
  return model
@@ -177,12 +176,10 @@ def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
177
176
  >>> out = model(input_tensor)
178
177
 
179
178
  Args:
180
- ----
181
179
  pretrained: boolean, True if model is pretrained
182
180
  **kwargs: keyword arguments of the ResNet architecture
183
181
 
184
182
  Returns:
185
- -------
186
183
  A feature extractor model
187
184
  """
188
185
  return _magc_resnet(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
3
+ if is_torch_available():
6
4
  from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import *
@@ -1,12 +1,13 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  # Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
7
7
 
8
+ import types
8
9
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional
10
+ from typing import Any
10
11
 
11
12
  from torchvision.models import mobilenetv3
12
13
  from torchvision.models.mobilenetv3 import MobileNetV3
@@ -25,7 +26,7 @@ __all__ = [
25
26
  "mobilenet_v3_small_page_orientation",
26
27
  ]
27
28
 
28
- default_cfgs: Dict[str, Dict[str, Any]] = {
29
+ default_cfgs: dict[str, dict[str, Any]] = {
29
30
  "mobilenet_v3_large": {
30
31
  "mean": (0.694, 0.695, 0.693),
31
32
  "std": (0.299, 0.296, 0.301),
@@ -74,8 +75,8 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
74
75
  def _mobilenet_v3(
75
76
  arch: str,
76
77
  pretrained: bool,
77
- rect_strides: Optional[List[str]] = None,
78
- ignore_keys: Optional[List[str]] = None,
78
+ rect_strides: list[str] | None = None,
79
+ ignore_keys: list[str] | None = None,
79
80
  **kwargs: Any,
80
81
  ) -> mobilenetv3.MobileNetV3:
81
82
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -99,12 +100,25 @@ def _mobilenet_v3(
99
100
  m = getattr(m, child)
100
101
  m.stride = (2, 1)
101
102
 
103
+ # monkeypatch the model to allow for loading pretrained parameters
104
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
105
+ """Load pretrained parameters onto the model
106
+
107
+ Args:
108
+ path_or_url: the path or URL to the model parameters (checkpoint)
109
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
110
+ """
111
+ load_pretrained_params(self, path_or_url, **kwargs)
112
+
113
+ # Bind method to the instance
114
+ model.from_pretrained = types.MethodType(from_pretrained, model)
115
+
102
116
  # Load pretrained parameters
103
117
  if pretrained:
104
118
  # The number of classes is not the same as the number of classes in the pretrained model =>
105
119
  # remove the last layer weights
106
120
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
107
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
121
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
108
122
 
109
123
  model.cfg = _cfg
110
124
 
@@ -123,12 +137,10 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M
123
137
  >>> out = model(input_tensor)
124
138
 
125
139
  Args:
126
- ----
127
140
  pretrained: boolean, True if model is pretrained
128
141
  **kwargs: keyword arguments of the MobileNetV3 architecture
129
142
 
130
143
  Returns:
131
- -------
132
144
  a torch.nn.Module
133
145
  """
134
146
  return _mobilenet_v3(
@@ -148,12 +160,10 @@ def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
148
160
  >>> out = model(input_tensor)
149
161
 
150
162
  Args:
151
- ----
152
163
  pretrained: boolean, True if model is pretrained
153
164
  **kwargs: keyword arguments of the MobileNetV3 architecture
154
165
 
155
166
  Returns:
156
- -------
157
167
  a torch.nn.Module
158
168
  """
159
169
  return _mobilenet_v3(
@@ -177,12 +187,10 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M
177
187
  >>> out = model(input_tensor)
178
188
 
179
189
  Args:
180
- ----
181
190
  pretrained: boolean, True if model is pretrained
182
191
  **kwargs: keyword arguments of the MobileNetV3 architecture
183
192
 
184
193
  Returns:
185
- -------
186
194
  a torch.nn.Module
187
195
  """
188
196
  return _mobilenet_v3(
@@ -205,12 +213,10 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
205
213
  >>> out = model(input_tensor)
206
214
 
207
215
  Args:
208
- ----
209
216
  pretrained: boolean, True if model is pretrained
210
217
  **kwargs: keyword arguments of the MobileNetV3 architecture
211
218
 
212
219
  Returns:
213
- -------
214
220
  a torch.nn.Module
215
221
  """
216
222
  return _mobilenet_v3(
@@ -234,12 +240,10 @@ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any)
234
240
  >>> out = model(input_tensor)
235
241
 
236
242
  Args:
237
- ----
238
243
  pretrained: boolean, True if model is pretrained
239
244
  **kwargs: keyword arguments of the MobileNetV3 architecture
240
245
 
241
246
  Returns:
242
- -------
243
247
  a torch.nn.Module
244
248
  """
245
249
  return _mobilenet_v3(
@@ -262,12 +266,10 @@ def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any)
262
266
  >>> out = model(input_tensor)
263
267
 
264
268
  Args:
265
- ----
266
269
  pretrained: boolean, True if model is pretrained
267
270
  **kwargs: keyword arguments of the MobileNetV3 architecture
268
271
 
269
272
  Returns:
270
- -------
271
273
  a torch.nn.Module
272
274
  """
273
275
  return _mobilenet_v3(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  # Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional, Tuple, Union
9
+ from typing import Any
10
10
 
11
11
  import tensorflow as tf
12
12
  from tensorflow.keras import layers
@@ -26,7 +26,7 @@ __all__ = [
26
26
  ]
27
27
 
28
28
 
29
- default_cfgs: Dict[str, Dict[str, Any]] = {
29
+ default_cfgs: dict[str, dict[str, Any]] = {
30
30
  "mobilenet_v3_large": {
31
31
  "mean": (0.694, 0.695, 0.693),
32
32
  "std": (0.299, 0.296, 0.301),
@@ -76,7 +76,7 @@ def hard_swish(x: tf.Tensor) -> tf.Tensor:
76
76
  return x * tf.nn.relu6(x + 3.0) / 6.0
77
77
 
78
78
 
79
- def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
79
+ def _make_divisible(v: float, divisor: int, min_value: int | None = None) -> int:
80
80
  if min_value is None:
81
81
  min_value = divisor
82
82
  new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
@@ -112,7 +112,7 @@ class InvertedResidualConfig:
112
112
  out_channels: int,
113
113
  use_se: bool,
114
114
  activation: str,
115
- stride: Union[int, Tuple[int, int]],
115
+ stride: int | tuple[int, int],
116
116
  width_mult: float = 1,
117
117
  ) -> None:
118
118
  self.input_channels = self.adjust_channels(input_channels, width_mult)
@@ -132,7 +132,6 @@ class InvertedResidual(layers.Layer):
132
132
  """InvertedResidual for mobilenet
133
133
 
134
134
  Args:
135
- ----
136
135
  conf: configuration object for inverted residual
137
136
  """
138
137
 
@@ -201,12 +200,12 @@ class MobileNetV3(Sequential):
201
200
 
202
201
  def __init__(
203
202
  self,
204
- layout: List[InvertedResidualConfig],
203
+ layout: list[InvertedResidualConfig],
205
204
  include_top: bool = True,
206
205
  head_chans: int = 1024,
207
206
  num_classes: int = 1000,
208
- cfg: Optional[Dict[str, Any]] = None,
209
- input_shape: Optional[Tuple[int, int, int]] = None,
207
+ cfg: dict[str, Any] | None = None,
208
+ input_shape: tuple[int, int, int] | None = None,
210
209
  ) -> None:
211
210
  _layers = [
212
211
  Sequential(
@@ -237,6 +236,15 @@ class MobileNetV3(Sequential):
237
236
  super().__init__(_layers)
238
237
  self.cfg = cfg
239
238
 
239
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
240
+ """Load pretrained parameters onto the model
241
+
242
+ Args:
243
+ path_or_url: the path or URL to the model parameters (checkpoint)
244
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
245
+ """
246
+ load_pretrained_params(self, path_or_url, **kwargs)
247
+
240
248
 
241
249
  def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwargs: Any) -> MobileNetV3:
242
250
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -301,8 +309,8 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa
301
309
  if pretrained:
302
310
  # The number of classes is not the same as the number of classes in the pretrained model =>
303
311
  # skip the mismatching layers for fine tuning
304
- load_pretrained_params(
305
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
312
+ model.from_pretrained(
313
+ default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
306
314
  )
307
315
 
308
316
  return model
@@ -320,12 +328,10 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
320
328
  >>> out = model(input_tensor)
321
329
 
322
330
  Args:
323
- ----
324
331
  pretrained: boolean, True if model is pretrained
325
332
  **kwargs: keyword arguments of the MobileNetV3 architecture
326
333
 
327
334
  Returns:
328
- -------
329
335
  a keras.Model
330
336
  """
331
337
  return _mobilenet_v3("mobilenet_v3_small", pretrained, False, **kwargs)
@@ -343,12 +349,10 @@ def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
343
349
  >>> out = model(input_tensor)
344
350
 
345
351
  Args:
346
- ----
347
352
  pretrained: boolean, True if model is pretrained
348
353
  **kwargs: keyword arguments of the MobileNetV3 architecture
349
354
 
350
355
  Returns:
351
- -------
352
356
  a keras.Model
353
357
  """
354
358
  return _mobilenet_v3("mobilenet_v3_small_r", pretrained, True, **kwargs)
@@ -366,12 +370,10 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
366
370
  >>> out = model(input_tensor)
367
371
 
368
372
  Args:
369
- ----
370
373
  pretrained: boolean, True if model is pretrained
371
374
  **kwargs: keyword arguments of the MobileNetV3 architecture
372
375
 
373
376
  Returns:
374
- -------
375
377
  a keras.Model
376
378
  """
377
379
  return _mobilenet_v3("mobilenet_v3_large", pretrained, False, **kwargs)
@@ -389,12 +391,10 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
389
391
  >>> out = model(input_tensor)
390
392
 
391
393
  Args:
392
- ----
393
394
  pretrained: boolean, True if model is pretrained
394
395
  **kwargs: keyword arguments of the MobileNetV3 architecture
395
396
 
396
397
  Returns:
397
- -------
398
398
  a keras.Model
399
399
  """
400
400
  return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs)
@@ -412,12 +412,10 @@ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any)
412
412
  >>> out = model(input_tensor)
413
413
 
414
414
  Args:
415
- ----
416
415
  pretrained: boolean, True if model is pretrained
417
416
  **kwargs: keyword arguments of the MobileNetV3 architecture
418
417
 
419
418
  Returns:
420
- -------
421
419
  a keras.Model
422
420
  """
423
421
  return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs)
@@ -435,12 +433,10 @@ def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any)
435
433
  >>> out = model(input_tensor)
436
434
 
437
435
  Args:
438
- ----
439
436
  pretrained: boolean, True if model is pretrained
440
437
  **kwargs: keyword arguments of the MobileNetV3 architecture
441
438
 
442
439
  Returns:
443
- -------
444
440
  a keras.Model
445
441
  """
446
442
  return _mobilenet_v3("mobilenet_v3_small_page_orientation", pretrained, include_top=True, **kwargs)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Optional, Union
7
6
 
8
7
  import numpy as np
9
8
  import torch
@@ -20,15 +19,14 @@ class OrientationPredictor(nn.Module):
20
19
  4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
21
20
 
22
21
  Args:
23
- ----
24
22
  pre_processor: transform inputs for easier batched model inference
25
23
  model: core classification architecture (backbone + classification head)
26
24
  """
27
25
 
28
26
  def __init__(
29
27
  self,
30
- pre_processor: Optional[PreProcessor],
31
- model: Optional[nn.Module],
28
+ pre_processor: PreProcessor | None,
29
+ model: nn.Module | None,
32
30
  ) -> None:
33
31
  super().__init__()
34
32
  self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
@@ -37,8 +35,8 @@ class OrientationPredictor(nn.Module):
37
35
  @torch.inference_mode()
38
36
  def forward(
39
37
  self,
40
- inputs: List[Union[np.ndarray, torch.Tensor]],
41
- ) -> List[Union[List[int], List[float]]]:
38
+ inputs: list[np.ndarray | torch.Tensor],
39
+ ) -> list[list[int] | list[float]]:
42
40
  # Dimension check
43
41
  if any(input.ndim != 3 for input in inputs):
44
42
  raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
@@ -52,7 +50,7 @@ class OrientationPredictor(nn.Module):
52
50
  self.model, processed_batches = set_device_and_dtype(
53
51
  self.model, processed_batches, _params.device, _params.dtype
54
52
  )
55
- predicted_batches = [self.model(batch) for batch in processed_batches] # type: ignore[misc]
53
+ predicted_batches = [self.model(batch) for batch in processed_batches]
56
54
  # confidence
57
55
  probs = [
58
56
  torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
@@ -61,7 +59,7 @@ class OrientationPredictor(nn.Module):
61
59
  predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]
62
60
 
63
61
  class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
64
- classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] # type: ignore[union-attr]
62
+ classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] # type: ignore
65
63
  confs = [round(float(p), 2) for prob in probs for p in prob]
66
64
 
67
65
  return [class_idxs, classes, confs]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Optional, Union
7
6
 
8
7
  import numpy as np
9
8
  import tensorflow as tf
@@ -20,25 +19,24 @@ class OrientationPredictor(NestedObject):
20
19
  4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
21
20
 
22
21
  Args:
23
- ----
24
22
  pre_processor: transform inputs for easier batched model inference
25
23
  model: core classification architecture (backbone + classification head)
26
24
  """
27
25
 
28
- _children_names: List[str] = ["pre_processor", "model"]
26
+ _children_names: list[str] = ["pre_processor", "model"]
29
27
 
30
28
  def __init__(
31
29
  self,
32
- pre_processor: Optional[PreProcessor],
33
- model: Optional[Model],
30
+ pre_processor: PreProcessor | None,
31
+ model: Model | None,
34
32
  ) -> None:
35
33
  self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
36
34
  self.model = model if isinstance(model, Model) else None
37
35
 
38
36
  def __call__(
39
37
  self,
40
- inputs: List[Union[np.ndarray, tf.Tensor]],
41
- ) -> List[Union[List[int], List[float]]]:
38
+ inputs: list[np.ndarray | tf.Tensor],
39
+ ) -> list[list[int] | list[float]]:
42
40
  # Dimension check
43
41
  if any(input.ndim != 3 for input in inputs):
44
42
  raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]