python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from torch import nn
|
|
@@ -18,7 +18,7 @@ from ...utils.pytorch import load_pretrained_params
|
|
|
18
18
|
__all__ = ["vit_s", "vit_b"]
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
default_cfgs:
|
|
21
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
22
22
|
"vit_s": {
|
|
23
23
|
"mean": (0.694, 0.695, 0.693),
|
|
24
24
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -40,7 +40,6 @@ class ClassifierHead(nn.Module):
|
|
|
40
40
|
"""Classifier head for Vision Transformer
|
|
41
41
|
|
|
42
42
|
Args:
|
|
43
|
-
----
|
|
44
43
|
in_channels: number of input channels
|
|
45
44
|
num_classes: number of output classes
|
|
46
45
|
"""
|
|
@@ -65,7 +64,6 @@ class VisionTransformer(nn.Sequential):
|
|
|
65
64
|
<https://arxiv.org/pdf/2010.11929.pdf>`_.
|
|
66
65
|
|
|
67
66
|
Args:
|
|
68
|
-
----
|
|
69
67
|
d_model: dimension of the transformer layers
|
|
70
68
|
num_layers: number of transformer layers
|
|
71
69
|
num_heads: number of attention heads
|
|
@@ -83,14 +81,14 @@ class VisionTransformer(nn.Sequential):
|
|
|
83
81
|
num_layers: int,
|
|
84
82
|
num_heads: int,
|
|
85
83
|
ffd_ratio: int,
|
|
86
|
-
patch_size:
|
|
87
|
-
input_shape:
|
|
84
|
+
patch_size: tuple[int, int] = (4, 4),
|
|
85
|
+
input_shape: tuple[int, int, int] = (3, 32, 32),
|
|
88
86
|
dropout: float = 0.0,
|
|
89
87
|
num_classes: int = 1000,
|
|
90
88
|
include_top: bool = True,
|
|
91
|
-
cfg:
|
|
89
|
+
cfg: dict[str, Any] | None = None,
|
|
92
90
|
) -> None:
|
|
93
|
-
_layers:
|
|
91
|
+
_layers: list[nn.Module] = [
|
|
94
92
|
PatchEmbedding(input_shape, d_model, patch_size),
|
|
95
93
|
EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, nn.GELU()),
|
|
96
94
|
]
|
|
@@ -104,7 +102,7 @@ class VisionTransformer(nn.Sequential):
|
|
|
104
102
|
def _vit(
|
|
105
103
|
arch: str,
|
|
106
104
|
pretrained: bool,
|
|
107
|
-
ignore_keys:
|
|
105
|
+
ignore_keys: list[str] | None = None,
|
|
108
106
|
**kwargs: Any,
|
|
109
107
|
) -> VisionTransformer:
|
|
110
108
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
@@ -143,12 +141,10 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
|
|
|
143
141
|
>>> out = model(input_tensor)
|
|
144
142
|
|
|
145
143
|
Args:
|
|
146
|
-
----
|
|
147
144
|
pretrained: boolean, True if model is pretrained
|
|
148
145
|
**kwargs: keyword arguments of the VisionTransformer architecture
|
|
149
146
|
|
|
150
147
|
Returns:
|
|
151
|
-
-------
|
|
152
148
|
A feature extractor model
|
|
153
149
|
"""
|
|
154
150
|
return _vit(
|
|
@@ -175,12 +171,10 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
|
|
|
175
171
|
>>> out = model(input_tensor)
|
|
176
172
|
|
|
177
173
|
Args:
|
|
178
|
-
----
|
|
179
174
|
pretrained: boolean, True if model is pretrained
|
|
180
175
|
**kwargs: keyword arguments of the VisionTransformer architecture
|
|
181
176
|
|
|
182
177
|
Returns:
|
|
183
|
-
-------
|
|
184
178
|
A feature extractor model
|
|
185
179
|
"""
|
|
186
180
|
return _vit(
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import Sequential, layers
|
|
@@ -14,25 +14,25 @@ from doctr.models.modules.transformer import EncoderBlock
|
|
|
14
14
|
from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding
|
|
15
15
|
from doctr.utils.repr import NestedObject
|
|
16
16
|
|
|
17
|
-
from ...utils import load_pretrained_params
|
|
17
|
+
from ...utils import _build_model, load_pretrained_params
|
|
18
18
|
|
|
19
19
|
__all__ = ["vit_s", "vit_b"]
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
default_cfgs:
|
|
22
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
23
|
"vit_s": {
|
|
24
24
|
"mean": (0.694, 0.695, 0.693),
|
|
25
25
|
"std": (0.299, 0.296, 0.301),
|
|
26
26
|
"input_shape": (3, 32, 32),
|
|
27
27
|
"classes": list(VOCABS["french"]),
|
|
28
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
28
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0",
|
|
29
29
|
},
|
|
30
30
|
"vit_b": {
|
|
31
31
|
"mean": (0.694, 0.695, 0.693),
|
|
32
32
|
"std": (0.299, 0.296, 0.301),
|
|
33
33
|
"input_shape": (32, 32, 3),
|
|
34
34
|
"classes": list(VOCABS["french"]),
|
|
35
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
35
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0",
|
|
36
36
|
},
|
|
37
37
|
}
|
|
38
38
|
|
|
@@ -41,7 +41,6 @@ class ClassifierHead(layers.Layer, NestedObject):
|
|
|
41
41
|
"""Classifier head for Vision Transformer
|
|
42
42
|
|
|
43
43
|
Args:
|
|
44
|
-
----
|
|
45
44
|
num_classes: number of output classes
|
|
46
45
|
"""
|
|
47
46
|
|
|
@@ -61,7 +60,6 @@ class VisionTransformer(Sequential):
|
|
|
61
60
|
<https://arxiv.org/pdf/2010.11929.pdf>`_.
|
|
62
61
|
|
|
63
62
|
Args:
|
|
64
|
-
----
|
|
65
63
|
d_model: dimension of the transformer layers
|
|
66
64
|
num_layers: number of transformer layers
|
|
67
65
|
num_heads: number of attention heads
|
|
@@ -79,12 +77,12 @@ class VisionTransformer(Sequential):
|
|
|
79
77
|
num_layers: int,
|
|
80
78
|
num_heads: int,
|
|
81
79
|
ffd_ratio: int,
|
|
82
|
-
patch_size:
|
|
83
|
-
input_shape:
|
|
80
|
+
patch_size: tuple[int, int] = (4, 4),
|
|
81
|
+
input_shape: tuple[int, int, int] = (32, 32, 3),
|
|
84
82
|
dropout: float = 0.0,
|
|
85
83
|
num_classes: int = 1000,
|
|
86
84
|
include_top: bool = True,
|
|
87
|
-
cfg:
|
|
85
|
+
cfg: dict[str, Any] | None = None,
|
|
88
86
|
) -> None:
|
|
89
87
|
_layers = [
|
|
90
88
|
PatchEmbedding(input_shape, d_model, patch_size),
|
|
@@ -121,9 +119,15 @@ def _vit(
|
|
|
121
119
|
|
|
122
120
|
# Build the model
|
|
123
121
|
model = VisionTransformer(cfg=_cfg, **kwargs)
|
|
122
|
+
_build_model(model)
|
|
123
|
+
|
|
124
124
|
# Load pretrained parameters
|
|
125
125
|
if pretrained:
|
|
126
|
-
|
|
126
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
127
|
+
# skip the mismatching layers for fine tuning
|
|
128
|
+
load_pretrained_params(
|
|
129
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
130
|
+
)
|
|
127
131
|
|
|
128
132
|
return model
|
|
129
133
|
|
|
@@ -142,12 +146,10 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
|
|
|
142
146
|
>>> out = model(input_tensor)
|
|
143
147
|
|
|
144
148
|
Args:
|
|
145
|
-
----
|
|
146
149
|
pretrained: boolean, True if model is pretrained
|
|
147
150
|
**kwargs: keyword arguments of the VisionTransformer architecture
|
|
148
151
|
|
|
149
152
|
Returns:
|
|
150
|
-
-------
|
|
151
153
|
A feature extractor model
|
|
152
154
|
"""
|
|
153
155
|
return _vit(
|
|
@@ -173,12 +175,10 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
|
|
|
173
175
|
>>> out = model(input_tensor)
|
|
174
176
|
|
|
175
177
|
Args:
|
|
176
|
-
----
|
|
177
178
|
pretrained: boolean, True if model is pretrained
|
|
178
179
|
**kwargs: keyword arguments of the VisionTransformer architecture
|
|
179
180
|
|
|
180
181
|
Returns:
|
|
181
|
-
-------
|
|
182
182
|
A feature extractor model
|
|
183
183
|
"""
|
|
184
184
|
return _vit(
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.file_utils import is_tf_available
|
|
8
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
9
9
|
|
|
10
10
|
from .. import classification
|
|
11
11
|
from ..preprocessor import PreProcessor
|
|
@@ -13,7 +13,7 @@ from .predictor import OrientationPredictor
|
|
|
13
13
|
|
|
14
14
|
__all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
|
|
15
15
|
|
|
16
|
-
ARCHS:
|
|
16
|
+
ARCHS: list[str] = [
|
|
17
17
|
"magc_resnet31",
|
|
18
18
|
"mobilenet_v3_small",
|
|
19
19
|
"mobilenet_v3_small_r",
|
|
@@ -31,18 +31,37 @@ ARCHS: List[str] = [
|
|
|
31
31
|
"vit_s",
|
|
32
32
|
"vit_b",
|
|
33
33
|
]
|
|
34
|
-
ORIENTATION_ARCHS:
|
|
34
|
+
ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
def _orientation_predictor(
|
|
38
|
-
|
|
39
|
-
|
|
37
|
+
def _orientation_predictor(
|
|
38
|
+
arch: Any, pretrained: bool, model_type: str, disabled: bool = False, **kwargs: Any
|
|
39
|
+
) -> OrientationPredictor:
|
|
40
|
+
if disabled:
|
|
41
|
+
# Case where the orientation predictor is disabled
|
|
42
|
+
return OrientationPredictor(None, None)
|
|
43
|
+
|
|
44
|
+
if isinstance(arch, str):
|
|
45
|
+
if arch not in ORIENTATION_ARCHS:
|
|
46
|
+
raise ValueError(f"unknown architecture '{arch}'")
|
|
47
|
+
|
|
48
|
+
# Load directly classifier from backbone
|
|
49
|
+
_model = classification.__dict__[arch](pretrained=pretrained)
|
|
50
|
+
else:
|
|
51
|
+
allowed_archs = [classification.MobileNetV3]
|
|
52
|
+
if is_torch_available():
|
|
53
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
54
|
+
from doctr.models.utils import _CompiledModule
|
|
55
|
+
|
|
56
|
+
allowed_archs.append(_CompiledModule)
|
|
57
|
+
|
|
58
|
+
if not isinstance(arch, tuple(allowed_archs)):
|
|
59
|
+
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
60
|
+
_model = arch
|
|
40
61
|
|
|
41
|
-
# Load directly classifier from backbone
|
|
42
|
-
_model = classification.__dict__[arch](pretrained=pretrained)
|
|
43
62
|
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
44
63
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
45
|
-
kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop"
|
|
64
|
+
kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
|
|
46
65
|
input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
|
|
47
66
|
predictor = OrientationPredictor(
|
|
48
67
|
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
|
|
@@ -51,7 +70,7 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient
|
|
|
51
70
|
|
|
52
71
|
|
|
53
72
|
def crop_orientation_predictor(
|
|
54
|
-
arch:
|
|
73
|
+
arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, batch_size: int = 128, **kwargs: Any
|
|
55
74
|
) -> OrientationPredictor:
|
|
56
75
|
"""Crop orientation classification architecture.
|
|
57
76
|
|
|
@@ -62,20 +81,19 @@ def crop_orientation_predictor(
|
|
|
62
81
|
>>> out = model([input_crop])
|
|
63
82
|
|
|
64
83
|
Args:
|
|
65
|
-
----
|
|
66
84
|
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
|
|
67
85
|
pretrained: If True, returns a model pre-trained on our recognition crops dataset
|
|
86
|
+
batch_size: number of samples the model processes in parallel
|
|
68
87
|
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
69
88
|
|
|
70
89
|
Returns:
|
|
71
|
-
-------
|
|
72
90
|
OrientationPredictor
|
|
73
91
|
"""
|
|
74
|
-
return _orientation_predictor(arch, pretrained, **kwargs)
|
|
92
|
+
return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="crop", **kwargs)
|
|
75
93
|
|
|
76
94
|
|
|
77
95
|
def page_orientation_predictor(
|
|
78
|
-
arch:
|
|
96
|
+
arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, batch_size: int = 4, **kwargs: Any
|
|
79
97
|
) -> OrientationPredictor:
|
|
80
98
|
"""Page orientation classification architecture.
|
|
81
99
|
|
|
@@ -86,13 +104,12 @@ def page_orientation_predictor(
|
|
|
86
104
|
>>> out = model([input_page])
|
|
87
105
|
|
|
88
106
|
Args:
|
|
89
|
-
----
|
|
90
107
|
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
|
|
91
108
|
pretrained: If True, returns a model pre-trained on our recognition crops dataset
|
|
109
|
+
batch_size: number of samples the model processes in parallel
|
|
92
110
|
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
93
111
|
|
|
94
112
|
Returns:
|
|
95
|
-
-------
|
|
96
113
|
OrientationPredictor
|
|
97
114
|
"""
|
|
98
|
-
return _orientation_predictor(arch, pretrained, **kwargs)
|
|
115
|
+
return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="page", **kwargs)
|
doctr/models/core.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
from doctr.utils.repr import NestedObject
|
|
10
10
|
|
|
@@ -14,6 +14,6 @@ __all__ = ["BaseModel"]
|
|
|
14
14
|
class BaseModel(NestedObject):
|
|
15
15
|
"""Implements abstract DetectionModel class"""
|
|
16
16
|
|
|
17
|
-
def __init__(self, cfg:
|
|
17
|
+
def __init__(self, cfg: dict[str, Any] | None = None) -> None:
|
|
18
18
|
super().__init__()
|
|
19
19
|
self.cfg = cfg
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
from .base import *
|
|
3
3
|
|
|
4
|
-
if
|
|
5
|
-
from .tensorflow import *
|
|
6
|
-
else:
|
|
4
|
+
if is_torch_available():
|
|
7
5
|
from .pytorch import *
|
|
6
|
+
elif is_tf_available():
|
|
7
|
+
from .tensorflow import *
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Dict, List
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -11,16 +10,15 @@ __all__ = ["_remove_padding"]
|
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
def _remove_padding(
|
|
14
|
-
pages:
|
|
15
|
-
loc_preds:
|
|
13
|
+
pages: list[np.ndarray],
|
|
14
|
+
loc_preds: list[dict[str, np.ndarray]],
|
|
16
15
|
preserve_aspect_ratio: bool,
|
|
17
16
|
symmetric_pad: bool,
|
|
18
17
|
assume_straight_pages: bool,
|
|
19
|
-
) ->
|
|
18
|
+
) -> list[dict[str, np.ndarray]]:
|
|
20
19
|
"""Remove padding from the localization predictions
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
----
|
|
24
22
|
pages: list of pages
|
|
25
23
|
loc_preds: list of localization predictions
|
|
26
24
|
preserve_aspect_ratio: whether the aspect ratio was preserved during padding
|
|
@@ -28,7 +26,6 @@ def _remove_padding(
|
|
|
28
26
|
assume_straight_pages: whether the pages are assumed to be straight
|
|
29
27
|
|
|
30
28
|
Returns:
|
|
31
|
-
-------
|
|
32
29
|
list of unpaded localization predictions
|
|
33
30
|
"""
|
|
34
31
|
if preserve_aspect_ratio:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -13,12 +13,10 @@ def erode(x: Tensor, kernel_size: int) -> Tensor:
|
|
|
13
13
|
"""Performs erosion on a given tensor
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
|
-
----
|
|
17
16
|
x: boolean tensor of shape (N, C, H, W)
|
|
18
17
|
kernel_size: the size of the kernel to use for erosion
|
|
19
18
|
|
|
20
19
|
Returns:
|
|
21
|
-
-------
|
|
22
20
|
the eroded tensor
|
|
23
21
|
"""
|
|
24
22
|
_pad = (kernel_size - 1) // 2
|
|
@@ -30,12 +28,10 @@ def dilate(x: Tensor, kernel_size: int) -> Tensor:
|
|
|
30
28
|
"""Performs dilation on a given tensor
|
|
31
29
|
|
|
32
30
|
Args:
|
|
33
|
-
----
|
|
34
31
|
x: boolean tensor of shape (N, C, H, W)
|
|
35
32
|
kernel_size: the size of the kernel to use for dilation
|
|
36
33
|
|
|
37
34
|
Returns:
|
|
38
|
-
-------
|
|
39
35
|
the dilated tensor
|
|
40
36
|
"""
|
|
41
37
|
_pad = (kernel_size - 1) // 2
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -12,12 +12,10 @@ def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
|
|
|
12
12
|
"""Performs erosion on a given tensor
|
|
13
13
|
|
|
14
14
|
Args:
|
|
15
|
-
----
|
|
16
15
|
x: boolean tensor of shape (N, H, W, C)
|
|
17
16
|
kernel_size: the size of the kernel to use for erosion
|
|
18
17
|
|
|
19
18
|
Returns:
|
|
20
|
-
-------
|
|
21
19
|
the eroded tensor
|
|
22
20
|
"""
|
|
23
21
|
return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME")
|
|
@@ -27,12 +25,10 @@ def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
|
|
|
27
25
|
"""Performs dilation on a given tensor
|
|
28
26
|
|
|
29
27
|
Args:
|
|
30
|
-
----
|
|
31
28
|
x: boolean tensor of shape (N, H, W, C)
|
|
32
29
|
kernel_size: the size of the kernel to use for dilation
|
|
33
30
|
|
|
34
31
|
Returns:
|
|
35
|
-
-------
|
|
36
32
|
the dilated tensor
|
|
37
33
|
"""
|
|
38
34
|
return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME")
|
doctr/models/detection/core.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List
|
|
7
6
|
|
|
8
7
|
import cv2
|
|
9
8
|
import numpy as np
|
|
@@ -17,7 +16,6 @@ class DetectionPostProcessor(NestedObject):
|
|
|
17
16
|
"""Abstract class to postprocess the raw output of the model
|
|
18
17
|
|
|
19
18
|
Args:
|
|
20
|
-
----
|
|
21
19
|
box_thresh (float): minimal objectness score to consider a box
|
|
22
20
|
bin_thresh (float): threshold to apply to segmentation raw heatmap
|
|
23
21
|
assume straight_pages (bool): if True, fit straight boxes only
|
|
@@ -37,13 +35,11 @@ class DetectionPostProcessor(NestedObject):
|
|
|
37
35
|
"""Compute the confidence score for a polygon : mean of the p values on the polygon
|
|
38
36
|
|
|
39
37
|
Args:
|
|
40
|
-
----
|
|
41
38
|
pred (np.ndarray): p map returned by the model
|
|
42
39
|
points: coordinates of the polygon
|
|
43
40
|
assume_straight_pages: if True, fit straight boxes only
|
|
44
41
|
|
|
45
42
|
Returns:
|
|
46
|
-
-------
|
|
47
43
|
polygon objectness
|
|
48
44
|
"""
|
|
49
45
|
h, w = pred.shape[:2]
|
|
@@ -71,15 +67,13 @@ class DetectionPostProcessor(NestedObject):
|
|
|
71
67
|
def __call__(
|
|
72
68
|
self,
|
|
73
69
|
proba_map,
|
|
74
|
-
) ->
|
|
70
|
+
) -> list[list[np.ndarray]]:
|
|
75
71
|
"""Performs postprocessing for a list of model outputs
|
|
76
72
|
|
|
77
73
|
Args:
|
|
78
|
-
----
|
|
79
74
|
proba_map: probability map of shape (N, H, W, C)
|
|
80
75
|
|
|
81
76
|
Returns:
|
|
82
|
-
-------
|
|
83
77
|
list of N class predictions (for each input sample), where each class predictions is a list of C tensors
|
|
84
78
|
of shape (*, 5) or (*, 6)
|
|
85
79
|
"""
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
|
-
from typing import Dict, List, Tuple, Union
|
|
9
8
|
|
|
10
9
|
import cv2
|
|
11
10
|
import numpy as np
|
|
@@ -22,7 +21,6 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
22
21
|
<https://github.com/xuannianz/DifferentiableBinarization>`_.
|
|
23
22
|
|
|
24
23
|
Args:
|
|
25
|
-
----
|
|
26
24
|
unclip ratio: ratio used to unshrink polygons
|
|
27
25
|
min_size_box: minimal length (pix) to keep a box
|
|
28
26
|
max_candidates: maximum boxes to consider in a single page
|
|
@@ -47,11 +45,9 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
47
45
|
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
48
46
|
|
|
49
47
|
Args:
|
|
50
|
-
----
|
|
51
48
|
points: The first parameter.
|
|
52
49
|
|
|
53
50
|
Returns:
|
|
54
|
-
-------
|
|
55
51
|
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
56
52
|
"""
|
|
57
53
|
if not self.assume_straight_pages:
|
|
@@ -96,25 +92,23 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
96
92
|
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
97
93
|
|
|
98
94
|
Args:
|
|
99
|
-
----
|
|
100
95
|
pred: Pred map from differentiable binarization output
|
|
101
96
|
bitmap: Bitmap map computed from pred (binarized)
|
|
102
97
|
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
103
98
|
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
104
99
|
|
|
105
100
|
Returns:
|
|
106
|
-
-------
|
|
107
101
|
np tensor boxes for the bitmap, each box is a 5-element list
|
|
108
102
|
containing x, y, w, h, score for the box
|
|
109
103
|
"""
|
|
110
104
|
height, width = bitmap.shape[:2]
|
|
111
105
|
min_size_box = 2
|
|
112
|
-
boxes:
|
|
106
|
+
boxes: list[np.ndarray | list[float]] = []
|
|
113
107
|
# get contours from connected components on the bitmap
|
|
114
108
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
115
109
|
for contour in contours:
|
|
116
110
|
# Check whether smallest enclosing bounding box is not too small
|
|
117
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
|
|
111
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
|
|
118
112
|
continue
|
|
119
113
|
# Compute objectness
|
|
120
114
|
if self.assume_straight_pages:
|
|
@@ -164,7 +158,6 @@ class _DBNet:
|
|
|
164
158
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
165
159
|
|
|
166
160
|
Args:
|
|
167
|
-
----
|
|
168
161
|
feature extractor: the backbone serving as feature extractor
|
|
169
162
|
fpn_channels: number of channels each extracted feature maps is mapped to
|
|
170
163
|
"""
|
|
@@ -186,7 +179,6 @@ class _DBNet:
|
|
|
186
179
|
"""Compute the distance for each point of the map (xs, ys) to the (a, b) segment
|
|
187
180
|
|
|
188
181
|
Args:
|
|
189
|
-
----
|
|
190
182
|
xs : map of x coordinates (height, width)
|
|
191
183
|
ys : map of y coordinates (height, width)
|
|
192
184
|
a: first point defining the [ab] segment
|
|
@@ -194,7 +186,6 @@ class _DBNet:
|
|
|
194
186
|
eps: epsilon to avoid division by zero
|
|
195
187
|
|
|
196
188
|
Returns:
|
|
197
|
-
-------
|
|
198
189
|
The computed distance
|
|
199
190
|
|
|
200
191
|
"""
|
|
@@ -214,11 +205,10 @@ class _DBNet:
|
|
|
214
205
|
polygon: np.ndarray,
|
|
215
206
|
canvas: np.ndarray,
|
|
216
207
|
mask: np.ndarray,
|
|
217
|
-
) ->
|
|
208
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
218
209
|
"""Draw a polygon treshold map on a canvas, as described in the DB paper
|
|
219
210
|
|
|
220
211
|
Args:
|
|
221
|
-
----
|
|
222
212
|
polygon : array of coord., to draw the boundary of the polygon
|
|
223
213
|
canvas : threshold map to fill with polygons
|
|
224
214
|
mask : mask for training on threshold polygons
|
|
@@ -278,10 +268,10 @@ class _DBNet:
|
|
|
278
268
|
|
|
279
269
|
def build_target(
|
|
280
270
|
self,
|
|
281
|
-
target:
|
|
282
|
-
output_shape:
|
|
271
|
+
target: list[dict[str, np.ndarray]],
|
|
272
|
+
output_shape: tuple[int, int, int],
|
|
283
273
|
channels_last: bool = True,
|
|
284
|
-
) ->
|
|
274
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
285
275
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|
|
286
276
|
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
|
|
287
277
|
if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
|