python-doctr 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +1 -1
- doctr/contrib/__init__.py +0 -0
- doctr/contrib/artefacts.py +131 -0
- doctr/contrib/base.py +105 -0
- doctr/datasets/cord.py +10 -1
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +11 -1
- doctr/datasets/generator/base.py +6 -5
- doctr/datasets/ic03.py +11 -1
- doctr/datasets/ic13.py +10 -1
- doctr/datasets/iiit5k.py +26 -16
- doctr/datasets/imgur5k.py +11 -2
- doctr/datasets/loader.py +1 -6
- doctr/datasets/sroie.py +11 -1
- doctr/datasets/svhn.py +11 -1
- doctr/datasets/svt.py +11 -1
- doctr/datasets/synthtext.py +11 -1
- doctr/datasets/utils.py +9 -3
- doctr/datasets/vocabs.py +15 -4
- doctr/datasets/wildreceipt.py +12 -1
- doctr/file_utils.py +45 -12
- doctr/io/elements.py +52 -10
- doctr/io/html.py +2 -2
- doctr/io/image/pytorch.py +6 -8
- doctr/io/image/tensorflow.py +1 -1
- doctr/io/pdf.py +5 -2
- doctr/io/reader.py +6 -0
- doctr/models/__init__.py +0 -1
- doctr/models/_utils.py +57 -20
- doctr/models/builder.py +73 -15
- doctr/models/classification/magc_resnet/tensorflow.py +13 -6
- doctr/models/classification/mobilenet/pytorch.py +47 -9
- doctr/models/classification/mobilenet/tensorflow.py +51 -14
- doctr/models/classification/predictor/pytorch.py +28 -17
- doctr/models/classification/predictor/tensorflow.py +26 -16
- doctr/models/classification/resnet/tensorflow.py +21 -8
- doctr/models/classification/textnet/pytorch.py +3 -3
- doctr/models/classification/textnet/tensorflow.py +11 -5
- doctr/models/classification/vgg/tensorflow.py +9 -3
- doctr/models/classification/vit/tensorflow.py +10 -4
- doctr/models/classification/zoo.py +55 -19
- doctr/models/detection/_utils/__init__.py +1 -0
- doctr/models/detection/_utils/base.py +66 -0
- doctr/models/detection/differentiable_binarization/base.py +4 -3
- doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
- doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
- doctr/models/detection/fast/base.py +6 -5
- doctr/models/detection/fast/pytorch.py +4 -4
- doctr/models/detection/fast/tensorflow.py +15 -12
- doctr/models/detection/linknet/base.py +4 -3
- doctr/models/detection/linknet/tensorflow.py +23 -11
- doctr/models/detection/predictor/pytorch.py +15 -1
- doctr/models/detection/predictor/tensorflow.py +17 -3
- doctr/models/detection/zoo.py +7 -2
- doctr/models/factory/hub.py +8 -18
- doctr/models/kie_predictor/base.py +13 -3
- doctr/models/kie_predictor/pytorch.py +45 -20
- doctr/models/kie_predictor/tensorflow.py +44 -17
- doctr/models/modules/layers/pytorch.py +2 -3
- doctr/models/modules/layers/tensorflow.py +6 -8
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/transformer/tensorflow.py +0 -2
- doctr/models/modules/vision_transformer/pytorch.py +1 -1
- doctr/models/modules/vision_transformer/tensorflow.py +1 -1
- doctr/models/predictor/base.py +97 -58
- doctr/models/predictor/pytorch.py +35 -20
- doctr/models/predictor/tensorflow.py +35 -18
- doctr/models/preprocessor/pytorch.py +4 -4
- doctr/models/preprocessor/tensorflow.py +3 -2
- doctr/models/recognition/crnn/tensorflow.py +8 -6
- doctr/models/recognition/master/pytorch.py +2 -2
- doctr/models/recognition/master/tensorflow.py +9 -4
- doctr/models/recognition/parseq/pytorch.py +4 -3
- doctr/models/recognition/parseq/tensorflow.py +14 -11
- doctr/models/recognition/sar/pytorch.py +7 -6
- doctr/models/recognition/sar/tensorflow.py +10 -12
- doctr/models/recognition/vitstr/pytorch.py +1 -1
- doctr/models/recognition/vitstr/tensorflow.py +9 -4
- doctr/models/recognition/zoo.py +1 -1
- doctr/models/utils/pytorch.py +1 -1
- doctr/models/utils/tensorflow.py +15 -15
- doctr/models/zoo.py +2 -2
- doctr/py.typed +0 -0
- doctr/transforms/functional/base.py +1 -1
- doctr/transforms/functional/pytorch.py +5 -5
- doctr/transforms/modules/base.py +37 -15
- doctr/transforms/modules/pytorch.py +73 -14
- doctr/transforms/modules/tensorflow.py +78 -19
- doctr/utils/fonts.py +7 -5
- doctr/utils/geometry.py +141 -31
- doctr/utils/metrics.py +34 -175
- doctr/utils/reconstitution.py +212 -0
- doctr/utils/visualization.py +5 -118
- doctr/version.py +1 -1
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/METADATA +85 -81
- python_doctr-0.10.0.dist-info/RECORD +173 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
- doctr/models/artefacts/__init__.py +0 -2
- doctr/models/artefacts/barcode.py +0 -74
- doctr/models/artefacts/face.py +0 -63
- doctr/models/obj_detection/__init__.py +0 -1
- doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
- python_doctr-0.8.1.dist-info/RECORD +0 -173
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Union
|
|
6
|
+
from typing import List, Optional, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -12,12 +12,12 @@ from torch import nn
|
|
|
12
12
|
from doctr.models.preprocessor import PreProcessor
|
|
13
13
|
from doctr.models.utils import set_device_and_dtype
|
|
14
14
|
|
|
15
|
-
__all__ = ["
|
|
15
|
+
__all__ = ["OrientationPredictor"]
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class
|
|
19
|
-
"""Implements an object able to detect the reading direction of a text box.
|
|
20
|
-
4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
|
|
18
|
+
class OrientationPredictor(nn.Module):
|
|
19
|
+
"""Implements an object able to detect the reading direction of a text box or a page.
|
|
20
|
+
4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
23
|
----
|
|
@@ -27,30 +27,41 @@ class CropOrientationPredictor(nn.Module):
|
|
|
27
27
|
|
|
28
28
|
def __init__(
|
|
29
29
|
self,
|
|
30
|
-
pre_processor: PreProcessor,
|
|
31
|
-
model: nn.Module,
|
|
30
|
+
pre_processor: Optional[PreProcessor],
|
|
31
|
+
model: Optional[nn.Module],
|
|
32
32
|
) -> None:
|
|
33
33
|
super().__init__()
|
|
34
|
-
self.pre_processor = pre_processor
|
|
35
|
-
self.model = model.eval()
|
|
34
|
+
self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
|
|
35
|
+
self.model = model.eval() if isinstance(model, nn.Module) else None
|
|
36
36
|
|
|
37
37
|
@torch.inference_mode()
|
|
38
38
|
def forward(
|
|
39
39
|
self,
|
|
40
|
-
|
|
41
|
-
) -> List[int]:
|
|
40
|
+
inputs: List[Union[np.ndarray, torch.Tensor]],
|
|
41
|
+
) -> List[Union[List[int], List[float]]]:
|
|
42
42
|
# Dimension check
|
|
43
|
-
if any(
|
|
44
|
-
raise ValueError("incorrect input shape: all
|
|
43
|
+
if any(input.ndim != 3 for input in inputs):
|
|
44
|
+
raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
|
|
45
45
|
|
|
46
|
-
|
|
46
|
+
if self.model is None or self.pre_processor is None:
|
|
47
|
+
# predictor is disabled
|
|
48
|
+
return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]
|
|
49
|
+
|
|
50
|
+
processed_batches = self.pre_processor(inputs)
|
|
47
51
|
_params = next(self.model.parameters())
|
|
48
52
|
self.model, processed_batches = set_device_and_dtype(
|
|
49
53
|
self.model, processed_batches, _params.device, _params.dtype
|
|
50
54
|
)
|
|
51
|
-
predicted_batches = [self.model(batch) for batch in processed_batches]
|
|
52
|
-
|
|
55
|
+
predicted_batches = [self.model(batch) for batch in processed_batches] # type: ignore[misc]
|
|
56
|
+
# confidence
|
|
57
|
+
probs = [
|
|
58
|
+
torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
|
|
59
|
+
]
|
|
53
60
|
# Postprocess predictions
|
|
54
61
|
predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]
|
|
55
62
|
|
|
56
|
-
|
|
63
|
+
class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
|
|
64
|
+
classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] # type: ignore[union-attr]
|
|
65
|
+
confs = [round(float(p), 2) for prob in probs for p in prob]
|
|
66
|
+
|
|
67
|
+
return [class_idxs, classes, confs]
|
|
@@ -3,21 +3,21 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Union
|
|
6
|
+
from typing import List, Optional, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
10
|
-
from tensorflow import
|
|
10
|
+
from tensorflow.keras import Model
|
|
11
11
|
|
|
12
12
|
from doctr.models.preprocessor import PreProcessor
|
|
13
13
|
from doctr.utils.repr import NestedObject
|
|
14
14
|
|
|
15
|
-
__all__ = ["
|
|
15
|
+
__all__ = ["OrientationPredictor"]
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class
|
|
19
|
-
"""Implements an object able to detect the reading direction of a text box.
|
|
20
|
-
4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
|
|
18
|
+
class OrientationPredictor(NestedObject):
|
|
19
|
+
"""Implements an object able to detect the reading direction of a text box or a page.
|
|
20
|
+
4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
23
|
----
|
|
@@ -29,24 +29,34 @@ class CropOrientationPredictor(NestedObject):
|
|
|
29
29
|
|
|
30
30
|
def __init__(
|
|
31
31
|
self,
|
|
32
|
-
pre_processor: PreProcessor,
|
|
33
|
-
model:
|
|
32
|
+
pre_processor: Optional[PreProcessor],
|
|
33
|
+
model: Optional[Model],
|
|
34
34
|
) -> None:
|
|
35
|
-
self.pre_processor = pre_processor
|
|
36
|
-
self.model = model
|
|
35
|
+
self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
|
|
36
|
+
self.model = model if isinstance(model, Model) else None
|
|
37
37
|
|
|
38
38
|
def __call__(
|
|
39
39
|
self,
|
|
40
|
-
|
|
41
|
-
) -> List[int]:
|
|
40
|
+
inputs: List[Union[np.ndarray, tf.Tensor]],
|
|
41
|
+
) -> List[Union[List[int], List[float]]]:
|
|
42
42
|
# Dimension check
|
|
43
|
-
if any(
|
|
44
|
-
raise ValueError("incorrect input shape: all
|
|
43
|
+
if any(input.ndim != 3 for input in inputs):
|
|
44
|
+
raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
|
|
45
45
|
|
|
46
|
-
|
|
46
|
+
if self.model is None or self.pre_processor is None:
|
|
47
|
+
# predictor is disabled
|
|
48
|
+
return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]
|
|
49
|
+
|
|
50
|
+
processed_batches = self.pre_processor(inputs)
|
|
47
51
|
predicted_batches = [self.model(batch, training=False) for batch in processed_batches]
|
|
48
52
|
|
|
53
|
+
# confidence
|
|
54
|
+
probs = [tf.math.reduce_max(tf.nn.softmax(batch, axis=1), axis=1).numpy() for batch in predicted_batches]
|
|
49
55
|
# Postprocess predictions
|
|
50
56
|
predicted_batches = [out_batch.numpy().argmax(1) for out_batch in predicted_batches]
|
|
51
57
|
|
|
52
|
-
|
|
58
|
+
class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
|
|
59
|
+
classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
|
|
60
|
+
confs = [round(float(p), 2) for prob in probs for p in prob]
|
|
61
|
+
|
|
62
|
+
return [class_idxs, classes, confs]
|
|
@@ -13,7 +13,7 @@ from tensorflow.keras.models import Sequential
|
|
|
13
13
|
|
|
14
14
|
from doctr.datasets import VOCABS
|
|
15
15
|
|
|
16
|
-
from ...utils import conv_sequence, load_pretrained_params
|
|
16
|
+
from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
17
17
|
|
|
18
18
|
__all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]
|
|
19
19
|
|
|
@@ -24,35 +24,35 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
24
24
|
"std": (0.299, 0.296, 0.301),
|
|
25
25
|
"input_shape": (32, 32, 3),
|
|
26
26
|
"classes": list(VOCABS["french"]),
|
|
27
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
27
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
|
|
28
28
|
},
|
|
29
29
|
"resnet31": {
|
|
30
30
|
"mean": (0.694, 0.695, 0.693),
|
|
31
31
|
"std": (0.299, 0.296, 0.301),
|
|
32
32
|
"input_shape": (32, 32, 3),
|
|
33
33
|
"classes": list(VOCABS["french"]),
|
|
34
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
34
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
|
|
35
35
|
},
|
|
36
36
|
"resnet34": {
|
|
37
37
|
"mean": (0.694, 0.695, 0.693),
|
|
38
38
|
"std": (0.299, 0.296, 0.301),
|
|
39
39
|
"input_shape": (32, 32, 3),
|
|
40
40
|
"classes": list(VOCABS["french"]),
|
|
41
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
41
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
|
|
42
42
|
},
|
|
43
43
|
"resnet50": {
|
|
44
44
|
"mean": (0.694, 0.695, 0.693),
|
|
45
45
|
"std": (0.299, 0.296, 0.301),
|
|
46
46
|
"input_shape": (32, 32, 3),
|
|
47
47
|
"classes": list(VOCABS["french"]),
|
|
48
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
48
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
|
|
49
49
|
},
|
|
50
50
|
"resnet34_wide": {
|
|
51
51
|
"mean": (0.694, 0.695, 0.693),
|
|
52
52
|
"std": (0.299, 0.296, 0.301),
|
|
53
53
|
"input_shape": (32, 32, 3),
|
|
54
54
|
"classes": list(VOCABS["french"]),
|
|
55
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
55
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
|
|
56
56
|
},
|
|
57
57
|
}
|
|
58
58
|
|
|
@@ -210,9 +210,15 @@ def _resnet(
|
|
|
210
210
|
model = ResNet(
|
|
211
211
|
num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs
|
|
212
212
|
)
|
|
213
|
+
_build_model(model)
|
|
214
|
+
|
|
213
215
|
# Load pretrained parameters
|
|
214
216
|
if pretrained:
|
|
215
|
-
|
|
217
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
218
|
+
# skip the mismatching layers for fine tuning
|
|
219
|
+
load_pretrained_params(
|
|
220
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
221
|
+
)
|
|
216
222
|
|
|
217
223
|
return model
|
|
218
224
|
|
|
@@ -354,10 +360,17 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
354
360
|
)
|
|
355
361
|
|
|
356
362
|
model.cfg = _cfg
|
|
363
|
+
_build_model(model)
|
|
357
364
|
|
|
358
365
|
# Load pretrained parameters
|
|
359
366
|
if pretrained:
|
|
360
|
-
|
|
367
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
368
|
+
# skip the mismatching layers for fine tuning
|
|
369
|
+
load_pretrained_params(
|
|
370
|
+
model,
|
|
371
|
+
default_cfgs["resnet50"]["url"],
|
|
372
|
+
skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
|
|
373
|
+
)
|
|
361
374
|
|
|
362
375
|
return model
|
|
363
376
|
|
|
@@ -22,21 +22,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
22
22
|
"std": (0.299, 0.296, 0.301),
|
|
23
23
|
"input_shape": (3, 32, 32),
|
|
24
24
|
"classes": list(VOCABS["french"]),
|
|
25
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
25
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-27288d12.pt&src=0",
|
|
26
26
|
},
|
|
27
27
|
"textnet_small": {
|
|
28
28
|
"mean": (0.694, 0.695, 0.693),
|
|
29
29
|
"std": (0.299, 0.296, 0.301),
|
|
30
30
|
"input_shape": (3, 32, 32),
|
|
31
31
|
"classes": list(VOCABS["french"]),
|
|
32
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
32
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-43166ee6.pt&src=0",
|
|
33
33
|
},
|
|
34
34
|
"textnet_base": {
|
|
35
35
|
"mean": (0.694, 0.695, 0.693),
|
|
36
36
|
"std": (0.299, 0.296, 0.301),
|
|
37
37
|
"input_shape": (3, 32, 32),
|
|
38
38
|
"classes": list(VOCABS["french"]),
|
|
39
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
39
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-7f68d7e0.pt&src=0",
|
|
40
40
|
},
|
|
41
41
|
}
|
|
42
42
|
|
|
@@ -12,7 +12,7 @@ from tensorflow.keras import Sequential, layers
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
|
|
14
14
|
from ...modules.layers.tensorflow import FASTConvLayer
|
|
15
|
-
from ...utils import conv_sequence, load_pretrained_params
|
|
15
|
+
from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
16
16
|
|
|
17
17
|
__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
|
|
18
18
|
|
|
@@ -22,21 +22,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
22
22
|
"std": (0.299, 0.296, 0.301),
|
|
23
23
|
"input_shape": (32, 32, 3),
|
|
24
24
|
"classes": list(VOCABS["french"]),
|
|
25
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
25
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
|
|
26
26
|
},
|
|
27
27
|
"textnet_small": {
|
|
28
28
|
"mean": (0.694, 0.695, 0.693),
|
|
29
29
|
"std": (0.299, 0.296, 0.301),
|
|
30
30
|
"input_shape": (32, 32, 3),
|
|
31
31
|
"classes": list(VOCABS["french"]),
|
|
32
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
32
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
|
|
33
33
|
},
|
|
34
34
|
"textnet_base": {
|
|
35
35
|
"mean": (0.694, 0.695, 0.693),
|
|
36
36
|
"std": (0.299, 0.296, 0.301),
|
|
37
37
|
"input_shape": (32, 32, 3),
|
|
38
38
|
"classes": list(VOCABS["french"]),
|
|
39
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
39
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
|
|
40
40
|
},
|
|
41
41
|
}
|
|
42
42
|
|
|
@@ -111,9 +111,15 @@ def _textnet(
|
|
|
111
111
|
|
|
112
112
|
# Build the model
|
|
113
113
|
model = TextNet(cfg=_cfg, **kwargs)
|
|
114
|
+
_build_model(model)
|
|
115
|
+
|
|
114
116
|
# Load pretrained parameters
|
|
115
117
|
if pretrained:
|
|
116
|
-
|
|
118
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
119
|
+
# skip the mismatching layers for fine tuning
|
|
120
|
+
load_pretrained_params(
|
|
121
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
122
|
+
)
|
|
117
123
|
|
|
118
124
|
return model
|
|
119
125
|
|
|
@@ -11,7 +11,7 @@ from tensorflow.keras.models import Sequential
|
|
|
11
11
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...utils import conv_sequence, load_pretrained_params
|
|
14
|
+
from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
15
15
|
|
|
16
16
|
__all__ = ["VGG", "vgg16_bn_r"]
|
|
17
17
|
|
|
@@ -22,7 +22,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
22
22
|
"std": (1.0, 1.0, 1.0),
|
|
23
23
|
"input_shape": (32, 32, 3),
|
|
24
24
|
"classes": list(VOCABS["french"]),
|
|
25
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
25
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
|
|
26
26
|
},
|
|
27
27
|
}
|
|
28
28
|
|
|
@@ -81,9 +81,15 @@ def _vgg(
|
|
|
81
81
|
|
|
82
82
|
# Build the model
|
|
83
83
|
model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
|
|
84
|
+
_build_model(model)
|
|
85
|
+
|
|
84
86
|
# Load pretrained parameters
|
|
85
87
|
if pretrained:
|
|
86
|
-
|
|
88
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
89
|
+
# skip the mismatching layers for fine tuning
|
|
90
|
+
load_pretrained_params(
|
|
91
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
92
|
+
)
|
|
87
93
|
|
|
88
94
|
return model
|
|
89
95
|
|
|
@@ -14,7 +14,7 @@ from doctr.models.modules.transformer import EncoderBlock
|
|
|
14
14
|
from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding
|
|
15
15
|
from doctr.utils.repr import NestedObject
|
|
16
16
|
|
|
17
|
-
from ...utils import load_pretrained_params
|
|
17
|
+
from ...utils import _build_model, load_pretrained_params
|
|
18
18
|
|
|
19
19
|
__all__ = ["vit_s", "vit_b"]
|
|
20
20
|
|
|
@@ -25,14 +25,14 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
25
25
|
"std": (0.299, 0.296, 0.301),
|
|
26
26
|
"input_shape": (3, 32, 32),
|
|
27
27
|
"classes": list(VOCABS["french"]),
|
|
28
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
28
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0",
|
|
29
29
|
},
|
|
30
30
|
"vit_b": {
|
|
31
31
|
"mean": (0.694, 0.695, 0.693),
|
|
32
32
|
"std": (0.299, 0.296, 0.301),
|
|
33
33
|
"input_shape": (32, 32, 3),
|
|
34
34
|
"classes": list(VOCABS["french"]),
|
|
35
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
35
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0",
|
|
36
36
|
},
|
|
37
37
|
}
|
|
38
38
|
|
|
@@ -121,9 +121,15 @@ def _vit(
|
|
|
121
121
|
|
|
122
122
|
# Build the model
|
|
123
123
|
model = VisionTransformer(cfg=_cfg, **kwargs)
|
|
124
|
+
_build_model(model)
|
|
125
|
+
|
|
124
126
|
# Load pretrained parameters
|
|
125
127
|
if pretrained:
|
|
126
|
-
|
|
128
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
129
|
+
# skip the mismatching layers for fine tuning
|
|
130
|
+
load_pretrained_params(
|
|
131
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
132
|
+
)
|
|
127
133
|
|
|
128
134
|
return model
|
|
129
135
|
|
|
@@ -9,9 +9,9 @@ from doctr.file_utils import is_tf_available
|
|
|
9
9
|
|
|
10
10
|
from .. import classification
|
|
11
11
|
from ..preprocessor import PreProcessor
|
|
12
|
-
from .predictor import
|
|
12
|
+
from .predictor import OrientationPredictor
|
|
13
13
|
|
|
14
|
-
__all__ = ["crop_orientation_predictor"]
|
|
14
|
+
__all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
|
|
15
15
|
|
|
16
16
|
ARCHS: List[str] = [
|
|
17
17
|
"magc_resnet31",
|
|
@@ -31,44 +31,80 @@ ARCHS: List[str] = [
|
|
|
31
31
|
"vit_s",
|
|
32
32
|
"vit_b",
|
|
33
33
|
]
|
|
34
|
-
ORIENTATION_ARCHS: List[str] = ["
|
|
34
|
+
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
def
|
|
38
|
-
|
|
39
|
-
|
|
37
|
+
def _orientation_predictor(
|
|
38
|
+
arch: Any, pretrained: bool, model_type: str, disabled: bool = False, **kwargs: Any
|
|
39
|
+
) -> OrientationPredictor:
|
|
40
|
+
if disabled:
|
|
41
|
+
# Case where the orientation predictor is disabled
|
|
42
|
+
return OrientationPredictor(None, None)
|
|
43
|
+
|
|
44
|
+
if isinstance(arch, str):
|
|
45
|
+
if arch not in ORIENTATION_ARCHS:
|
|
46
|
+
raise ValueError(f"unknown architecture '{arch}'")
|
|
47
|
+
|
|
48
|
+
# Load directly classifier from backbone
|
|
49
|
+
_model = classification.__dict__[arch](pretrained=pretrained)
|
|
50
|
+
else:
|
|
51
|
+
if not isinstance(arch, classification.MobileNetV3):
|
|
52
|
+
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
53
|
+
_model = arch
|
|
40
54
|
|
|
41
|
-
# Load directly classifier from backbone
|
|
42
|
-
_model = classification.__dict__[arch](pretrained=pretrained)
|
|
43
55
|
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
44
56
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
45
|
-
kwargs["batch_size"] = kwargs.get("batch_size",
|
|
57
|
+
kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
|
|
46
58
|
input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
|
|
47
|
-
predictor =
|
|
59
|
+
predictor = OrientationPredictor(
|
|
48
60
|
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
|
|
49
61
|
)
|
|
50
62
|
return predictor
|
|
51
63
|
|
|
52
64
|
|
|
53
65
|
def crop_orientation_predictor(
|
|
54
|
-
arch:
|
|
55
|
-
) ->
|
|
56
|
-
"""
|
|
66
|
+
arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
|
|
67
|
+
) -> OrientationPredictor:
|
|
68
|
+
"""Crop orientation classification architecture.
|
|
57
69
|
|
|
58
70
|
>>> import numpy as np
|
|
59
71
|
>>> from doctr.models import crop_orientation_predictor
|
|
60
|
-
>>> model = crop_orientation_predictor(arch='
|
|
61
|
-
>>> input_crop = (255 * np.random.rand(
|
|
72
|
+
>>> model = crop_orientation_predictor(arch='mobilenet_v3_small_crop_orientation', pretrained=True)
|
|
73
|
+
>>> input_crop = (255 * np.random.rand(256, 256, 3)).astype(np.uint8)
|
|
62
74
|
>>> out = model([input_crop])
|
|
63
75
|
|
|
64
76
|
Args:
|
|
65
77
|
----
|
|
66
|
-
arch: name of the architecture to use (e.g. '
|
|
78
|
+
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
|
|
79
|
+
pretrained: If True, returns a model pre-trained on our recognition crops dataset
|
|
80
|
+
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
-------
|
|
84
|
+
OrientationPredictor
|
|
85
|
+
"""
|
|
86
|
+
return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def page_orientation_predictor(
|
|
90
|
+
arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
|
|
91
|
+
) -> OrientationPredictor:
|
|
92
|
+
"""Page orientation classification architecture.
|
|
93
|
+
|
|
94
|
+
>>> import numpy as np
|
|
95
|
+
>>> from doctr.models import page_orientation_predictor
|
|
96
|
+
>>> model = page_orientation_predictor(arch='mobilenet_v3_small_page_orientation', pretrained=True)
|
|
97
|
+
>>> input_page = (255 * np.random.rand(512, 512, 3)).astype(np.uint8)
|
|
98
|
+
>>> out = model([input_page])
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
----
|
|
102
|
+
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
|
|
67
103
|
pretrained: If True, returns a model pre-trained on our recognition crops dataset
|
|
68
|
-
**kwargs: keyword arguments to be passed to the
|
|
104
|
+
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
69
105
|
|
|
70
106
|
Returns:
|
|
71
107
|
-------
|
|
72
|
-
|
|
108
|
+
OrientationPredictor
|
|
73
109
|
"""
|
|
74
|
-
return
|
|
110
|
+
return _orientation_predictor(arch, pretrained, model_type="page", **kwargs)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Dict, List
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
__all__ = ["_remove_padding"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _remove_padding(
|
|
14
|
+
pages: List[np.ndarray],
|
|
15
|
+
loc_preds: List[Dict[str, np.ndarray]],
|
|
16
|
+
preserve_aspect_ratio: bool,
|
|
17
|
+
symmetric_pad: bool,
|
|
18
|
+
assume_straight_pages: bool,
|
|
19
|
+
) -> List[Dict[str, np.ndarray]]:
|
|
20
|
+
"""Remove padding from the localization predictions
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
----
|
|
24
|
+
pages: list of pages
|
|
25
|
+
loc_preds: list of localization predictions
|
|
26
|
+
preserve_aspect_ratio: whether the aspect ratio was preserved during padding
|
|
27
|
+
symmetric_pad: whether the padding was symmetric
|
|
28
|
+
assume_straight_pages: whether the pages are assumed to be straight
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
-------
|
|
32
|
+
list of unpaded localization predictions
|
|
33
|
+
"""
|
|
34
|
+
if preserve_aspect_ratio:
|
|
35
|
+
# Rectify loc_preds to remove padding
|
|
36
|
+
rectified_preds = []
|
|
37
|
+
for page, dict_loc_preds in zip(pages, loc_preds):
|
|
38
|
+
for k, loc_pred in dict_loc_preds.items():
|
|
39
|
+
h, w = page.shape[0], page.shape[1]
|
|
40
|
+
if h > w:
|
|
41
|
+
# y unchanged, dilate x coord
|
|
42
|
+
if symmetric_pad:
|
|
43
|
+
if assume_straight_pages:
|
|
44
|
+
loc_pred[:, [0, 2]] = (loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5
|
|
45
|
+
else:
|
|
46
|
+
loc_pred[:, :, 0] = (loc_pred[:, :, 0] - 0.5) * h / w + 0.5
|
|
47
|
+
else:
|
|
48
|
+
if assume_straight_pages:
|
|
49
|
+
loc_pred[:, [0, 2]] *= h / w
|
|
50
|
+
else:
|
|
51
|
+
loc_pred[:, :, 0] *= h / w
|
|
52
|
+
elif w > h:
|
|
53
|
+
# x unchanged, dilate y coord
|
|
54
|
+
if symmetric_pad:
|
|
55
|
+
if assume_straight_pages:
|
|
56
|
+
loc_pred[:, [1, 3]] = (loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5
|
|
57
|
+
else:
|
|
58
|
+
loc_pred[:, :, 1] = (loc_pred[:, :, 1] - 0.5) * w / h + 0.5
|
|
59
|
+
else:
|
|
60
|
+
if assume_straight_pages:
|
|
61
|
+
loc_pred[:, [1, 3]] *= w / h
|
|
62
|
+
else:
|
|
63
|
+
loc_pred[:, :, 1] *= w / h
|
|
64
|
+
rectified_preds.append({k: np.clip(loc_pred, 0, 1)})
|
|
65
|
+
return rectified_preds
|
|
66
|
+
return loc_preds
|
|
@@ -114,7 +114,7 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
114
114
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
115
115
|
for contour in contours:
|
|
116
116
|
# Check whether smallest enclosing bounding box is not too small
|
|
117
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
|
|
117
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): # type: ignore[index]
|
|
118
118
|
continue
|
|
119
119
|
# Compute objectness
|
|
120
120
|
if self.assume_straight_pages:
|
|
@@ -150,10 +150,11 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
150
150
|
raise AssertionError("When assume straight pages is false a box is a (4, 2) array (polygon)")
|
|
151
151
|
_box[:, 0] /= width
|
|
152
152
|
_box[:, 1] /= height
|
|
153
|
-
|
|
153
|
+
# Add score to box as (0, score)
|
|
154
|
+
boxes.append(np.vstack([_box, np.array([0.0, score])]))
|
|
154
155
|
|
|
155
156
|
if not self.assume_straight_pages:
|
|
156
|
-
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0,
|
|
157
|
+
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
|
|
157
158
|
else:
|
|
158
159
|
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
|
|
159
160
|
|
|
@@ -39,7 +39,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
39
39
|
"input_shape": (3, 1024, 1024),
|
|
40
40
|
"mean": (0.798, 0.785, 0.772),
|
|
41
41
|
"std": (0.264, 0.2749, 0.287),
|
|
42
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
42
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/db_mobilenet_v3_large-21748dd0.pt&src=0",
|
|
43
43
|
},
|
|
44
44
|
}
|
|
45
45
|
|
|
@@ -273,7 +273,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
273
273
|
dice_map = torch.softmax(out_map, dim=1)
|
|
274
274
|
else:
|
|
275
275
|
# compute binary map instead
|
|
276
|
-
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
|
|
276
|
+
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
|
|
277
277
|
# Class reduced
|
|
278
278
|
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
|
|
279
279
|
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
|