python-doctr 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/cord.py +10 -1
- doctr/datasets/funsd.py +11 -1
- doctr/datasets/ic03.py +11 -1
- doctr/datasets/ic13.py +10 -1
- doctr/datasets/iiit5k.py +26 -16
- doctr/datasets/imgur5k.py +10 -1
- doctr/datasets/sroie.py +11 -1
- doctr/datasets/svhn.py +11 -1
- doctr/datasets/svt.py +11 -1
- doctr/datasets/synthtext.py +11 -1
- doctr/datasets/utils.py +7 -2
- doctr/datasets/vocabs.py +6 -2
- doctr/datasets/wildreceipt.py +12 -1
- doctr/file_utils.py +19 -0
- doctr/io/elements.py +12 -4
- doctr/models/builder.py +2 -2
- doctr/models/classification/magc_resnet/tensorflow.py +13 -6
- doctr/models/classification/mobilenet/pytorch.py +2 -0
- doctr/models/classification/mobilenet/tensorflow.py +14 -8
- doctr/models/classification/predictor/pytorch.py +11 -7
- doctr/models/classification/predictor/tensorflow.py +10 -6
- doctr/models/classification/resnet/tensorflow.py +21 -8
- doctr/models/classification/textnet/tensorflow.py +11 -5
- doctr/models/classification/vgg/tensorflow.py +9 -3
- doctr/models/classification/vit/tensorflow.py +10 -4
- doctr/models/classification/zoo.py +22 -10
- doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
- doctr/models/detection/fast/tensorflow.py +14 -11
- doctr/models/detection/linknet/tensorflow.py +23 -11
- doctr/models/detection/predictor/tensorflow.py +2 -2
- doctr/models/factory/hub.py +5 -6
- doctr/models/kie_predictor/base.py +4 -0
- doctr/models/kie_predictor/pytorch.py +4 -0
- doctr/models/kie_predictor/tensorflow.py +8 -1
- doctr/models/modules/transformer/tensorflow.py +0 -2
- doctr/models/modules/vision_transformer/pytorch.py +1 -1
- doctr/models/modules/vision_transformer/tensorflow.py +1 -1
- doctr/models/predictor/base.py +24 -12
- doctr/models/predictor/pytorch.py +4 -0
- doctr/models/predictor/tensorflow.py +8 -1
- doctr/models/preprocessor/tensorflow.py +1 -1
- doctr/models/recognition/crnn/tensorflow.py +8 -6
- doctr/models/recognition/master/tensorflow.py +9 -4
- doctr/models/recognition/parseq/tensorflow.py +10 -8
- doctr/models/recognition/sar/tensorflow.py +7 -3
- doctr/models/recognition/vitstr/tensorflow.py +9 -4
- doctr/models/utils/pytorch.py +1 -1
- doctr/models/utils/tensorflow.py +15 -15
- doctr/transforms/functional/pytorch.py +1 -1
- doctr/transforms/modules/pytorch.py +7 -6
- doctr/transforms/modules/tensorflow.py +15 -12
- doctr/utils/geometry.py +106 -19
- doctr/utils/metrics.py +1 -1
- doctr/utils/reconstitution.py +151 -65
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/METADATA +11 -11
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/RECORD +61 -61
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
|
@@ -9,12 +9,12 @@ from functools import partial
|
|
|
9
9
|
from typing import Any, Dict, List, Optional, Tuple
|
|
10
10
|
|
|
11
11
|
import tensorflow as tf
|
|
12
|
-
from tensorflow.keras import layers
|
|
12
|
+
from tensorflow.keras import activations, layers
|
|
13
13
|
from tensorflow.keras.models import Sequential
|
|
14
14
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
|
-
from ...utils import load_pretrained_params
|
|
17
|
+
from ...utils import _build_model, load_pretrained_params
|
|
18
18
|
from ..resnet.tensorflow import ResNet
|
|
19
19
|
|
|
20
20
|
__all__ = ["magc_resnet31"]
|
|
@@ -26,7 +26,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
26
26
|
"std": (0.299, 0.296, 0.301),
|
|
27
27
|
"input_shape": (32, 32, 3),
|
|
28
28
|
"classes": list(VOCABS["french"]),
|
|
29
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
29
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0",
|
|
30
30
|
},
|
|
31
31
|
}
|
|
32
32
|
|
|
@@ -57,6 +57,7 @@ class MAGC(layers.Layer):
|
|
|
57
57
|
self.headers = headers # h
|
|
58
58
|
self.inplanes = inplanes # C
|
|
59
59
|
self.attn_scale = attn_scale
|
|
60
|
+
self.ratio = ratio
|
|
60
61
|
self.planes = int(inplanes * ratio)
|
|
61
62
|
|
|
62
63
|
self.single_header_inplanes = int(inplanes / headers) # C / h
|
|
@@ -97,7 +98,7 @@ class MAGC(layers.Layer):
|
|
|
97
98
|
if self.attn_scale and self.headers > 1:
|
|
98
99
|
context_mask = context_mask / math.sqrt(self.single_header_inplanes)
|
|
99
100
|
# B*h, 1, H*W, 1
|
|
100
|
-
context_mask =
|
|
101
|
+
context_mask = activations.softmax(context_mask, axis=2)
|
|
101
102
|
|
|
102
103
|
# Compute context
|
|
103
104
|
# B*h, 1, C/h, 1
|
|
@@ -114,7 +115,7 @@ class MAGC(layers.Layer):
|
|
|
114
115
|
# Context modeling: B, H, W, C -> B, 1, 1, C
|
|
115
116
|
context = self.context_modeling(inputs)
|
|
116
117
|
# Transform: B, 1, 1, C -> B, 1, 1, C
|
|
117
|
-
transformed = self.transform(context)
|
|
118
|
+
transformed = self.transform(context, **kwargs)
|
|
118
119
|
return inputs + transformed
|
|
119
120
|
|
|
120
121
|
|
|
@@ -151,9 +152,15 @@ def _magc_resnet(
|
|
|
151
152
|
cfg=_cfg,
|
|
152
153
|
**kwargs,
|
|
153
154
|
)
|
|
155
|
+
_build_model(model)
|
|
156
|
+
|
|
154
157
|
# Load pretrained parameters
|
|
155
158
|
if pretrained:
|
|
156
|
-
|
|
159
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
160
|
+
# skip the mismatching layers for fine tuning
|
|
161
|
+
load_pretrained_params(
|
|
162
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
163
|
+
)
|
|
157
164
|
|
|
158
165
|
return model
|
|
159
166
|
|
|
@@ -9,12 +9,14 @@ from copy import deepcopy
|
|
|
9
9
|
from typing import Any, Dict, List, Optional
|
|
10
10
|
|
|
11
11
|
from torchvision.models import mobilenetv3
|
|
12
|
+
from torchvision.models.mobilenetv3 import MobileNetV3
|
|
12
13
|
|
|
13
14
|
from doctr.datasets import VOCABS
|
|
14
15
|
|
|
15
16
|
from ...utils import load_pretrained_params
|
|
16
17
|
|
|
17
18
|
__all__ = [
|
|
19
|
+
"MobileNetV3",
|
|
18
20
|
"mobilenet_v3_small",
|
|
19
21
|
"mobilenet_v3_small_r",
|
|
20
22
|
"mobilenet_v3_large",
|
|
@@ -13,7 +13,7 @@ from tensorflow.keras import layers
|
|
|
13
13
|
from tensorflow.keras.models import Sequential
|
|
14
14
|
|
|
15
15
|
from ....datasets import VOCABS
|
|
16
|
-
from ...utils import conv_sequence, load_pretrained_params
|
|
16
|
+
from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
17
17
|
|
|
18
18
|
__all__ = [
|
|
19
19
|
"MobileNetV3",
|
|
@@ -32,42 +32,42 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
32
32
|
"std": (0.299, 0.296, 0.301),
|
|
33
33
|
"input_shape": (32, 32, 3),
|
|
34
34
|
"classes": list(VOCABS["french"]),
|
|
35
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
35
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0",
|
|
36
36
|
},
|
|
37
37
|
"mobilenet_v3_large_r": {
|
|
38
38
|
"mean": (0.694, 0.695, 0.693),
|
|
39
39
|
"std": (0.299, 0.296, 0.301),
|
|
40
40
|
"input_shape": (32, 32, 3),
|
|
41
41
|
"classes": list(VOCABS["french"]),
|
|
42
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
42
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0",
|
|
43
43
|
},
|
|
44
44
|
"mobilenet_v3_small": {
|
|
45
45
|
"mean": (0.694, 0.695, 0.693),
|
|
46
46
|
"std": (0.299, 0.296, 0.301),
|
|
47
47
|
"input_shape": (32, 32, 3),
|
|
48
48
|
"classes": list(VOCABS["french"]),
|
|
49
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
49
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0",
|
|
50
50
|
},
|
|
51
51
|
"mobilenet_v3_small_r": {
|
|
52
52
|
"mean": (0.694, 0.695, 0.693),
|
|
53
53
|
"std": (0.299, 0.296, 0.301),
|
|
54
54
|
"input_shape": (32, 32, 3),
|
|
55
55
|
"classes": list(VOCABS["french"]),
|
|
56
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
56
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0",
|
|
57
57
|
},
|
|
58
58
|
"mobilenet_v3_small_crop_orientation": {
|
|
59
59
|
"mean": (0.694, 0.695, 0.693),
|
|
60
60
|
"std": (0.299, 0.296, 0.301),
|
|
61
61
|
"input_shape": (128, 128, 3),
|
|
62
62
|
"classes": [0, -90, 180, 90],
|
|
63
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
63
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0",
|
|
64
64
|
},
|
|
65
65
|
"mobilenet_v3_small_page_orientation": {
|
|
66
66
|
"mean": (0.694, 0.695, 0.693),
|
|
67
67
|
"std": (0.299, 0.296, 0.301),
|
|
68
68
|
"input_shape": (512, 512, 3),
|
|
69
69
|
"classes": [0, -90, 180, 90],
|
|
70
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
70
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0",
|
|
71
71
|
},
|
|
72
72
|
}
|
|
73
73
|
|
|
@@ -295,9 +295,15 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa
|
|
|
295
295
|
cfg=_cfg,
|
|
296
296
|
**kwargs,
|
|
297
297
|
)
|
|
298
|
+
_build_model(model)
|
|
299
|
+
|
|
298
300
|
# Load pretrained parameters
|
|
299
301
|
if pretrained:
|
|
300
|
-
|
|
302
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
303
|
+
# skip the mismatching layers for fine tuning
|
|
304
|
+
load_pretrained_params(
|
|
305
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
306
|
+
)
|
|
301
307
|
|
|
302
308
|
return model
|
|
303
309
|
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Union
|
|
6
|
+
from typing import List, Optional, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -27,12 +27,12 @@ class OrientationPredictor(nn.Module):
|
|
|
27
27
|
|
|
28
28
|
def __init__(
|
|
29
29
|
self,
|
|
30
|
-
pre_processor: PreProcessor,
|
|
31
|
-
model: nn.Module,
|
|
30
|
+
pre_processor: Optional[PreProcessor],
|
|
31
|
+
model: Optional[nn.Module],
|
|
32
32
|
) -> None:
|
|
33
33
|
super().__init__()
|
|
34
|
-
self.pre_processor = pre_processor
|
|
35
|
-
self.model = model.eval()
|
|
34
|
+
self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
|
|
35
|
+
self.model = model.eval() if isinstance(model, nn.Module) else None
|
|
36
36
|
|
|
37
37
|
@torch.inference_mode()
|
|
38
38
|
def forward(
|
|
@@ -43,12 +43,16 @@ class OrientationPredictor(nn.Module):
|
|
|
43
43
|
if any(input.ndim != 3 for input in inputs):
|
|
44
44
|
raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
|
|
45
45
|
|
|
46
|
+
if self.model is None or self.pre_processor is None:
|
|
47
|
+
# predictor is disabled
|
|
48
|
+
return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]
|
|
49
|
+
|
|
46
50
|
processed_batches = self.pre_processor(inputs)
|
|
47
51
|
_params = next(self.model.parameters())
|
|
48
52
|
self.model, processed_batches = set_device_and_dtype(
|
|
49
53
|
self.model, processed_batches, _params.device, _params.dtype
|
|
50
54
|
)
|
|
51
|
-
predicted_batches = [self.model(batch) for batch in processed_batches]
|
|
55
|
+
predicted_batches = [self.model(batch) for batch in processed_batches] # type: ignore[misc]
|
|
52
56
|
# confidence
|
|
53
57
|
probs = [
|
|
54
58
|
torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
|
|
@@ -57,7 +61,7 @@ class OrientationPredictor(nn.Module):
|
|
|
57
61
|
predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]
|
|
58
62
|
|
|
59
63
|
class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
|
|
60
|
-
classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
|
|
64
|
+
classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] # type: ignore[union-attr]
|
|
61
65
|
confs = [round(float(p), 2) for prob in probs for p in prob]
|
|
62
66
|
|
|
63
67
|
return [class_idxs, classes, confs]
|
|
@@ -3,11 +3,11 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Union
|
|
6
|
+
from typing import List, Optional, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
10
|
-
from tensorflow import
|
|
10
|
+
from tensorflow.keras import Model
|
|
11
11
|
|
|
12
12
|
from doctr.models.preprocessor import PreProcessor
|
|
13
13
|
from doctr.utils.repr import NestedObject
|
|
@@ -29,11 +29,11 @@ class OrientationPredictor(NestedObject):
|
|
|
29
29
|
|
|
30
30
|
def __init__(
|
|
31
31
|
self,
|
|
32
|
-
pre_processor: PreProcessor,
|
|
33
|
-
model:
|
|
32
|
+
pre_processor: Optional[PreProcessor],
|
|
33
|
+
model: Optional[Model],
|
|
34
34
|
) -> None:
|
|
35
|
-
self.pre_processor = pre_processor
|
|
36
|
-
self.model = model
|
|
35
|
+
self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
|
|
36
|
+
self.model = model if isinstance(model, Model) else None
|
|
37
37
|
|
|
38
38
|
def __call__(
|
|
39
39
|
self,
|
|
@@ -43,6 +43,10 @@ class OrientationPredictor(NestedObject):
|
|
|
43
43
|
if any(input.ndim != 3 for input in inputs):
|
|
44
44
|
raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
|
|
45
45
|
|
|
46
|
+
if self.model is None or self.pre_processor is None:
|
|
47
|
+
# predictor is disabled
|
|
48
|
+
return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]
|
|
49
|
+
|
|
46
50
|
processed_batches = self.pre_processor(inputs)
|
|
47
51
|
predicted_batches = [self.model(batch, training=False) for batch in processed_batches]
|
|
48
52
|
|
|
@@ -13,7 +13,7 @@ from tensorflow.keras.models import Sequential
|
|
|
13
13
|
|
|
14
14
|
from doctr.datasets import VOCABS
|
|
15
15
|
|
|
16
|
-
from ...utils import conv_sequence, load_pretrained_params
|
|
16
|
+
from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
17
17
|
|
|
18
18
|
__all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]
|
|
19
19
|
|
|
@@ -24,35 +24,35 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
24
24
|
"std": (0.299, 0.296, 0.301),
|
|
25
25
|
"input_shape": (32, 32, 3),
|
|
26
26
|
"classes": list(VOCABS["french"]),
|
|
27
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
27
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
|
|
28
28
|
},
|
|
29
29
|
"resnet31": {
|
|
30
30
|
"mean": (0.694, 0.695, 0.693),
|
|
31
31
|
"std": (0.299, 0.296, 0.301),
|
|
32
32
|
"input_shape": (32, 32, 3),
|
|
33
33
|
"classes": list(VOCABS["french"]),
|
|
34
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
34
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
|
|
35
35
|
},
|
|
36
36
|
"resnet34": {
|
|
37
37
|
"mean": (0.694, 0.695, 0.693),
|
|
38
38
|
"std": (0.299, 0.296, 0.301),
|
|
39
39
|
"input_shape": (32, 32, 3),
|
|
40
40
|
"classes": list(VOCABS["french"]),
|
|
41
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
41
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
|
|
42
42
|
},
|
|
43
43
|
"resnet50": {
|
|
44
44
|
"mean": (0.694, 0.695, 0.693),
|
|
45
45
|
"std": (0.299, 0.296, 0.301),
|
|
46
46
|
"input_shape": (32, 32, 3),
|
|
47
47
|
"classes": list(VOCABS["french"]),
|
|
48
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
48
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
|
|
49
49
|
},
|
|
50
50
|
"resnet34_wide": {
|
|
51
51
|
"mean": (0.694, 0.695, 0.693),
|
|
52
52
|
"std": (0.299, 0.296, 0.301),
|
|
53
53
|
"input_shape": (32, 32, 3),
|
|
54
54
|
"classes": list(VOCABS["french"]),
|
|
55
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
55
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
|
|
56
56
|
},
|
|
57
57
|
}
|
|
58
58
|
|
|
@@ -210,9 +210,15 @@ def _resnet(
|
|
|
210
210
|
model = ResNet(
|
|
211
211
|
num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs
|
|
212
212
|
)
|
|
213
|
+
_build_model(model)
|
|
214
|
+
|
|
213
215
|
# Load pretrained parameters
|
|
214
216
|
if pretrained:
|
|
215
|
-
|
|
217
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
218
|
+
# skip the mismatching layers for fine tuning
|
|
219
|
+
load_pretrained_params(
|
|
220
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
221
|
+
)
|
|
216
222
|
|
|
217
223
|
return model
|
|
218
224
|
|
|
@@ -354,10 +360,17 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
354
360
|
)
|
|
355
361
|
|
|
356
362
|
model.cfg = _cfg
|
|
363
|
+
_build_model(model)
|
|
357
364
|
|
|
358
365
|
# Load pretrained parameters
|
|
359
366
|
if pretrained:
|
|
360
|
-
|
|
367
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
368
|
+
# skip the mismatching layers for fine tuning
|
|
369
|
+
load_pretrained_params(
|
|
370
|
+
model,
|
|
371
|
+
default_cfgs["resnet50"]["url"],
|
|
372
|
+
skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
|
|
373
|
+
)
|
|
361
374
|
|
|
362
375
|
return model
|
|
363
376
|
|
|
@@ -12,7 +12,7 @@ from tensorflow.keras import Sequential, layers
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
|
|
14
14
|
from ...modules.layers.tensorflow import FASTConvLayer
|
|
15
|
-
from ...utils import conv_sequence, load_pretrained_params
|
|
15
|
+
from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
16
16
|
|
|
17
17
|
__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
|
|
18
18
|
|
|
@@ -22,21 +22,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
22
22
|
"std": (0.299, 0.296, 0.301),
|
|
23
23
|
"input_shape": (32, 32, 3),
|
|
24
24
|
"classes": list(VOCABS["french"]),
|
|
25
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
25
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
|
|
26
26
|
},
|
|
27
27
|
"textnet_small": {
|
|
28
28
|
"mean": (0.694, 0.695, 0.693),
|
|
29
29
|
"std": (0.299, 0.296, 0.301),
|
|
30
30
|
"input_shape": (32, 32, 3),
|
|
31
31
|
"classes": list(VOCABS["french"]),
|
|
32
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
32
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
|
|
33
33
|
},
|
|
34
34
|
"textnet_base": {
|
|
35
35
|
"mean": (0.694, 0.695, 0.693),
|
|
36
36
|
"std": (0.299, 0.296, 0.301),
|
|
37
37
|
"input_shape": (32, 32, 3),
|
|
38
38
|
"classes": list(VOCABS["french"]),
|
|
39
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
39
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
|
|
40
40
|
},
|
|
41
41
|
}
|
|
42
42
|
|
|
@@ -111,9 +111,15 @@ def _textnet(
|
|
|
111
111
|
|
|
112
112
|
# Build the model
|
|
113
113
|
model = TextNet(cfg=_cfg, **kwargs)
|
|
114
|
+
_build_model(model)
|
|
115
|
+
|
|
114
116
|
# Load pretrained parameters
|
|
115
117
|
if pretrained:
|
|
116
|
-
|
|
118
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
119
|
+
# skip the mismatching layers for fine tuning
|
|
120
|
+
load_pretrained_params(
|
|
121
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
122
|
+
)
|
|
117
123
|
|
|
118
124
|
return model
|
|
119
125
|
|
|
@@ -11,7 +11,7 @@ from tensorflow.keras.models import Sequential
|
|
|
11
11
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...utils import conv_sequence, load_pretrained_params
|
|
14
|
+
from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
15
15
|
|
|
16
16
|
__all__ = ["VGG", "vgg16_bn_r"]
|
|
17
17
|
|
|
@@ -22,7 +22,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
22
22
|
"std": (1.0, 1.0, 1.0),
|
|
23
23
|
"input_shape": (32, 32, 3),
|
|
24
24
|
"classes": list(VOCABS["french"]),
|
|
25
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
25
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
|
|
26
26
|
},
|
|
27
27
|
}
|
|
28
28
|
|
|
@@ -81,9 +81,15 @@ def _vgg(
|
|
|
81
81
|
|
|
82
82
|
# Build the model
|
|
83
83
|
model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
|
|
84
|
+
_build_model(model)
|
|
85
|
+
|
|
84
86
|
# Load pretrained parameters
|
|
85
87
|
if pretrained:
|
|
86
|
-
|
|
88
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
89
|
+
# skip the mismatching layers for fine tuning
|
|
90
|
+
load_pretrained_params(
|
|
91
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
92
|
+
)
|
|
87
93
|
|
|
88
94
|
return model
|
|
89
95
|
|
|
@@ -14,7 +14,7 @@ from doctr.models.modules.transformer import EncoderBlock
|
|
|
14
14
|
from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding
|
|
15
15
|
from doctr.utils.repr import NestedObject
|
|
16
16
|
|
|
17
|
-
from ...utils import load_pretrained_params
|
|
17
|
+
from ...utils import _build_model, load_pretrained_params
|
|
18
18
|
|
|
19
19
|
__all__ = ["vit_s", "vit_b"]
|
|
20
20
|
|
|
@@ -25,14 +25,14 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
25
25
|
"std": (0.299, 0.296, 0.301),
|
|
26
26
|
"input_shape": (3, 32, 32),
|
|
27
27
|
"classes": list(VOCABS["french"]),
|
|
28
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
28
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0",
|
|
29
29
|
},
|
|
30
30
|
"vit_b": {
|
|
31
31
|
"mean": (0.694, 0.695, 0.693),
|
|
32
32
|
"std": (0.299, 0.296, 0.301),
|
|
33
33
|
"input_shape": (32, 32, 3),
|
|
34
34
|
"classes": list(VOCABS["french"]),
|
|
35
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
35
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0",
|
|
36
36
|
},
|
|
37
37
|
}
|
|
38
38
|
|
|
@@ -121,9 +121,15 @@ def _vit(
|
|
|
121
121
|
|
|
122
122
|
# Build the model
|
|
123
123
|
model = VisionTransformer(cfg=_cfg, **kwargs)
|
|
124
|
+
_build_model(model)
|
|
125
|
+
|
|
124
126
|
# Load pretrained parameters
|
|
125
127
|
if pretrained:
|
|
126
|
-
|
|
128
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
129
|
+
# skip the mismatching layers for fine tuning
|
|
130
|
+
load_pretrained_params(
|
|
131
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
132
|
+
)
|
|
127
133
|
|
|
128
134
|
return model
|
|
129
135
|
|
|
@@ -34,15 +34,27 @@ ARCHS: List[str] = [
|
|
|
34
34
|
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
def _orientation_predictor(
|
|
38
|
-
|
|
39
|
-
|
|
37
|
+
def _orientation_predictor(
|
|
38
|
+
arch: Any, pretrained: bool, model_type: str, disabled: bool = False, **kwargs: Any
|
|
39
|
+
) -> OrientationPredictor:
|
|
40
|
+
if disabled:
|
|
41
|
+
# Case where the orientation predictor is disabled
|
|
42
|
+
return OrientationPredictor(None, None)
|
|
43
|
+
|
|
44
|
+
if isinstance(arch, str):
|
|
45
|
+
if arch not in ORIENTATION_ARCHS:
|
|
46
|
+
raise ValueError(f"unknown architecture '{arch}'")
|
|
47
|
+
|
|
48
|
+
# Load directly classifier from backbone
|
|
49
|
+
_model = classification.__dict__[arch](pretrained=pretrained)
|
|
50
|
+
else:
|
|
51
|
+
if not isinstance(arch, classification.MobileNetV3):
|
|
52
|
+
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
53
|
+
_model = arch
|
|
40
54
|
|
|
41
|
-
# Load directly classifier from backbone
|
|
42
|
-
_model = classification.__dict__[arch](pretrained=pretrained)
|
|
43
55
|
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
44
56
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
45
|
-
kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop"
|
|
57
|
+
kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
|
|
46
58
|
input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
|
|
47
59
|
predictor = OrientationPredictor(
|
|
48
60
|
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
|
|
@@ -51,7 +63,7 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient
|
|
|
51
63
|
|
|
52
64
|
|
|
53
65
|
def crop_orientation_predictor(
|
|
54
|
-
arch:
|
|
66
|
+
arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
|
|
55
67
|
) -> OrientationPredictor:
|
|
56
68
|
"""Crop orientation classification architecture.
|
|
57
69
|
|
|
@@ -71,11 +83,11 @@ def crop_orientation_predictor(
|
|
|
71
83
|
-------
|
|
72
84
|
OrientationPredictor
|
|
73
85
|
"""
|
|
74
|
-
return _orientation_predictor(arch, pretrained, **kwargs)
|
|
86
|
+
return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs)
|
|
75
87
|
|
|
76
88
|
|
|
77
89
|
def page_orientation_predictor(
|
|
78
|
-
arch:
|
|
90
|
+
arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
|
|
79
91
|
) -> OrientationPredictor:
|
|
80
92
|
"""Page orientation classification architecture.
|
|
81
93
|
|
|
@@ -95,4 +107,4 @@ def page_orientation_predictor(
|
|
|
95
107
|
-------
|
|
96
108
|
OrientationPredictor
|
|
97
109
|
"""
|
|
98
|
-
return _orientation_predictor(arch, pretrained, **kwargs)
|
|
110
|
+
return _orientation_predictor(arch, pretrained, model_type="page", **kwargs)
|
|
@@ -10,12 +10,17 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
13
|
-
from tensorflow import
|
|
14
|
-
from tensorflow.keras import layers
|
|
13
|
+
from tensorflow.keras import Model, Sequential, layers, losses
|
|
15
14
|
from tensorflow.keras.applications import ResNet50
|
|
16
15
|
|
|
17
16
|
from doctr.file_utils import CLASS_NAME
|
|
18
|
-
from doctr.models.utils import
|
|
17
|
+
from doctr.models.utils import (
|
|
18
|
+
IntermediateLayerGetter,
|
|
19
|
+
_bf16_to_float32,
|
|
20
|
+
_build_model,
|
|
21
|
+
conv_sequence,
|
|
22
|
+
load_pretrained_params,
|
|
23
|
+
)
|
|
19
24
|
from doctr.utils.repr import NestedObject
|
|
20
25
|
|
|
21
26
|
from ...classification import mobilenet_v3_large
|
|
@@ -29,13 +34,13 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
29
34
|
"mean": (0.798, 0.785, 0.772),
|
|
30
35
|
"std": (0.264, 0.2749, 0.287),
|
|
31
36
|
"input_shape": (1024, 1024, 3),
|
|
32
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
37
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0",
|
|
33
38
|
},
|
|
34
39
|
"db_mobilenet_v3_large": {
|
|
35
40
|
"mean": (0.798, 0.785, 0.772),
|
|
36
41
|
"std": (0.264, 0.2749, 0.287),
|
|
37
42
|
"input_shape": (1024, 1024, 3),
|
|
38
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
43
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0",
|
|
39
44
|
},
|
|
40
45
|
}
|
|
41
46
|
|
|
@@ -81,7 +86,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
81
86
|
if dilation_factor > 1:
|
|
82
87
|
_layers.append(layers.UpSampling2D(size=(dilation_factor, dilation_factor), interpolation="nearest"))
|
|
83
88
|
|
|
84
|
-
module =
|
|
89
|
+
module = Sequential(_layers)
|
|
85
90
|
|
|
86
91
|
return module
|
|
87
92
|
|
|
@@ -104,7 +109,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
104
109
|
return layers.concatenate(results)
|
|
105
110
|
|
|
106
111
|
|
|
107
|
-
class DBNet(_DBNet,
|
|
112
|
+
class DBNet(_DBNet, Model, NestedObject):
|
|
108
113
|
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
109
114
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
110
115
|
|
|
@@ -147,14 +152,14 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
147
152
|
_inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
|
|
148
153
|
output_shape = tuple(self.fpn(_inputs).shape)
|
|
149
154
|
|
|
150
|
-
self.probability_head =
|
|
155
|
+
self.probability_head = Sequential([
|
|
151
156
|
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
|
|
152
157
|
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
|
|
153
158
|
layers.BatchNormalization(),
|
|
154
159
|
layers.Activation("relu"),
|
|
155
160
|
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
|
|
156
161
|
])
|
|
157
|
-
self.threshold_head =
|
|
162
|
+
self.threshold_head = Sequential([
|
|
158
163
|
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
|
|
159
164
|
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
|
|
160
165
|
layers.BatchNormalization(),
|
|
@@ -206,7 +211,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
206
211
|
|
|
207
212
|
# Focal loss
|
|
208
213
|
focal_scale = 10.0
|
|
209
|
-
bce_loss =
|
|
214
|
+
bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
|
|
210
215
|
|
|
211
216
|
# Convert logits to prob, compute gamma factor
|
|
212
217
|
p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
|
|
@@ -305,9 +310,16 @@ def _db_resnet(
|
|
|
305
310
|
|
|
306
311
|
# Build the model
|
|
307
312
|
model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
|
|
313
|
+
_build_model(model)
|
|
314
|
+
|
|
308
315
|
# Load pretrained parameters
|
|
309
316
|
if pretrained:
|
|
310
|
-
|
|
317
|
+
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
318
|
+
load_pretrained_params(
|
|
319
|
+
model,
|
|
320
|
+
_cfg["url"],
|
|
321
|
+
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
322
|
+
)
|
|
311
323
|
|
|
312
324
|
return model
|
|
313
325
|
|
|
@@ -326,6 +338,10 @@ def _db_mobilenet(
|
|
|
326
338
|
# Patch the config
|
|
327
339
|
_cfg = deepcopy(default_cfgs[arch])
|
|
328
340
|
_cfg["input_shape"] = input_shape or _cfg["input_shape"]
|
|
341
|
+
if not kwargs.get("class_names", None):
|
|
342
|
+
kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
|
|
343
|
+
else:
|
|
344
|
+
kwargs["class_names"] = sorted(kwargs["class_names"])
|
|
329
345
|
|
|
330
346
|
# Feature extractor
|
|
331
347
|
feat_extractor = IntermediateLayerGetter(
|
|
@@ -339,9 +355,15 @@ def _db_mobilenet(
|
|
|
339
355
|
|
|
340
356
|
# Build the model
|
|
341
357
|
model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
|
|
358
|
+
_build_model(model)
|
|
342
359
|
# Load pretrained parameters
|
|
343
360
|
if pretrained:
|
|
344
|
-
|
|
361
|
+
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
362
|
+
load_pretrained_params(
|
|
363
|
+
model,
|
|
364
|
+
_cfg["url"],
|
|
365
|
+
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
366
|
+
)
|
|
345
367
|
|
|
346
368
|
return model
|
|
347
369
|
|