python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import tensorflow as tf
|
|
10
11
|
from tensorflow.keras import layers
|
|
@@ -13,46 +14,46 @@ from tensorflow.keras.models import Sequential
|
|
|
13
14
|
|
|
14
15
|
from doctr.datasets import VOCABS
|
|
15
16
|
|
|
16
|
-
from ...utils import conv_sequence, load_pretrained_params
|
|
17
|
+
from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
17
18
|
|
|
18
19
|
__all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
default_cfgs:
|
|
22
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
22
23
|
"resnet18": {
|
|
23
24
|
"mean": (0.694, 0.695, 0.693),
|
|
24
25
|
"std": (0.299, 0.296, 0.301),
|
|
25
26
|
"input_shape": (32, 32, 3),
|
|
26
27
|
"classes": list(VOCABS["french"]),
|
|
27
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
28
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
|
|
28
29
|
},
|
|
29
30
|
"resnet31": {
|
|
30
31
|
"mean": (0.694, 0.695, 0.693),
|
|
31
32
|
"std": (0.299, 0.296, 0.301),
|
|
32
33
|
"input_shape": (32, 32, 3),
|
|
33
34
|
"classes": list(VOCABS["french"]),
|
|
34
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
35
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
|
|
35
36
|
},
|
|
36
37
|
"resnet34": {
|
|
37
38
|
"mean": (0.694, 0.695, 0.693),
|
|
38
39
|
"std": (0.299, 0.296, 0.301),
|
|
39
40
|
"input_shape": (32, 32, 3),
|
|
40
41
|
"classes": list(VOCABS["french"]),
|
|
41
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
42
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
|
|
42
43
|
},
|
|
43
44
|
"resnet50": {
|
|
44
45
|
"mean": (0.694, 0.695, 0.693),
|
|
45
46
|
"std": (0.299, 0.296, 0.301),
|
|
46
47
|
"input_shape": (32, 32, 3),
|
|
47
48
|
"classes": list(VOCABS["french"]),
|
|
48
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
49
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
|
|
49
50
|
},
|
|
50
51
|
"resnet34_wide": {
|
|
51
52
|
"mean": (0.694, 0.695, 0.693),
|
|
52
53
|
"std": (0.299, 0.296, 0.301),
|
|
53
54
|
"input_shape": (32, 32, 3),
|
|
54
55
|
"classes": list(VOCABS["french"]),
|
|
55
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
56
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
|
|
56
57
|
},
|
|
57
58
|
}
|
|
58
59
|
|
|
@@ -61,7 +62,6 @@ class ResnetBlock(layers.Layer):
|
|
|
61
62
|
"""Implements a resnet31 block with shortcut
|
|
62
63
|
|
|
63
64
|
Args:
|
|
64
|
-
----
|
|
65
65
|
conv_shortcut: Use of shortcut
|
|
66
66
|
output_channels: number of channels to use in Conv2D
|
|
67
67
|
kernel_size: size of square kernels
|
|
@@ -92,7 +92,7 @@ class ResnetBlock(layers.Layer):
|
|
|
92
92
|
output_channels: int,
|
|
93
93
|
kernel_size: int,
|
|
94
94
|
strides: int = 1,
|
|
95
|
-
) ->
|
|
95
|
+
) -> list[layers.Layer]:
|
|
96
96
|
return [
|
|
97
97
|
*conv_sequence(output_channels, "relu", bn=True, strides=strides, kernel_size=kernel_size),
|
|
98
98
|
*conv_sequence(output_channels, None, bn=True, kernel_size=kernel_size),
|
|
@@ -108,8 +108,8 @@ class ResnetBlock(layers.Layer):
|
|
|
108
108
|
|
|
109
109
|
def resnet_stage(
|
|
110
110
|
num_blocks: int, out_channels: int, shortcut: bool = False, downsample: bool = False
|
|
111
|
-
) ->
|
|
112
|
-
_layers:
|
|
111
|
+
) -> list[layers.Layer]:
|
|
112
|
+
_layers: list[layers.Layer] = [ResnetBlock(out_channels, conv_shortcut=shortcut, strides=2 if downsample else 1)]
|
|
113
113
|
|
|
114
114
|
for _ in range(1, num_blocks):
|
|
115
115
|
_layers.append(ResnetBlock(out_channels, conv_shortcut=False))
|
|
@@ -121,7 +121,6 @@ class ResNet(Sequential):
|
|
|
121
121
|
"""Implements a ResNet architecture
|
|
122
122
|
|
|
123
123
|
Args:
|
|
124
|
-
----
|
|
125
124
|
num_blocks: number of resnet block in each stage
|
|
126
125
|
output_channels: number of channels in each stage
|
|
127
126
|
stage_downsample: whether the first residual block of a stage should downsample
|
|
@@ -137,18 +136,18 @@ class ResNet(Sequential):
|
|
|
137
136
|
|
|
138
137
|
def __init__(
|
|
139
138
|
self,
|
|
140
|
-
num_blocks:
|
|
141
|
-
output_channels:
|
|
142
|
-
stage_downsample:
|
|
143
|
-
stage_conv:
|
|
144
|
-
stage_pooling:
|
|
139
|
+
num_blocks: list[int],
|
|
140
|
+
output_channels: list[int],
|
|
141
|
+
stage_downsample: list[bool],
|
|
142
|
+
stage_conv: list[bool],
|
|
143
|
+
stage_pooling: list[tuple[int, int] | None],
|
|
145
144
|
origin_stem: bool = True,
|
|
146
145
|
stem_channels: int = 64,
|
|
147
|
-
attn_module:
|
|
146
|
+
attn_module: Callable[[int], layers.Layer] | None = None,
|
|
148
147
|
include_top: bool = True,
|
|
149
148
|
num_classes: int = 1000,
|
|
150
|
-
cfg:
|
|
151
|
-
input_shape:
|
|
149
|
+
cfg: dict[str, Any] | None = None,
|
|
150
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
152
151
|
) -> None:
|
|
153
152
|
inplanes = stem_channels
|
|
154
153
|
if origin_stem:
|
|
@@ -188,11 +187,11 @@ class ResNet(Sequential):
|
|
|
188
187
|
def _resnet(
|
|
189
188
|
arch: str,
|
|
190
189
|
pretrained: bool,
|
|
191
|
-
num_blocks:
|
|
192
|
-
output_channels:
|
|
193
|
-
stage_downsample:
|
|
194
|
-
stage_conv:
|
|
195
|
-
stage_pooling:
|
|
190
|
+
num_blocks: list[int],
|
|
191
|
+
output_channels: list[int],
|
|
192
|
+
stage_downsample: list[bool],
|
|
193
|
+
stage_conv: list[bool],
|
|
194
|
+
stage_pooling: list[tuple[int, int] | None],
|
|
196
195
|
origin_stem: bool = True,
|
|
197
196
|
**kwargs: Any,
|
|
198
197
|
) -> ResNet:
|
|
@@ -210,9 +209,15 @@ def _resnet(
|
|
|
210
209
|
model = ResNet(
|
|
211
210
|
num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs
|
|
212
211
|
)
|
|
212
|
+
_build_model(model)
|
|
213
|
+
|
|
213
214
|
# Load pretrained parameters
|
|
214
215
|
if pretrained:
|
|
215
|
-
|
|
216
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
217
|
+
# skip the mismatching layers for fine tuning
|
|
218
|
+
load_pretrained_params(
|
|
219
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
220
|
+
)
|
|
216
221
|
|
|
217
222
|
return model
|
|
218
223
|
|
|
@@ -228,12 +233,10 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
228
233
|
>>> out = model(input_tensor)
|
|
229
234
|
|
|
230
235
|
Args:
|
|
231
|
-
----
|
|
232
236
|
pretrained: boolean, True if model is pretrained
|
|
233
237
|
**kwargs: keyword arguments of the ResNet architecture
|
|
234
238
|
|
|
235
239
|
Returns:
|
|
236
|
-
-------
|
|
237
240
|
A classification model
|
|
238
241
|
"""
|
|
239
242
|
return _resnet(
|
|
@@ -261,12 +264,10 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
261
264
|
>>> out = model(input_tensor)
|
|
262
265
|
|
|
263
266
|
Args:
|
|
264
|
-
----
|
|
265
267
|
pretrained: boolean, True if model is pretrained
|
|
266
268
|
**kwargs: keyword arguments of the ResNet architecture
|
|
267
269
|
|
|
268
270
|
Returns:
|
|
269
|
-
-------
|
|
270
271
|
A classification model
|
|
271
272
|
"""
|
|
272
273
|
return _resnet(
|
|
@@ -294,12 +295,10 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
294
295
|
>>> out = model(input_tensor)
|
|
295
296
|
|
|
296
297
|
Args:
|
|
297
|
-
----
|
|
298
298
|
pretrained: boolean, True if model is pretrained
|
|
299
299
|
**kwargs: keyword arguments of the ResNet architecture
|
|
300
300
|
|
|
301
301
|
Returns:
|
|
302
|
-
-------
|
|
303
302
|
A classification model
|
|
304
303
|
"""
|
|
305
304
|
return _resnet(
|
|
@@ -326,12 +325,10 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
326
325
|
>>> out = model(input_tensor)
|
|
327
326
|
|
|
328
327
|
Args:
|
|
329
|
-
----
|
|
330
328
|
pretrained: boolean, True if model is pretrained
|
|
331
329
|
**kwargs: keyword arguments of the ResNet architecture
|
|
332
330
|
|
|
333
331
|
Returns:
|
|
334
|
-
-------
|
|
335
332
|
A classification model
|
|
336
333
|
"""
|
|
337
334
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs["resnet50"]["classes"]))
|
|
@@ -354,10 +351,17 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
354
351
|
)
|
|
355
352
|
|
|
356
353
|
model.cfg = _cfg
|
|
354
|
+
_build_model(model)
|
|
357
355
|
|
|
358
356
|
# Load pretrained parameters
|
|
359
357
|
if pretrained:
|
|
360
|
-
|
|
358
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
359
|
+
# skip the mismatching layers for fine tuning
|
|
360
|
+
load_pretrained_params(
|
|
361
|
+
model,
|
|
362
|
+
default_cfgs["resnet50"]["url"],
|
|
363
|
+
skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
|
|
364
|
+
)
|
|
361
365
|
|
|
362
366
|
return model
|
|
363
367
|
|
|
@@ -373,12 +377,10 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
373
377
|
>>> out = model(input_tensor)
|
|
374
378
|
|
|
375
379
|
Args:
|
|
376
|
-
----
|
|
377
380
|
pretrained: boolean, True if model is pretrained
|
|
378
381
|
**kwargs: keyword arguments of the ResNet architecture
|
|
379
382
|
|
|
380
383
|
Returns:
|
|
381
|
-
-------
|
|
382
384
|
A classification model
|
|
383
385
|
"""
|
|
384
386
|
return _resnet(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
from torch import nn
|
|
11
11
|
|
|
@@ -16,7 +16,7 @@ from ...utils import conv_sequence_pt, load_pretrained_params
|
|
|
16
16
|
|
|
17
17
|
__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
|
|
18
18
|
|
|
19
|
-
default_cfgs:
|
|
19
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
20
20
|
"textnet_tiny": {
|
|
21
21
|
"mean": (0.694, 0.695, 0.693),
|
|
22
22
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -47,22 +47,21 @@ class TextNet(nn.Sequential):
|
|
|
47
47
|
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
|
|
48
48
|
|
|
49
49
|
Args:
|
|
50
|
-
|
|
51
|
-
stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
|
|
50
|
+
stages (list[dict[str, list[int]]]): list of dictionaries containing the parameters of each stage.
|
|
52
51
|
include_top (bool, optional): Whether to include the classifier head. Defaults to True.
|
|
53
52
|
num_classes (int, optional): Number of output classes. Defaults to 1000.
|
|
54
|
-
cfg (
|
|
53
|
+
cfg (dict[str, Any], optional): Additional configuration. Defaults to None.
|
|
55
54
|
"""
|
|
56
55
|
|
|
57
56
|
def __init__(
|
|
58
57
|
self,
|
|
59
|
-
stages:
|
|
60
|
-
input_shape:
|
|
58
|
+
stages: list[dict[str, list[int]]],
|
|
59
|
+
input_shape: tuple[int, int, int] = (3, 32, 32),
|
|
61
60
|
num_classes: int = 1000,
|
|
62
61
|
include_top: bool = True,
|
|
63
|
-
cfg:
|
|
62
|
+
cfg: dict[str, Any] | None = None,
|
|
64
63
|
) -> None:
|
|
65
|
-
_layers:
|
|
64
|
+
_layers: list[nn.Module] = [
|
|
66
65
|
*conv_sequence_pt(
|
|
67
66
|
in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2, padding=(1, 1)
|
|
68
67
|
),
|
|
@@ -98,7 +97,7 @@ class TextNet(nn.Sequential):
|
|
|
98
97
|
def _textnet(
|
|
99
98
|
arch: str,
|
|
100
99
|
pretrained: bool,
|
|
101
|
-
ignore_keys:
|
|
100
|
+
ignore_keys: list[str] | None = None,
|
|
102
101
|
**kwargs: Any,
|
|
103
102
|
) -> TextNet:
|
|
104
103
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
@@ -135,12 +134,10 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
|
135
134
|
>>> out = model(input_tensor)
|
|
136
135
|
|
|
137
136
|
Args:
|
|
138
|
-
----
|
|
139
137
|
pretrained: boolean, True if model is pretrained
|
|
140
138
|
**kwargs: keyword arguments of the TextNet architecture
|
|
141
139
|
|
|
142
140
|
Returns:
|
|
143
|
-
-------
|
|
144
141
|
A textnet tiny model
|
|
145
142
|
"""
|
|
146
143
|
return _textnet(
|
|
@@ -184,12 +181,10 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
|
184
181
|
>>> out = model(input_tensor)
|
|
185
182
|
|
|
186
183
|
Args:
|
|
187
|
-
----
|
|
188
184
|
pretrained: boolean, True if model is pretrained
|
|
189
185
|
**kwargs: keyword arguments of the TextNet architecture
|
|
190
186
|
|
|
191
187
|
Returns:
|
|
192
|
-
-------
|
|
193
188
|
A TextNet small model
|
|
194
189
|
"""
|
|
195
190
|
return _textnet(
|
|
@@ -233,12 +228,10 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
|
233
228
|
>>> out = model(input_tensor)
|
|
234
229
|
|
|
235
230
|
Args:
|
|
236
|
-
----
|
|
237
231
|
pretrained: boolean, True if model is pretrained
|
|
238
232
|
**kwargs: keyword arguments of the TextNet architecture
|
|
239
233
|
|
|
240
234
|
Returns:
|
|
241
|
-
-------
|
|
242
235
|
A TextNet base model
|
|
243
236
|
"""
|
|
244
237
|
return _textnet(
|
|
@@ -1,42 +1,42 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
from tensorflow.keras import Sequential, layers
|
|
11
11
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
|
|
14
14
|
from ...modules.layers.tensorflow import FASTConvLayer
|
|
15
|
-
from ...utils import conv_sequence, load_pretrained_params
|
|
15
|
+
from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
16
16
|
|
|
17
17
|
__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
|
|
18
18
|
|
|
19
|
-
default_cfgs:
|
|
19
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
20
20
|
"textnet_tiny": {
|
|
21
21
|
"mean": (0.694, 0.695, 0.693),
|
|
22
22
|
"std": (0.299, 0.296, 0.301),
|
|
23
23
|
"input_shape": (32, 32, 3),
|
|
24
24
|
"classes": list(VOCABS["french"]),
|
|
25
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
25
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
|
|
26
26
|
},
|
|
27
27
|
"textnet_small": {
|
|
28
28
|
"mean": (0.694, 0.695, 0.693),
|
|
29
29
|
"std": (0.299, 0.296, 0.301),
|
|
30
30
|
"input_shape": (32, 32, 3),
|
|
31
31
|
"classes": list(VOCABS["french"]),
|
|
32
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
32
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
|
|
33
33
|
},
|
|
34
34
|
"textnet_base": {
|
|
35
35
|
"mean": (0.694, 0.695, 0.693),
|
|
36
36
|
"std": (0.299, 0.296, 0.301),
|
|
37
37
|
"input_shape": (32, 32, 3),
|
|
38
38
|
"classes": list(VOCABS["french"]),
|
|
39
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
39
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
|
|
40
40
|
},
|
|
41
41
|
}
|
|
42
42
|
|
|
@@ -47,20 +47,19 @@ class TextNet(Sequential):
|
|
|
47
47
|
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
|
|
48
48
|
|
|
49
49
|
Args:
|
|
50
|
-
|
|
51
|
-
stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
|
|
50
|
+
stages (list[dict[str, list[int]]]): list of dictionaries containing the parameters of each stage.
|
|
52
51
|
include_top (bool, optional): Whether to include the classifier head. Defaults to True.
|
|
53
52
|
num_classes (int, optional): Number of output classes. Defaults to 1000.
|
|
54
|
-
cfg (
|
|
53
|
+
cfg (dict[str, Any], optional): Additional configuration. Defaults to None.
|
|
55
54
|
"""
|
|
56
55
|
|
|
57
56
|
def __init__(
|
|
58
57
|
self,
|
|
59
|
-
stages:
|
|
60
|
-
input_shape:
|
|
58
|
+
stages: list[dict[str, list[int]]],
|
|
59
|
+
input_shape: tuple[int, int, int] = (32, 32, 3),
|
|
61
60
|
num_classes: int = 1000,
|
|
62
61
|
include_top: bool = True,
|
|
63
|
-
cfg:
|
|
62
|
+
cfg: dict[str, Any] | None = None,
|
|
64
63
|
) -> None:
|
|
65
64
|
_layers = [
|
|
66
65
|
*conv_sequence(
|
|
@@ -111,9 +110,15 @@ def _textnet(
|
|
|
111
110
|
|
|
112
111
|
# Build the model
|
|
113
112
|
model = TextNet(cfg=_cfg, **kwargs)
|
|
113
|
+
_build_model(model)
|
|
114
|
+
|
|
114
115
|
# Load pretrained parameters
|
|
115
116
|
if pretrained:
|
|
116
|
-
|
|
117
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
118
|
+
# skip the mismatching layers for fine tuning
|
|
119
|
+
load_pretrained_params(
|
|
120
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
121
|
+
)
|
|
117
122
|
|
|
118
123
|
return model
|
|
119
124
|
|
|
@@ -130,12 +135,10 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
|
130
135
|
>>> out = model(input_tensor)
|
|
131
136
|
|
|
132
137
|
Args:
|
|
133
|
-
----
|
|
134
138
|
pretrained: boolean, True if model is pretrained
|
|
135
139
|
**kwargs: keyword arguments of the TextNet architecture
|
|
136
140
|
|
|
137
141
|
Returns:
|
|
138
|
-
-------
|
|
139
142
|
A textnet tiny model
|
|
140
143
|
"""
|
|
141
144
|
return _textnet(
|
|
@@ -178,12 +181,10 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
|
178
181
|
>>> out = model(input_tensor)
|
|
179
182
|
|
|
180
183
|
Args:
|
|
181
|
-
----
|
|
182
184
|
pretrained: boolean, True if model is pretrained
|
|
183
185
|
**kwargs: keyword arguments of the TextNet architecture
|
|
184
186
|
|
|
185
187
|
Returns:
|
|
186
|
-
-------
|
|
187
188
|
A TextNet small model
|
|
188
189
|
"""
|
|
189
190
|
return _textnet(
|
|
@@ -226,12 +227,10 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
|
226
227
|
>>> out = model(input_tensor)
|
|
227
228
|
|
|
228
229
|
Args:
|
|
229
|
-
----
|
|
230
230
|
pretrained: boolean, True if model is pretrained
|
|
231
231
|
**kwargs: keyword arguments of the TextNet architecture
|
|
232
232
|
|
|
233
233
|
Returns:
|
|
234
|
-
-------
|
|
235
234
|
A TextNet base model
|
|
236
235
|
"""
|
|
237
236
|
return _textnet(
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
from torch import nn
|
|
10
10
|
from torchvision.models import vgg as tv_vgg
|
|
@@ -16,7 +16,7 @@ from ...utils import load_pretrained_params
|
|
|
16
16
|
__all__ = ["vgg16_bn_r"]
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
default_cfgs:
|
|
19
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
20
20
|
"vgg16_bn_r": {
|
|
21
21
|
"mean": (0.694, 0.695, 0.693),
|
|
22
22
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -32,7 +32,7 @@ def _vgg(
|
|
|
32
32
|
pretrained: bool,
|
|
33
33
|
tv_arch: str,
|
|
34
34
|
num_rect_pools: int = 3,
|
|
35
|
-
ignore_keys:
|
|
35
|
+
ignore_keys: list[str] | None = None,
|
|
36
36
|
**kwargs: Any,
|
|
37
37
|
) -> tv_vgg.VGG:
|
|
38
38
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
@@ -45,7 +45,7 @@ def _vgg(
|
|
|
45
45
|
|
|
46
46
|
# Build the model
|
|
47
47
|
model = tv_vgg.__dict__[tv_arch](**kwargs, weights=None)
|
|
48
|
-
#
|
|
48
|
+
# list the MaxPool2d
|
|
49
49
|
pool_idcs = [idx for idx, m in enumerate(model.features) if isinstance(m, nn.MaxPool2d)]
|
|
50
50
|
# Replace their kernel with rectangular ones
|
|
51
51
|
for idx in pool_idcs[-num_rect_pools:]:
|
|
@@ -77,12 +77,10 @@ def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> tv_vgg.VGG:
|
|
|
77
77
|
>>> out = model(input_tensor)
|
|
78
78
|
|
|
79
79
|
Args:
|
|
80
|
-
----
|
|
81
80
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
82
81
|
**kwargs: keyword arguments of the VGG architecture
|
|
83
82
|
|
|
84
83
|
Returns:
|
|
85
|
-
-------
|
|
86
84
|
VGG feature extractor
|
|
87
85
|
"""
|
|
88
86
|
return _vgg(
|
|
@@ -1,28 +1,28 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
from tensorflow.keras import layers
|
|
10
10
|
from tensorflow.keras.models import Sequential
|
|
11
11
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...utils import conv_sequence, load_pretrained_params
|
|
14
|
+
from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
15
15
|
|
|
16
16
|
__all__ = ["VGG", "vgg16_bn_r"]
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
default_cfgs:
|
|
19
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
20
20
|
"vgg16_bn_r": {
|
|
21
21
|
"mean": (0.5, 0.5, 0.5),
|
|
22
22
|
"std": (1.0, 1.0, 1.0),
|
|
23
23
|
"input_shape": (32, 32, 3),
|
|
24
24
|
"classes": list(VOCABS["french"]),
|
|
25
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
25
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
|
|
26
26
|
},
|
|
27
27
|
}
|
|
28
28
|
|
|
@@ -32,7 +32,6 @@ class VGG(Sequential):
|
|
|
32
32
|
<https://arxiv.org/pdf/1409.1556.pdf>`_.
|
|
33
33
|
|
|
34
34
|
Args:
|
|
35
|
-
----
|
|
36
35
|
num_blocks: number of convolutional block in each stage
|
|
37
36
|
planes: number of output channels in each stage
|
|
38
37
|
rect_pools: whether pooling square kernels should be replace with rectangular ones
|
|
@@ -43,13 +42,13 @@ class VGG(Sequential):
|
|
|
43
42
|
|
|
44
43
|
def __init__(
|
|
45
44
|
self,
|
|
46
|
-
num_blocks:
|
|
47
|
-
planes:
|
|
48
|
-
rect_pools:
|
|
45
|
+
num_blocks: list[int],
|
|
46
|
+
planes: list[int],
|
|
47
|
+
rect_pools: list[bool],
|
|
49
48
|
include_top: bool = False,
|
|
50
49
|
num_classes: int = 1000,
|
|
51
|
-
input_shape:
|
|
52
|
-
cfg:
|
|
50
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
51
|
+
cfg: dict[str, Any] | None = None,
|
|
53
52
|
) -> None:
|
|
54
53
|
_layers = []
|
|
55
54
|
# Specify input_shape only for the first layer
|
|
@@ -67,7 +66,7 @@ class VGG(Sequential):
|
|
|
67
66
|
|
|
68
67
|
|
|
69
68
|
def _vgg(
|
|
70
|
-
arch: str, pretrained: bool, num_blocks:
|
|
69
|
+
arch: str, pretrained: bool, num_blocks: list[int], planes: list[int], rect_pools: list[bool], **kwargs: Any
|
|
71
70
|
) -> VGG:
|
|
72
71
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
73
72
|
kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
|
|
@@ -81,9 +80,15 @@ def _vgg(
|
|
|
81
80
|
|
|
82
81
|
# Build the model
|
|
83
82
|
model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
|
|
83
|
+
_build_model(model)
|
|
84
|
+
|
|
84
85
|
# Load pretrained parameters
|
|
85
86
|
if pretrained:
|
|
86
|
-
|
|
87
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
88
|
+
# skip the mismatching layers for fine tuning
|
|
89
|
+
load_pretrained_params(
|
|
90
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
91
|
+
)
|
|
87
92
|
|
|
88
93
|
return model
|
|
89
94
|
|
|
@@ -100,12 +105,10 @@ def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> VGG:
|
|
|
100
105
|
>>> out = model(input_tensor)
|
|
101
106
|
|
|
102
107
|
Args:
|
|
103
|
-
----
|
|
104
108
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
105
109
|
**kwargs: keyword arguments of the VGG architecture
|
|
106
110
|
|
|
107
111
|
Returns:
|
|
108
|
-
-------
|
|
109
112
|
VGG feature extractor
|
|
110
113
|
"""
|
|
111
114
|
return _vgg(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|