python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +8 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +7 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +4 -5
- doctr/datasets/ic13.py +4 -5
- doctr/datasets/iiit5k.py +6 -5
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +6 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +6 -5
- doctr/datasets/svt.py +4 -5
- doctr/datasets/synthtext.py +4 -5
- doctr/datasets/utils.py +34 -29
- doctr/datasets/vocabs.py +17 -7
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +2 -7
- doctr/io/elements.py +59 -79
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +30 -48
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +8 -11
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +5 -17
- doctr/models/classification/mobilenet/tensorflow.py +8 -21
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +6 -8
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +20 -31
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +8 -15
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +9 -12
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +6 -12
- doctr/models/classification/zoo.py +19 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +14 -26
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +14 -23
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +3 -7
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +18 -19
- doctr/models/kie_predictor/tensorflow.py +13 -14
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +12 -13
- doctr/models/predictor/tensorflow.py +8 -9
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +11 -23
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +12 -22
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +16 -22
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +12 -21
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +12 -20
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +14 -17
- doctr/models/utils/tensorflow.py +17 -16
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +16 -47
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
|
|
7
7
|
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
from torchvision.models import mobilenetv3
|
|
12
12
|
from torchvision.models.mobilenetv3 import MobileNetV3
|
|
@@ -25,7 +25,7 @@ __all__ = [
|
|
|
25
25
|
"mobilenet_v3_small_page_orientation",
|
|
26
26
|
]
|
|
27
27
|
|
|
28
|
-
default_cfgs:
|
|
28
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
29
29
|
"mobilenet_v3_large": {
|
|
30
30
|
"mean": (0.694, 0.695, 0.693),
|
|
31
31
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -74,8 +74,8 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
74
74
|
def _mobilenet_v3(
|
|
75
75
|
arch: str,
|
|
76
76
|
pretrained: bool,
|
|
77
|
-
rect_strides:
|
|
78
|
-
ignore_keys:
|
|
77
|
+
rect_strides: list[str] | None = None,
|
|
78
|
+
ignore_keys: list[str] | None = None,
|
|
79
79
|
**kwargs: Any,
|
|
80
80
|
) -> mobilenetv3.MobileNetV3:
|
|
81
81
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
@@ -123,12 +123,10 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M
|
|
|
123
123
|
>>> out = model(input_tensor)
|
|
124
124
|
|
|
125
125
|
Args:
|
|
126
|
-
----
|
|
127
126
|
pretrained: boolean, True if model is pretrained
|
|
128
127
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
129
128
|
|
|
130
129
|
Returns:
|
|
131
|
-
-------
|
|
132
130
|
a torch.nn.Module
|
|
133
131
|
"""
|
|
134
132
|
return _mobilenet_v3(
|
|
@@ -148,12 +146,10 @@ def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
|
|
|
148
146
|
>>> out = model(input_tensor)
|
|
149
147
|
|
|
150
148
|
Args:
|
|
151
|
-
----
|
|
152
149
|
pretrained: boolean, True if model is pretrained
|
|
153
150
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
154
151
|
|
|
155
152
|
Returns:
|
|
156
|
-
-------
|
|
157
153
|
a torch.nn.Module
|
|
158
154
|
"""
|
|
159
155
|
return _mobilenet_v3(
|
|
@@ -177,12 +173,10 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M
|
|
|
177
173
|
>>> out = model(input_tensor)
|
|
178
174
|
|
|
179
175
|
Args:
|
|
180
|
-
----
|
|
181
176
|
pretrained: boolean, True if model is pretrained
|
|
182
177
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
183
178
|
|
|
184
179
|
Returns:
|
|
185
|
-
-------
|
|
186
180
|
a torch.nn.Module
|
|
187
181
|
"""
|
|
188
182
|
return _mobilenet_v3(
|
|
@@ -205,12 +199,10 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
|
|
|
205
199
|
>>> out = model(input_tensor)
|
|
206
200
|
|
|
207
201
|
Args:
|
|
208
|
-
----
|
|
209
202
|
pretrained: boolean, True if model is pretrained
|
|
210
203
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
211
204
|
|
|
212
205
|
Returns:
|
|
213
|
-
-------
|
|
214
206
|
a torch.nn.Module
|
|
215
207
|
"""
|
|
216
208
|
return _mobilenet_v3(
|
|
@@ -234,12 +226,10 @@ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any)
|
|
|
234
226
|
>>> out = model(input_tensor)
|
|
235
227
|
|
|
236
228
|
Args:
|
|
237
|
-
----
|
|
238
229
|
pretrained: boolean, True if model is pretrained
|
|
239
230
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
240
231
|
|
|
241
232
|
Returns:
|
|
242
|
-
-------
|
|
243
233
|
a torch.nn.Module
|
|
244
234
|
"""
|
|
245
235
|
return _mobilenet_v3(
|
|
@@ -262,12 +252,10 @@ def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any)
|
|
|
262
252
|
>>> out = model(input_tensor)
|
|
263
253
|
|
|
264
254
|
Args:
|
|
265
|
-
----
|
|
266
255
|
pretrained: boolean, True if model is pretrained
|
|
267
256
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
268
257
|
|
|
269
258
|
Returns:
|
|
270
|
-
-------
|
|
271
259
|
a torch.nn.Module
|
|
272
260
|
"""
|
|
273
261
|
return _mobilenet_v3(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
|
|
7
7
|
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import tensorflow as tf
|
|
12
12
|
from tensorflow.keras import layers
|
|
@@ -26,7 +26,7 @@ __all__ = [
|
|
|
26
26
|
]
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
default_cfgs:
|
|
29
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
30
30
|
"mobilenet_v3_large": {
|
|
31
31
|
"mean": (0.694, 0.695, 0.693),
|
|
32
32
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -76,7 +76,7 @@ def hard_swish(x: tf.Tensor) -> tf.Tensor:
|
|
|
76
76
|
return x * tf.nn.relu6(x + 3.0) / 6.0
|
|
77
77
|
|
|
78
78
|
|
|
79
|
-
def _make_divisible(v: float, divisor: int, min_value:
|
|
79
|
+
def _make_divisible(v: float, divisor: int, min_value: int | None = None) -> int:
|
|
80
80
|
if min_value is None:
|
|
81
81
|
min_value = divisor
|
|
82
82
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
|
@@ -112,7 +112,7 @@ class InvertedResidualConfig:
|
|
|
112
112
|
out_channels: int,
|
|
113
113
|
use_se: bool,
|
|
114
114
|
activation: str,
|
|
115
|
-
stride:
|
|
115
|
+
stride: int | tuple[int, int],
|
|
116
116
|
width_mult: float = 1,
|
|
117
117
|
) -> None:
|
|
118
118
|
self.input_channels = self.adjust_channels(input_channels, width_mult)
|
|
@@ -132,7 +132,6 @@ class InvertedResidual(layers.Layer):
|
|
|
132
132
|
"""InvertedResidual for mobilenet
|
|
133
133
|
|
|
134
134
|
Args:
|
|
135
|
-
----
|
|
136
135
|
conf: configuration object for inverted residual
|
|
137
136
|
"""
|
|
138
137
|
|
|
@@ -201,12 +200,12 @@ class MobileNetV3(Sequential):
|
|
|
201
200
|
|
|
202
201
|
def __init__(
|
|
203
202
|
self,
|
|
204
|
-
layout:
|
|
203
|
+
layout: list[InvertedResidualConfig],
|
|
205
204
|
include_top: bool = True,
|
|
206
205
|
head_chans: int = 1024,
|
|
207
206
|
num_classes: int = 1000,
|
|
208
|
-
cfg:
|
|
209
|
-
input_shape:
|
|
207
|
+
cfg: dict[str, Any] | None = None,
|
|
208
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
210
209
|
) -> None:
|
|
211
210
|
_layers = [
|
|
212
211
|
Sequential(
|
|
@@ -320,12 +319,10 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
|
|
|
320
319
|
>>> out = model(input_tensor)
|
|
321
320
|
|
|
322
321
|
Args:
|
|
323
|
-
----
|
|
324
322
|
pretrained: boolean, True if model is pretrained
|
|
325
323
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
326
324
|
|
|
327
325
|
Returns:
|
|
328
|
-
-------
|
|
329
326
|
a keras.Model
|
|
330
327
|
"""
|
|
331
328
|
return _mobilenet_v3("mobilenet_v3_small", pretrained, False, **kwargs)
|
|
@@ -343,12 +340,10 @@ def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
|
|
|
343
340
|
>>> out = model(input_tensor)
|
|
344
341
|
|
|
345
342
|
Args:
|
|
346
|
-
----
|
|
347
343
|
pretrained: boolean, True if model is pretrained
|
|
348
344
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
349
345
|
|
|
350
346
|
Returns:
|
|
351
|
-
-------
|
|
352
347
|
a keras.Model
|
|
353
348
|
"""
|
|
354
349
|
return _mobilenet_v3("mobilenet_v3_small_r", pretrained, True, **kwargs)
|
|
@@ -366,12 +361,10 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
|
|
|
366
361
|
>>> out = model(input_tensor)
|
|
367
362
|
|
|
368
363
|
Args:
|
|
369
|
-
----
|
|
370
364
|
pretrained: boolean, True if model is pretrained
|
|
371
365
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
372
366
|
|
|
373
367
|
Returns:
|
|
374
|
-
-------
|
|
375
368
|
a keras.Model
|
|
376
369
|
"""
|
|
377
370
|
return _mobilenet_v3("mobilenet_v3_large", pretrained, False, **kwargs)
|
|
@@ -389,12 +382,10 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
|
|
|
389
382
|
>>> out = model(input_tensor)
|
|
390
383
|
|
|
391
384
|
Args:
|
|
392
|
-
----
|
|
393
385
|
pretrained: boolean, True if model is pretrained
|
|
394
386
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
395
387
|
|
|
396
388
|
Returns:
|
|
397
|
-
-------
|
|
398
389
|
a keras.Model
|
|
399
390
|
"""
|
|
400
391
|
return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs)
|
|
@@ -412,12 +403,10 @@ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any)
|
|
|
412
403
|
>>> out = model(input_tensor)
|
|
413
404
|
|
|
414
405
|
Args:
|
|
415
|
-
----
|
|
416
406
|
pretrained: boolean, True if model is pretrained
|
|
417
407
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
418
408
|
|
|
419
409
|
Returns:
|
|
420
|
-
-------
|
|
421
410
|
a keras.Model
|
|
422
411
|
"""
|
|
423
412
|
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs)
|
|
@@ -435,12 +424,10 @@ def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any)
|
|
|
435
424
|
>>> out = model(input_tensor)
|
|
436
425
|
|
|
437
426
|
Args:
|
|
438
|
-
----
|
|
439
427
|
pretrained: boolean, True if model is pretrained
|
|
440
428
|
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
441
429
|
|
|
442
430
|
Returns:
|
|
443
|
-
-------
|
|
444
431
|
a keras.Model
|
|
445
432
|
"""
|
|
446
433
|
return _mobilenet_v3("mobilenet_v3_small_page_orientation", pretrained, include_top=True, **kwargs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Optional, Union
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
import torch
|
|
@@ -20,15 +19,14 @@ class OrientationPredictor(nn.Module):
|
|
|
20
19
|
4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
----
|
|
24
22
|
pre_processor: transform inputs for easier batched model inference
|
|
25
23
|
model: core classification architecture (backbone + classification head)
|
|
26
24
|
"""
|
|
27
25
|
|
|
28
26
|
def __init__(
|
|
29
27
|
self,
|
|
30
|
-
pre_processor:
|
|
31
|
-
model:
|
|
28
|
+
pre_processor: PreProcessor | None,
|
|
29
|
+
model: nn.Module | None,
|
|
32
30
|
) -> None:
|
|
33
31
|
super().__init__()
|
|
34
32
|
self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
|
|
@@ -37,8 +35,8 @@ class OrientationPredictor(nn.Module):
|
|
|
37
35
|
@torch.inference_mode()
|
|
38
36
|
def forward(
|
|
39
37
|
self,
|
|
40
|
-
inputs:
|
|
41
|
-
) ->
|
|
38
|
+
inputs: list[np.ndarray | torch.Tensor],
|
|
39
|
+
) -> list[list[int] | list[float]]:
|
|
42
40
|
# Dimension check
|
|
43
41
|
if any(input.ndim != 3 for input in inputs):
|
|
44
42
|
raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
|
|
@@ -61,7 +59,7 @@ class OrientationPredictor(nn.Module):
|
|
|
61
59
|
predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]
|
|
62
60
|
|
|
63
61
|
class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
|
|
64
|
-
classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] # type: ignore
|
|
62
|
+
classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] # type: ignore
|
|
65
63
|
confs = [round(float(p), 2) for prob in probs for p in prob]
|
|
66
64
|
|
|
67
65
|
return [class_idxs, classes, confs]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Optional, Union
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
import tensorflow as tf
|
|
@@ -20,25 +19,24 @@ class OrientationPredictor(NestedObject):
|
|
|
20
19
|
4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
----
|
|
24
22
|
pre_processor: transform inputs for easier batched model inference
|
|
25
23
|
model: core classification architecture (backbone + classification head)
|
|
26
24
|
"""
|
|
27
25
|
|
|
28
|
-
_children_names:
|
|
26
|
+
_children_names: list[str] = ["pre_processor", "model"]
|
|
29
27
|
|
|
30
28
|
def __init__(
|
|
31
29
|
self,
|
|
32
|
-
pre_processor:
|
|
33
|
-
model:
|
|
30
|
+
pre_processor: PreProcessor | None,
|
|
31
|
+
model: Model | None,
|
|
34
32
|
) -> None:
|
|
35
33
|
self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
|
|
36
34
|
self.model = model if isinstance(model, Model) else None
|
|
37
35
|
|
|
38
36
|
def __call__(
|
|
39
37
|
self,
|
|
40
|
-
inputs:
|
|
41
|
-
) ->
|
|
38
|
+
inputs: list[np.ndarray | tf.Tensor],
|
|
39
|
+
) -> list[list[int] | list[float]]:
|
|
42
40
|
# Dimension check
|
|
43
41
|
if any(input.ndim != 3 for input in inputs):
|
|
44
42
|
raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
+
from collections.abc import Callable
|
|
7
8
|
from copy import deepcopy
|
|
8
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
9
10
|
|
|
10
11
|
from torch import nn
|
|
11
12
|
from torchvision.models.resnet import BasicBlock
|
|
@@ -21,7 +22,7 @@ from ...utils import conv_sequence_pt, load_pretrained_params
|
|
|
21
22
|
__all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide", "resnet_stage"]
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
default_cfgs:
|
|
25
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
25
26
|
"resnet18": {
|
|
26
27
|
"mean": (0.694, 0.695, 0.693),
|
|
27
28
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -60,9 +61,9 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
60
61
|
}
|
|
61
62
|
|
|
62
63
|
|
|
63
|
-
def resnet_stage(in_channels: int, out_channels: int, num_blocks: int, stride: int) ->
|
|
64
|
+
def resnet_stage(in_channels: int, out_channels: int, num_blocks: int, stride: int) -> list[nn.Module]:
|
|
64
65
|
"""Build a ResNet stage"""
|
|
65
|
-
_layers:
|
|
66
|
+
_layers: list[nn.Module] = []
|
|
66
67
|
|
|
67
68
|
in_chan = in_channels
|
|
68
69
|
s = stride
|
|
@@ -84,7 +85,6 @@ class ResNet(nn.Sequential):
|
|
|
84
85
|
Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
85
86
|
|
|
86
87
|
Args:
|
|
87
|
-
----
|
|
88
88
|
num_blocks: number of resnet block in each stage
|
|
89
89
|
output_channels: number of channels in each stage
|
|
90
90
|
stage_conv: whether to add a conv_sequence after each stage
|
|
@@ -98,19 +98,19 @@ class ResNet(nn.Sequential):
|
|
|
98
98
|
|
|
99
99
|
def __init__(
|
|
100
100
|
self,
|
|
101
|
-
num_blocks:
|
|
102
|
-
output_channels:
|
|
103
|
-
stage_stride:
|
|
104
|
-
stage_conv:
|
|
105
|
-
stage_pooling:
|
|
101
|
+
num_blocks: list[int],
|
|
102
|
+
output_channels: list[int],
|
|
103
|
+
stage_stride: list[int],
|
|
104
|
+
stage_conv: list[bool],
|
|
105
|
+
stage_pooling: list[tuple[int, int] | None],
|
|
106
106
|
origin_stem: bool = True,
|
|
107
107
|
stem_channels: int = 64,
|
|
108
|
-
attn_module:
|
|
108
|
+
attn_module: Callable[[int], nn.Module] | None = None,
|
|
109
109
|
include_top: bool = True,
|
|
110
110
|
num_classes: int = 1000,
|
|
111
|
-
cfg:
|
|
111
|
+
cfg: dict[str, Any] | None = None,
|
|
112
112
|
) -> None:
|
|
113
|
-
_layers:
|
|
113
|
+
_layers: list[nn.Module]
|
|
114
114
|
if origin_stem:
|
|
115
115
|
_layers = [
|
|
116
116
|
*conv_sequence_pt(3, stem_channels, True, True, kernel_size=7, padding=3, stride=2),
|
|
@@ -156,12 +156,12 @@ class ResNet(nn.Sequential):
|
|
|
156
156
|
def _resnet(
|
|
157
157
|
arch: str,
|
|
158
158
|
pretrained: bool,
|
|
159
|
-
num_blocks:
|
|
160
|
-
output_channels:
|
|
161
|
-
stage_stride:
|
|
162
|
-
stage_conv:
|
|
163
|
-
stage_pooling:
|
|
164
|
-
ignore_keys:
|
|
159
|
+
num_blocks: list[int],
|
|
160
|
+
output_channels: list[int],
|
|
161
|
+
stage_stride: list[int],
|
|
162
|
+
stage_conv: list[bool],
|
|
163
|
+
stage_pooling: list[tuple[int, int] | None],
|
|
164
|
+
ignore_keys: list[str] | None = None,
|
|
165
165
|
**kwargs: Any,
|
|
166
166
|
) -> ResNet:
|
|
167
167
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
@@ -188,7 +188,7 @@ def _tv_resnet(
|
|
|
188
188
|
arch: str,
|
|
189
189
|
pretrained: bool,
|
|
190
190
|
arch_fn,
|
|
191
|
-
ignore_keys:
|
|
191
|
+
ignore_keys: list[str] | None = None,
|
|
192
192
|
**kwargs: Any,
|
|
193
193
|
) -> TVResNet:
|
|
194
194
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
@@ -224,12 +224,10 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> TVResNet:
|
|
|
224
224
|
>>> out = model(input_tensor)
|
|
225
225
|
|
|
226
226
|
Args:
|
|
227
|
-
----
|
|
228
227
|
pretrained: boolean, True if model is pretrained
|
|
229
228
|
**kwargs: keyword arguments of the ResNet architecture
|
|
230
229
|
|
|
231
230
|
Returns:
|
|
232
|
-
-------
|
|
233
231
|
A resnet18 model
|
|
234
232
|
"""
|
|
235
233
|
return _tv_resnet(
|
|
@@ -253,12 +251,10 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
253
251
|
>>> out = model(input_tensor)
|
|
254
252
|
|
|
255
253
|
Args:
|
|
256
|
-
----
|
|
257
254
|
pretrained: boolean, True if model is pretrained
|
|
258
255
|
**kwargs: keyword arguments of the ResNet architecture
|
|
259
256
|
|
|
260
257
|
Returns:
|
|
261
|
-
-------
|
|
262
258
|
A resnet31 model
|
|
263
259
|
"""
|
|
264
260
|
return _resnet(
|
|
@@ -287,12 +283,10 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> TVResNet:
|
|
|
287
283
|
>>> out = model(input_tensor)
|
|
288
284
|
|
|
289
285
|
Args:
|
|
290
|
-
----
|
|
291
286
|
pretrained: boolean, True if model is pretrained
|
|
292
287
|
**kwargs: keyword arguments of the ResNet architecture
|
|
293
288
|
|
|
294
289
|
Returns:
|
|
295
|
-
-------
|
|
296
290
|
A resnet34 model
|
|
297
291
|
"""
|
|
298
292
|
return _tv_resnet(
|
|
@@ -315,12 +309,10 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
315
309
|
>>> out = model(input_tensor)
|
|
316
310
|
|
|
317
311
|
Args:
|
|
318
|
-
----
|
|
319
312
|
pretrained: boolean, True if model is pretrained
|
|
320
313
|
**kwargs: keyword arguments of the ResNet architecture
|
|
321
314
|
|
|
322
315
|
Returns:
|
|
323
|
-
-------
|
|
324
316
|
A resnet34_wide model
|
|
325
317
|
"""
|
|
326
318
|
return _resnet(
|
|
@@ -349,12 +341,10 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> TVResNet:
|
|
|
349
341
|
>>> out = model(input_tensor)
|
|
350
342
|
|
|
351
343
|
Args:
|
|
352
|
-
----
|
|
353
344
|
pretrained: boolean, True if model is pretrained
|
|
354
345
|
**kwargs: keyword arguments of the ResNet architecture
|
|
355
346
|
|
|
356
347
|
Returns:
|
|
357
|
-
-------
|
|
358
348
|
A resnet50 model
|
|
359
349
|
"""
|
|
360
350
|
return _tv_resnet(
|