python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -1,433 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- # Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
7
-
8
- from copy import deepcopy
9
- from typing import Any
10
-
11
- import tensorflow as tf
12
- from tensorflow.keras import layers
13
- from tensorflow.keras.models import Sequential
14
-
15
- from ....datasets import VOCABS
16
- from ...utils import _build_model, conv_sequence, load_pretrained_params
17
-
18
- __all__ = [
19
- "MobileNetV3",
20
- "mobilenet_v3_small",
21
- "mobilenet_v3_small_r",
22
- "mobilenet_v3_large",
23
- "mobilenet_v3_large_r",
24
- "mobilenet_v3_small_crop_orientation",
25
- "mobilenet_v3_small_page_orientation",
26
- ]
27
-
28
-
29
- default_cfgs: dict[str, dict[str, Any]] = {
30
- "mobilenet_v3_large": {
31
- "mean": (0.694, 0.695, 0.693),
32
- "std": (0.299, 0.296, 0.301),
33
- "input_shape": (32, 32, 3),
34
- "classes": list(VOCABS["french"]),
35
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0",
36
- },
37
- "mobilenet_v3_large_r": {
38
- "mean": (0.694, 0.695, 0.693),
39
- "std": (0.299, 0.296, 0.301),
40
- "input_shape": (32, 32, 3),
41
- "classes": list(VOCABS["french"]),
42
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0",
43
- },
44
- "mobilenet_v3_small": {
45
- "mean": (0.694, 0.695, 0.693),
46
- "std": (0.299, 0.296, 0.301),
47
- "input_shape": (32, 32, 3),
48
- "classes": list(VOCABS["french"]),
49
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0",
50
- },
51
- "mobilenet_v3_small_r": {
52
- "mean": (0.694, 0.695, 0.693),
53
- "std": (0.299, 0.296, 0.301),
54
- "input_shape": (32, 32, 3),
55
- "classes": list(VOCABS["french"]),
56
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0",
57
- },
58
- "mobilenet_v3_small_crop_orientation": {
59
- "mean": (0.694, 0.695, 0.693),
60
- "std": (0.299, 0.296, 0.301),
61
- "input_shape": (128, 128, 3),
62
- "classes": [0, -90, 180, 90],
63
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0",
64
- },
65
- "mobilenet_v3_small_page_orientation": {
66
- "mean": (0.694, 0.695, 0.693),
67
- "std": (0.299, 0.296, 0.301),
68
- "input_shape": (512, 512, 3),
69
- "classes": [0, -90, 180, 90],
70
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0",
71
- },
72
- }
73
-
74
-
75
- def hard_swish(x: tf.Tensor) -> tf.Tensor:
76
- return x * tf.nn.relu6(x + 3.0) / 6.0
77
-
78
-
79
- def _make_divisible(v: float, divisor: int, min_value: int | None = None) -> int:
80
- if min_value is None:
81
- min_value = divisor
82
- new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
83
- # Make sure that round down does not go down by more than 10%.
84
- if new_v < 0.9 * v:
85
- new_v += divisor
86
- return new_v
87
-
88
-
89
- class SqueezeExcitation(Sequential):
90
- """Squeeze and Excitation."""
91
-
92
- def __init__(self, chan: int, squeeze_factor: int = 4) -> None:
93
- super().__init__([
94
- layers.GlobalAveragePooling2D(),
95
- layers.Dense(chan // squeeze_factor, activation="relu"),
96
- layers.Dense(chan, activation="hard_sigmoid"),
97
- layers.Reshape((1, 1, chan)),
98
- ])
99
-
100
- def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
101
- x = super().call(inputs, **kwargs)
102
- x = tf.math.multiply(inputs, x)
103
- return x
104
-
105
-
106
- class InvertedResidualConfig:
107
- def __init__(
108
- self,
109
- input_channels: int,
110
- kernel: int,
111
- expanded_channels: int,
112
- out_channels: int,
113
- use_se: bool,
114
- activation: str,
115
- stride: int | tuple[int, int],
116
- width_mult: float = 1,
117
- ) -> None:
118
- self.input_channels = self.adjust_channels(input_channels, width_mult)
119
- self.kernel = kernel
120
- self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
121
- self.out_channels = self.adjust_channels(out_channels, width_mult)
122
- self.use_se = use_se
123
- self.use_hs = activation == "HS"
124
- self.stride = stride
125
-
126
- @staticmethod
127
- def adjust_channels(channels: int, width_mult: float):
128
- return _make_divisible(channels * width_mult, 8)
129
-
130
-
131
- class InvertedResidual(layers.Layer):
132
- """InvertedResidual for mobilenet
133
-
134
- Args:
135
- conf: configuration object for inverted residual
136
- """
137
-
138
- def __init__(
139
- self,
140
- conf: InvertedResidualConfig,
141
- **kwargs: Any,
142
- ) -> None:
143
- _kwargs = {"input_shape": kwargs.pop("input_shape")} if isinstance(kwargs.get("input_shape"), tuple) else {}
144
- super().__init__(**kwargs)
145
-
146
- act_fn = hard_swish if conf.use_hs else tf.nn.relu
147
-
148
- _is_s1 = (isinstance(conf.stride, tuple) and conf.stride == (1, 1)) or conf.stride == 1
149
- self.use_res_connect = _is_s1 and conf.input_channels == conf.out_channels
150
-
151
- _layers = []
152
- # expand
153
- if conf.expanded_channels != conf.input_channels:
154
- _layers.extend(conv_sequence(conf.expanded_channels, act_fn, kernel_size=1, bn=True, **_kwargs))
155
-
156
- # depth-wise
157
- _layers.extend(
158
- conv_sequence(
159
- conf.expanded_channels,
160
- act_fn,
161
- kernel_size=conf.kernel,
162
- strides=conf.stride,
163
- bn=True,
164
- groups=conf.expanded_channels,
165
- )
166
- )
167
-
168
- if conf.use_se:
169
- _layers.append(SqueezeExcitation(conf.expanded_channels))
170
-
171
- # project
172
- _layers.extend(
173
- conv_sequence(
174
- conf.out_channels,
175
- None,
176
- kernel_size=1,
177
- bn=True,
178
- )
179
- )
180
-
181
- self.block = Sequential(_layers)
182
-
183
- def call(
184
- self,
185
- inputs: tf.Tensor,
186
- **kwargs: Any,
187
- ) -> tf.Tensor:
188
- out = self.block(inputs, **kwargs)
189
- if self.use_res_connect:
190
- out = tf.add(out, inputs)
191
-
192
- return out
193
-
194
-
195
- class MobileNetV3(Sequential):
196
- """Implements MobileNetV3, inspired from both:
197
- <https://github.com/xiaochus/MobileNetV3/tree/master/model>`_.
198
- and <https://pytorch.org/vision/stable/_modules/torchvision/models/mobilenetv3.html>`_.
199
- """
200
-
201
- def __init__(
202
- self,
203
- layout: list[InvertedResidualConfig],
204
- include_top: bool = True,
205
- head_chans: int = 1024,
206
- num_classes: int = 1000,
207
- cfg: dict[str, Any] | None = None,
208
- input_shape: tuple[int, int, int] | None = None,
209
- ) -> None:
210
- _layers = [
211
- Sequential(
212
- conv_sequence(
213
- layout[0].input_channels, hard_swish, True, kernel_size=3, strides=2, input_shape=input_shape
214
- ),
215
- name="stem",
216
- )
217
- ]
218
-
219
- for idx, conf in enumerate(layout):
220
- _layers.append(
221
- InvertedResidual(conf, name=f"inverted_{idx}"),
222
- )
223
-
224
- _layers.append(
225
- Sequential(conv_sequence(6 * layout[-1].out_channels, hard_swish, True, kernel_size=1), name="final_block")
226
- )
227
-
228
- if include_top:
229
- _layers.extend([
230
- layers.GlobalAveragePooling2D(),
231
- layers.Dense(head_chans, activation=hard_swish),
232
- layers.Dropout(0.2),
233
- layers.Dense(num_classes),
234
- ])
235
-
236
- super().__init__(_layers)
237
- self.cfg = cfg
238
-
239
-
240
- def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwargs: Any) -> MobileNetV3:
241
- kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
242
- kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
243
- kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
244
-
245
- _cfg = deepcopy(default_cfgs[arch])
246
- _cfg["num_classes"] = kwargs["num_classes"]
247
- _cfg["classes"] = kwargs["classes"]
248
- _cfg["input_shape"] = kwargs["input_shape"]
249
- kwargs.pop("classes")
250
-
251
- # cf. Table 1 & 2 of the paper
252
- if arch.startswith("mobilenet_v3_small"):
253
- inverted_residual_setting = [
254
- InvertedResidualConfig(16, 3, 16, 16, True, "RE", 2), # C1
255
- InvertedResidualConfig(16, 3, 72, 24, False, "RE", (2, 1) if rect_strides else 2), # C2
256
- InvertedResidualConfig(24, 3, 88, 24, False, "RE", 1),
257
- InvertedResidualConfig(24, 5, 96, 40, True, "HS", (2, 1) if rect_strides else 2), # C3
258
- InvertedResidualConfig(40, 5, 240, 40, True, "HS", 1),
259
- InvertedResidualConfig(40, 5, 240, 40, True, "HS", 1),
260
- InvertedResidualConfig(40, 5, 120, 48, True, "HS", 1),
261
- InvertedResidualConfig(48, 5, 144, 48, True, "HS", 1),
262
- InvertedResidualConfig(48, 5, 288, 96, True, "HS", (2, 1) if rect_strides else 2), # C4
263
- InvertedResidualConfig(96, 5, 576, 96, True, "HS", 1),
264
- InvertedResidualConfig(96, 5, 576, 96, True, "HS", 1),
265
- ]
266
- head_chans = 1024
267
- else:
268
- inverted_residual_setting = [
269
- InvertedResidualConfig(16, 3, 16, 16, False, "RE", 1),
270
- InvertedResidualConfig(16, 3, 64, 24, False, "RE", 2), # C1
271
- InvertedResidualConfig(24, 3, 72, 24, False, "RE", 1),
272
- InvertedResidualConfig(24, 5, 72, 40, True, "RE", (2, 1) if rect_strides else 2), # C2
273
- InvertedResidualConfig(40, 5, 120, 40, True, "RE", 1),
274
- InvertedResidualConfig(40, 5, 120, 40, True, "RE", 1),
275
- InvertedResidualConfig(40, 3, 240, 80, False, "HS", (2, 1) if rect_strides else 2), # C3
276
- InvertedResidualConfig(80, 3, 200, 80, False, "HS", 1),
277
- InvertedResidualConfig(80, 3, 184, 80, False, "HS", 1),
278
- InvertedResidualConfig(80, 3, 184, 80, False, "HS", 1),
279
- InvertedResidualConfig(80, 3, 480, 112, True, "HS", 1),
280
- InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1),
281
- InvertedResidualConfig(112, 5, 672, 160, True, "HS", (2, 1) if rect_strides else 2), # C4
282
- InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1),
283
- InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1),
284
- ]
285
- head_chans = 1280
286
-
287
- kwargs["num_classes"] = _cfg["num_classes"]
288
- kwargs["input_shape"] = _cfg["input_shape"]
289
-
290
- # Build the model
291
- model = MobileNetV3(
292
- inverted_residual_setting,
293
- head_chans=head_chans,
294
- cfg=_cfg,
295
- **kwargs,
296
- )
297
- _build_model(model)
298
-
299
- # Load pretrained parameters
300
- if pretrained:
301
- # The number of classes is not the same as the number of classes in the pretrained model =>
302
- # skip the mismatching layers for fine tuning
303
- load_pretrained_params(
304
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
305
- )
306
-
307
- return model
308
-
309
-
310
- def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
311
- """MobileNetV3-Small architecture as described in
312
- `"Searching for MobileNetV3",
313
- <https://arxiv.org/pdf/1905.02244.pdf>`_.
314
-
315
- >>> import tensorflow as tf
316
- >>> from doctr.models import mobilenet_v3_small
317
- >>> model = mobilenet_v3_small(pretrained=False)
318
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
319
- >>> out = model(input_tensor)
320
-
321
- Args:
322
- pretrained: boolean, True if model is pretrained
323
- **kwargs: keyword arguments of the MobileNetV3 architecture
324
-
325
- Returns:
326
- a keras.Model
327
- """
328
- return _mobilenet_v3("mobilenet_v3_small", pretrained, False, **kwargs)
329
-
330
-
331
- def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
332
- """MobileNetV3-Small architecture as described in
333
- `"Searching for MobileNetV3",
334
- <https://arxiv.org/pdf/1905.02244.pdf>`_, with rectangular pooling.
335
-
336
- >>> import tensorflow as tf
337
- >>> from doctr.models import mobilenet_v3_small_r
338
- >>> model = mobilenet_v3_small_r(pretrained=False)
339
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
340
- >>> out = model(input_tensor)
341
-
342
- Args:
343
- pretrained: boolean, True if model is pretrained
344
- **kwargs: keyword arguments of the MobileNetV3 architecture
345
-
346
- Returns:
347
- a keras.Model
348
- """
349
- return _mobilenet_v3("mobilenet_v3_small_r", pretrained, True, **kwargs)
350
-
351
-
352
- def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
353
- """MobileNetV3-Large architecture as described in
354
- `"Searching for MobileNetV3",
355
- <https://arxiv.org/pdf/1905.02244.pdf>`_.
356
-
357
- >>> import tensorflow as tf
358
- >>> from doctr.models import mobilenet_v3_large
359
- >>> model = mobilenet_v3_large(pretrained=False)
360
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
361
- >>> out = model(input_tensor)
362
-
363
- Args:
364
- pretrained: boolean, True if model is pretrained
365
- **kwargs: keyword arguments of the MobileNetV3 architecture
366
-
367
- Returns:
368
- a keras.Model
369
- """
370
- return _mobilenet_v3("mobilenet_v3_large", pretrained, False, **kwargs)
371
-
372
-
373
- def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
374
- """MobileNetV3-Large architecture as described in
375
- `"Searching for MobileNetV3",
376
- <https://arxiv.org/pdf/1905.02244.pdf>`_.
377
-
378
- >>> import tensorflow as tf
379
- >>> from doctr.models import mobilenet_v3_large_r
380
- >>> model = mobilenet_v3_large_r(pretrained=False)
381
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
382
- >>> out = model(input_tensor)
383
-
384
- Args:
385
- pretrained: boolean, True if model is pretrained
386
- **kwargs: keyword arguments of the MobileNetV3 architecture
387
-
388
- Returns:
389
- a keras.Model
390
- """
391
- return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs)
392
-
393
-
394
- def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
395
- """MobileNetV3-Small architecture as described in
396
- `"Searching for MobileNetV3",
397
- <https://arxiv.org/pdf/1905.02244.pdf>`_.
398
-
399
- >>> import tensorflow as tf
400
- >>> from doctr.models import mobilenet_v3_small_crop_orientation
401
- >>> model = mobilenet_v3_small_crop_orientation(pretrained=False)
402
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
403
- >>> out = model(input_tensor)
404
-
405
- Args:
406
- pretrained: boolean, True if model is pretrained
407
- **kwargs: keyword arguments of the MobileNetV3 architecture
408
-
409
- Returns:
410
- a keras.Model
411
- """
412
- return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs)
413
-
414
-
415
- def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
416
- """MobileNetV3-Small architecture as described in
417
- `"Searching for MobileNetV3",
418
- <https://arxiv.org/pdf/1905.02244.pdf>`_.
419
-
420
- >>> import tensorflow as tf
421
- >>> from doctr.models import mobilenet_v3_small_page_orientation
422
- >>> model = mobilenet_v3_small_page_orientation(pretrained=False)
423
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
424
- >>> out = model(input_tensor)
425
-
426
- Args:
427
- pretrained: boolean, True if model is pretrained
428
- **kwargs: keyword arguments of the MobileNetV3 architecture
429
-
430
- Returns:
431
- a keras.Model
432
- """
433
- return _mobilenet_v3("mobilenet_v3_small_page_orientation", pretrained, include_top=True, **kwargs)
@@ -1,60 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
-
7
- import numpy as np
8
- import tensorflow as tf
9
- from tensorflow.keras import Model
10
-
11
- from doctr.models.preprocessor import PreProcessor
12
- from doctr.utils.repr import NestedObject
13
-
14
- __all__ = ["OrientationPredictor"]
15
-
16
-
17
- class OrientationPredictor(NestedObject):
18
- """Implements an object able to detect the reading direction of a text box or a page.
19
- 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
20
-
21
- Args:
22
- pre_processor: transform inputs for easier batched model inference
23
- model: core classification architecture (backbone + classification head)
24
- """
25
-
26
- _children_names: list[str] = ["pre_processor", "model"]
27
-
28
- def __init__(
29
- self,
30
- pre_processor: PreProcessor | None,
31
- model: Model | None,
32
- ) -> None:
33
- self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
34
- self.model = model if isinstance(model, Model) else None
35
-
36
- def __call__(
37
- self,
38
- inputs: list[np.ndarray | tf.Tensor],
39
- ) -> list[list[int] | list[float]]:
40
- # Dimension check
41
- if any(input.ndim != 3 for input in inputs):
42
- raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
43
-
44
- if self.model is None or self.pre_processor is None:
45
- # predictor is disabled
46
- return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]
47
-
48
- processed_batches = self.pre_processor(inputs)
49
- predicted_batches = [self.model(batch, training=False) for batch in processed_batches]
50
-
51
- # confidence
52
- probs = [tf.math.reduce_max(tf.nn.softmax(batch, axis=1), axis=1).numpy() for batch in predicted_batches]
53
- # Postprocess predictions
54
- predicted_batches = [out_batch.numpy().argmax(1) for out_batch in predicted_batches]
55
-
56
- class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
57
- classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
58
- confs = [round(float(p), 2) for prob in probs for p in prob]
59
-
60
- return [class_idxs, classes, confs]