python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import Sequential, layers
|
|
@@ -19,7 +19,7 @@ from ...utils import _build_model, load_pretrained_params
|
|
|
19
19
|
__all__ = ["vit_s", "vit_b"]
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
default_cfgs:
|
|
22
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
23
|
"vit_s": {
|
|
24
24
|
"mean": (0.694, 0.695, 0.693),
|
|
25
25
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -41,7 +41,6 @@ class ClassifierHead(layers.Layer, NestedObject):
|
|
|
41
41
|
"""Classifier head for Vision Transformer
|
|
42
42
|
|
|
43
43
|
Args:
|
|
44
|
-
----
|
|
45
44
|
num_classes: number of output classes
|
|
46
45
|
"""
|
|
47
46
|
|
|
@@ -61,7 +60,6 @@ class VisionTransformer(Sequential):
|
|
|
61
60
|
<https://arxiv.org/pdf/2010.11929.pdf>`_.
|
|
62
61
|
|
|
63
62
|
Args:
|
|
64
|
-
----
|
|
65
63
|
d_model: dimension of the transformer layers
|
|
66
64
|
num_layers: number of transformer layers
|
|
67
65
|
num_heads: number of attention heads
|
|
@@ -79,12 +77,12 @@ class VisionTransformer(Sequential):
|
|
|
79
77
|
num_layers: int,
|
|
80
78
|
num_heads: int,
|
|
81
79
|
ffd_ratio: int,
|
|
82
|
-
patch_size:
|
|
83
|
-
input_shape:
|
|
80
|
+
patch_size: tuple[int, int] = (4, 4),
|
|
81
|
+
input_shape: tuple[int, int, int] = (32, 32, 3),
|
|
84
82
|
dropout: float = 0.0,
|
|
85
83
|
num_classes: int = 1000,
|
|
86
84
|
include_top: bool = True,
|
|
87
|
-
cfg:
|
|
85
|
+
cfg: dict[str, Any] | None = None,
|
|
88
86
|
) -> None:
|
|
89
87
|
_layers = [
|
|
90
88
|
PatchEmbedding(input_shape, d_model, patch_size),
|
|
@@ -103,6 +101,15 @@ class VisionTransformer(Sequential):
|
|
|
103
101
|
super().__init__(_layers)
|
|
104
102
|
self.cfg = cfg
|
|
105
103
|
|
|
104
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
105
|
+
"""Load pretrained parameters onto the model
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
109
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
110
|
+
"""
|
|
111
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
112
|
+
|
|
106
113
|
|
|
107
114
|
def _vit(
|
|
108
115
|
arch: str,
|
|
@@ -148,12 +155,10 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
|
|
|
148
155
|
>>> out = model(input_tensor)
|
|
149
156
|
|
|
150
157
|
Args:
|
|
151
|
-
----
|
|
152
158
|
pretrained: boolean, True if model is pretrained
|
|
153
159
|
**kwargs: keyword arguments of the VisionTransformer architecture
|
|
154
160
|
|
|
155
161
|
Returns:
|
|
156
|
-
-------
|
|
157
162
|
A feature extractor model
|
|
158
163
|
"""
|
|
159
164
|
return _vit(
|
|
@@ -179,12 +184,10 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
|
|
|
179
184
|
>>> out = model(input_tensor)
|
|
180
185
|
|
|
181
186
|
Args:
|
|
182
|
-
----
|
|
183
187
|
pretrained: boolean, True if model is pretrained
|
|
184
188
|
**kwargs: keyword arguments of the VisionTransformer architecture
|
|
185
189
|
|
|
186
190
|
Returns:
|
|
187
|
-
-------
|
|
188
191
|
A feature extractor model
|
|
189
192
|
"""
|
|
190
193
|
return _vit(
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.file_utils import is_tf_available
|
|
8
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
9
9
|
|
|
10
10
|
from .. import classification
|
|
11
11
|
from ..preprocessor import PreProcessor
|
|
@@ -13,7 +13,7 @@ from .predictor import OrientationPredictor
|
|
|
13
13
|
|
|
14
14
|
__all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
|
|
15
15
|
|
|
16
|
-
ARCHS:
|
|
16
|
+
ARCHS: list[str] = [
|
|
17
17
|
"magc_resnet31",
|
|
18
18
|
"mobilenet_v3_small",
|
|
19
19
|
"mobilenet_v3_small_r",
|
|
@@ -31,7 +31,11 @@ ARCHS: List[str] = [
|
|
|
31
31
|
"vit_s",
|
|
32
32
|
"vit_b",
|
|
33
33
|
]
|
|
34
|
-
|
|
34
|
+
|
|
35
|
+
if is_torch_available():
|
|
36
|
+
ARCHS.extend(["vip_tiny", "vip_base"])
|
|
37
|
+
|
|
38
|
+
ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
|
|
35
39
|
|
|
36
40
|
|
|
37
41
|
def _orientation_predictor(
|
|
@@ -48,7 +52,14 @@ def _orientation_predictor(
|
|
|
48
52
|
# Load directly classifier from backbone
|
|
49
53
|
_model = classification.__dict__[arch](pretrained=pretrained)
|
|
50
54
|
else:
|
|
51
|
-
|
|
55
|
+
allowed_archs = [classification.MobileNetV3]
|
|
56
|
+
if is_torch_available():
|
|
57
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
58
|
+
from doctr.models.utils import _CompiledModule
|
|
59
|
+
|
|
60
|
+
allowed_archs.append(_CompiledModule)
|
|
61
|
+
|
|
62
|
+
if not isinstance(arch, tuple(allowed_archs)):
|
|
52
63
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
53
64
|
_model = arch
|
|
54
65
|
|
|
@@ -63,7 +74,7 @@ def _orientation_predictor(
|
|
|
63
74
|
|
|
64
75
|
|
|
65
76
|
def crop_orientation_predictor(
|
|
66
|
-
arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
|
|
77
|
+
arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, batch_size: int = 128, **kwargs: Any
|
|
67
78
|
) -> OrientationPredictor:
|
|
68
79
|
"""Crop orientation classification architecture.
|
|
69
80
|
|
|
@@ -74,20 +85,19 @@ def crop_orientation_predictor(
|
|
|
74
85
|
>>> out = model([input_crop])
|
|
75
86
|
|
|
76
87
|
Args:
|
|
77
|
-
----
|
|
78
88
|
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
|
|
79
89
|
pretrained: If True, returns a model pre-trained on our recognition crops dataset
|
|
90
|
+
batch_size: number of samples the model processes in parallel
|
|
80
91
|
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
81
92
|
|
|
82
93
|
Returns:
|
|
83
|
-
-------
|
|
84
94
|
OrientationPredictor
|
|
85
95
|
"""
|
|
86
|
-
return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs)
|
|
96
|
+
return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="crop", **kwargs)
|
|
87
97
|
|
|
88
98
|
|
|
89
99
|
def page_orientation_predictor(
|
|
90
|
-
arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
|
|
100
|
+
arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, batch_size: int = 4, **kwargs: Any
|
|
91
101
|
) -> OrientationPredictor:
|
|
92
102
|
"""Page orientation classification architecture.
|
|
93
103
|
|
|
@@ -98,13 +108,12 @@ def page_orientation_predictor(
|
|
|
98
108
|
>>> out = model([input_page])
|
|
99
109
|
|
|
100
110
|
Args:
|
|
101
|
-
----
|
|
102
111
|
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
|
|
103
112
|
pretrained: If True, returns a model pre-trained on our recognition crops dataset
|
|
113
|
+
batch_size: number of samples the model processes in parallel
|
|
104
114
|
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
105
115
|
|
|
106
116
|
Returns:
|
|
107
|
-
-------
|
|
108
117
|
OrientationPredictor
|
|
109
118
|
"""
|
|
110
|
-
return _orientation_predictor(arch, pretrained, model_type="page", **kwargs)
|
|
119
|
+
return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="page", **kwargs)
|
doctr/models/core.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
from doctr.utils.repr import NestedObject
|
|
10
10
|
|
|
@@ -14,6 +14,6 @@ __all__ = ["BaseModel"]
|
|
|
14
14
|
class BaseModel(NestedObject):
|
|
15
15
|
"""Implements abstract DetectionModel class"""
|
|
16
16
|
|
|
17
|
-
def __init__(self, cfg:
|
|
17
|
+
def __init__(self, cfg: dict[str, Any] | None = None) -> None:
|
|
18
18
|
super().__init__()
|
|
19
19
|
self.cfg = cfg
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
from .base import *
|
|
3
3
|
|
|
4
|
-
if
|
|
5
|
-
from .tensorflow import *
|
|
6
|
-
else:
|
|
4
|
+
if is_torch_available():
|
|
7
5
|
from .pytorch import *
|
|
6
|
+
elif is_tf_available():
|
|
7
|
+
from .tensorflow import *
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Dict, List
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -11,16 +10,15 @@ __all__ = ["_remove_padding"]
|
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
def _remove_padding(
|
|
14
|
-
pages:
|
|
15
|
-
loc_preds:
|
|
13
|
+
pages: list[np.ndarray],
|
|
14
|
+
loc_preds: list[dict[str, np.ndarray]],
|
|
16
15
|
preserve_aspect_ratio: bool,
|
|
17
16
|
symmetric_pad: bool,
|
|
18
17
|
assume_straight_pages: bool,
|
|
19
|
-
) ->
|
|
18
|
+
) -> list[dict[str, np.ndarray]]:
|
|
20
19
|
"""Remove padding from the localization predictions
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
----
|
|
24
22
|
pages: list of pages
|
|
25
23
|
loc_preds: list of localization predictions
|
|
26
24
|
preserve_aspect_ratio: whether the aspect ratio was preserved during padding
|
|
@@ -28,7 +26,6 @@ def _remove_padding(
|
|
|
28
26
|
assume_straight_pages: whether the pages are assumed to be straight
|
|
29
27
|
|
|
30
28
|
Returns:
|
|
31
|
-
-------
|
|
32
29
|
list of unpaded localization predictions
|
|
33
30
|
"""
|
|
34
31
|
if preserve_aspect_ratio:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -13,12 +13,10 @@ def erode(x: Tensor, kernel_size: int) -> Tensor:
|
|
|
13
13
|
"""Performs erosion on a given tensor
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
|
-
----
|
|
17
16
|
x: boolean tensor of shape (N, C, H, W)
|
|
18
17
|
kernel_size: the size of the kernel to use for erosion
|
|
19
18
|
|
|
20
19
|
Returns:
|
|
21
|
-
-------
|
|
22
20
|
the eroded tensor
|
|
23
21
|
"""
|
|
24
22
|
_pad = (kernel_size - 1) // 2
|
|
@@ -30,12 +28,10 @@ def dilate(x: Tensor, kernel_size: int) -> Tensor:
|
|
|
30
28
|
"""Performs dilation on a given tensor
|
|
31
29
|
|
|
32
30
|
Args:
|
|
33
|
-
----
|
|
34
31
|
x: boolean tensor of shape (N, C, H, W)
|
|
35
32
|
kernel_size: the size of the kernel to use for dilation
|
|
36
33
|
|
|
37
34
|
Returns:
|
|
38
|
-
-------
|
|
39
35
|
the dilated tensor
|
|
40
36
|
"""
|
|
41
37
|
_pad = (kernel_size - 1) // 2
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -12,12 +12,10 @@ def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
|
|
|
12
12
|
"""Performs erosion on a given tensor
|
|
13
13
|
|
|
14
14
|
Args:
|
|
15
|
-
----
|
|
16
15
|
x: boolean tensor of shape (N, H, W, C)
|
|
17
16
|
kernel_size: the size of the kernel to use for erosion
|
|
18
17
|
|
|
19
18
|
Returns:
|
|
20
|
-
-------
|
|
21
19
|
the eroded tensor
|
|
22
20
|
"""
|
|
23
21
|
return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME")
|
|
@@ -27,12 +25,10 @@ def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
|
|
|
27
25
|
"""Performs dilation on a given tensor
|
|
28
26
|
|
|
29
27
|
Args:
|
|
30
|
-
----
|
|
31
28
|
x: boolean tensor of shape (N, H, W, C)
|
|
32
29
|
kernel_size: the size of the kernel to use for dilation
|
|
33
30
|
|
|
34
31
|
Returns:
|
|
35
|
-
-------
|
|
36
32
|
the dilated tensor
|
|
37
33
|
"""
|
|
38
34
|
return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME")
|
doctr/models/detection/core.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List
|
|
7
6
|
|
|
8
7
|
import cv2
|
|
9
8
|
import numpy as np
|
|
@@ -17,7 +16,6 @@ class DetectionPostProcessor(NestedObject):
|
|
|
17
16
|
"""Abstract class to postprocess the raw output of the model
|
|
18
17
|
|
|
19
18
|
Args:
|
|
20
|
-
----
|
|
21
19
|
box_thresh (float): minimal objectness score to consider a box
|
|
22
20
|
bin_thresh (float): threshold to apply to segmentation raw heatmap
|
|
23
21
|
assume straight_pages (bool): if True, fit straight boxes only
|
|
@@ -37,13 +35,11 @@ class DetectionPostProcessor(NestedObject):
|
|
|
37
35
|
"""Compute the confidence score for a polygon : mean of the p values on the polygon
|
|
38
36
|
|
|
39
37
|
Args:
|
|
40
|
-
----
|
|
41
38
|
pred (np.ndarray): p map returned by the model
|
|
42
39
|
points: coordinates of the polygon
|
|
43
40
|
assume_straight_pages: if True, fit straight boxes only
|
|
44
41
|
|
|
45
42
|
Returns:
|
|
46
|
-
-------
|
|
47
43
|
polygon objectness
|
|
48
44
|
"""
|
|
49
45
|
h, w = pred.shape[:2]
|
|
@@ -71,15 +67,13 @@ class DetectionPostProcessor(NestedObject):
|
|
|
71
67
|
def __call__(
|
|
72
68
|
self,
|
|
73
69
|
proba_map,
|
|
74
|
-
) ->
|
|
70
|
+
) -> list[list[np.ndarray]]:
|
|
75
71
|
"""Performs postprocessing for a list of model outputs
|
|
76
72
|
|
|
77
73
|
Args:
|
|
78
|
-
----
|
|
79
74
|
proba_map: probability map of shape (N, H, W, C)
|
|
80
75
|
|
|
81
76
|
Returns:
|
|
82
|
-
-------
|
|
83
77
|
list of N class predictions (for each input sample), where each class predictions is a list of C tensors
|
|
84
78
|
of shape (*, 5) or (*, 6)
|
|
85
79
|
"""
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
|
-
from typing import Dict, List, Tuple, Union
|
|
9
8
|
|
|
10
9
|
import cv2
|
|
11
10
|
import numpy as np
|
|
@@ -22,7 +21,6 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
22
21
|
<https://github.com/xuannianz/DifferentiableBinarization>`_.
|
|
23
22
|
|
|
24
23
|
Args:
|
|
25
|
-
----
|
|
26
24
|
unclip ratio: ratio used to unshrink polygons
|
|
27
25
|
min_size_box: minimal length (pix) to keep a box
|
|
28
26
|
max_candidates: maximum boxes to consider in a single page
|
|
@@ -47,11 +45,9 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
47
45
|
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
48
46
|
|
|
49
47
|
Args:
|
|
50
|
-
----
|
|
51
48
|
points: The first parameter.
|
|
52
49
|
|
|
53
50
|
Returns:
|
|
54
|
-
-------
|
|
55
51
|
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
56
52
|
"""
|
|
57
53
|
if not self.assume_straight_pages:
|
|
@@ -62,9 +58,8 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
62
58
|
area = (rect[1][0] + 1) * (1 + rect[1][1])
|
|
63
59
|
length = 2 * (rect[1][0] + rect[1][1]) + 2
|
|
64
60
|
else:
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
length = poly.length
|
|
61
|
+
area = cv2.contourArea(points)
|
|
62
|
+
length = cv2.arcLength(points, closed=True)
|
|
68
63
|
distance = area * self.unclip_ratio / length # compute distance to expand polygon
|
|
69
64
|
offset = pyclipper.PyclipperOffset()
|
|
70
65
|
offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
@@ -96,25 +91,23 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
96
91
|
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
97
92
|
|
|
98
93
|
Args:
|
|
99
|
-
----
|
|
100
94
|
pred: Pred map from differentiable binarization output
|
|
101
95
|
bitmap: Bitmap map computed from pred (binarized)
|
|
102
96
|
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
103
97
|
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
104
98
|
|
|
105
99
|
Returns:
|
|
106
|
-
-------
|
|
107
100
|
np tensor boxes for the bitmap, each box is a 5-element list
|
|
108
101
|
containing x, y, w, h, score for the box
|
|
109
102
|
"""
|
|
110
103
|
height, width = bitmap.shape[:2]
|
|
111
104
|
min_size_box = 2
|
|
112
|
-
boxes:
|
|
105
|
+
boxes: list[np.ndarray | list[float]] = []
|
|
113
106
|
# get contours from connected components on the bitmap
|
|
114
107
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
115
108
|
for contour in contours:
|
|
116
109
|
# Check whether smallest enclosing bounding box is not too small
|
|
117
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
|
|
110
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
|
|
118
111
|
continue
|
|
119
112
|
# Compute objectness
|
|
120
113
|
if self.assume_straight_pages:
|
|
@@ -164,7 +157,6 @@ class _DBNet:
|
|
|
164
157
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
165
158
|
|
|
166
159
|
Args:
|
|
167
|
-
----
|
|
168
160
|
feature extractor: the backbone serving as feature extractor
|
|
169
161
|
fpn_channels: number of channels each extracted feature maps is mapped to
|
|
170
162
|
"""
|
|
@@ -186,7 +178,6 @@ class _DBNet:
|
|
|
186
178
|
"""Compute the distance for each point of the map (xs, ys) to the (a, b) segment
|
|
187
179
|
|
|
188
180
|
Args:
|
|
189
|
-
----
|
|
190
181
|
xs : map of x coordinates (height, width)
|
|
191
182
|
ys : map of y coordinates (height, width)
|
|
192
183
|
a: first point defining the [ab] segment
|
|
@@ -194,7 +185,6 @@ class _DBNet:
|
|
|
194
185
|
eps: epsilon to avoid division by zero
|
|
195
186
|
|
|
196
187
|
Returns:
|
|
197
|
-
-------
|
|
198
188
|
The computed distance
|
|
199
189
|
|
|
200
190
|
"""
|
|
@@ -214,11 +204,10 @@ class _DBNet:
|
|
|
214
204
|
polygon: np.ndarray,
|
|
215
205
|
canvas: np.ndarray,
|
|
216
206
|
mask: np.ndarray,
|
|
217
|
-
) ->
|
|
218
|
-
"""Draw a polygon
|
|
207
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
208
|
+
"""Draw a polygon threshold map on a canvas, as described in the DB paper
|
|
219
209
|
|
|
220
210
|
Args:
|
|
221
|
-
----
|
|
222
211
|
polygon : array of coord., to draw the boundary of the polygon
|
|
223
212
|
canvas : threshold map to fill with polygons
|
|
224
213
|
mask : mask for training on threshold polygons
|
|
@@ -278,10 +267,10 @@ class _DBNet:
|
|
|
278
267
|
|
|
279
268
|
def build_target(
|
|
280
269
|
self,
|
|
281
|
-
target:
|
|
282
|
-
output_shape:
|
|
270
|
+
target: list[dict[str, np.ndarray]],
|
|
271
|
+
output_shape: tuple[int, int, int],
|
|
283
272
|
channels_last: bool = True,
|
|
284
|
-
) ->
|
|
273
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
285
274
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|
|
286
275
|
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
|
|
287
276
|
if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
|