python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +8 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +7 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +4 -5
- doctr/datasets/ic13.py +4 -5
- doctr/datasets/iiit5k.py +6 -5
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +6 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +6 -5
- doctr/datasets/svt.py +4 -5
- doctr/datasets/synthtext.py +4 -5
- doctr/datasets/utils.py +34 -29
- doctr/datasets/vocabs.py +17 -7
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +2 -7
- doctr/io/elements.py +59 -79
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +30 -48
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +8 -11
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +5 -17
- doctr/models/classification/mobilenet/tensorflow.py +8 -21
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +6 -8
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +20 -31
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +8 -15
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +9 -12
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +6 -12
- doctr/models/classification/zoo.py +19 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +14 -26
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +14 -23
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +3 -7
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +18 -19
- doctr/models/kie_predictor/tensorflow.py +13 -14
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +12 -13
- doctr/models/predictor/tensorflow.py +8 -9
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +11 -23
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +12 -22
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +16 -22
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +12 -21
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +12 -20
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +14 -17
- doctr/models/utils/tensorflow.py +17 -16
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +16 -47
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.file_utils import is_tf_available
|
|
8
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
9
9
|
|
|
10
10
|
from .. import classification
|
|
11
11
|
from ..preprocessor import PreProcessor
|
|
@@ -13,7 +13,7 @@ from .predictor import OrientationPredictor
|
|
|
13
13
|
|
|
14
14
|
__all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
|
|
15
15
|
|
|
16
|
-
ARCHS:
|
|
16
|
+
ARCHS: list[str] = [
|
|
17
17
|
"magc_resnet31",
|
|
18
18
|
"mobilenet_v3_small",
|
|
19
19
|
"mobilenet_v3_small_r",
|
|
@@ -31,7 +31,7 @@ ARCHS: List[str] = [
|
|
|
31
31
|
"vit_s",
|
|
32
32
|
"vit_b",
|
|
33
33
|
]
|
|
34
|
-
ORIENTATION_ARCHS:
|
|
34
|
+
ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
def _orientation_predictor(
|
|
@@ -48,7 +48,14 @@ def _orientation_predictor(
|
|
|
48
48
|
# Load directly classifier from backbone
|
|
49
49
|
_model = classification.__dict__[arch](pretrained=pretrained)
|
|
50
50
|
else:
|
|
51
|
-
|
|
51
|
+
allowed_archs = [classification.MobileNetV3]
|
|
52
|
+
if is_torch_available():
|
|
53
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
54
|
+
from doctr.models.utils import _CompiledModule
|
|
55
|
+
|
|
56
|
+
allowed_archs.append(_CompiledModule)
|
|
57
|
+
|
|
58
|
+
if not isinstance(arch, tuple(allowed_archs)):
|
|
52
59
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
53
60
|
_model = arch
|
|
54
61
|
|
|
@@ -63,7 +70,7 @@ def _orientation_predictor(
|
|
|
63
70
|
|
|
64
71
|
|
|
65
72
|
def crop_orientation_predictor(
|
|
66
|
-
arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
|
|
73
|
+
arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, batch_size: int = 128, **kwargs: Any
|
|
67
74
|
) -> OrientationPredictor:
|
|
68
75
|
"""Crop orientation classification architecture.
|
|
69
76
|
|
|
@@ -74,20 +81,19 @@ def crop_orientation_predictor(
|
|
|
74
81
|
>>> out = model([input_crop])
|
|
75
82
|
|
|
76
83
|
Args:
|
|
77
|
-
----
|
|
78
84
|
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
|
|
79
85
|
pretrained: If True, returns a model pre-trained on our recognition crops dataset
|
|
86
|
+
batch_size: number of samples the model processes in parallel
|
|
80
87
|
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
81
88
|
|
|
82
89
|
Returns:
|
|
83
|
-
-------
|
|
84
90
|
OrientationPredictor
|
|
85
91
|
"""
|
|
86
|
-
return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs)
|
|
92
|
+
return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="crop", **kwargs)
|
|
87
93
|
|
|
88
94
|
|
|
89
95
|
def page_orientation_predictor(
|
|
90
|
-
arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
|
|
96
|
+
arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, batch_size: int = 4, **kwargs: Any
|
|
91
97
|
) -> OrientationPredictor:
|
|
92
98
|
"""Page orientation classification architecture.
|
|
93
99
|
|
|
@@ -98,13 +104,12 @@ def page_orientation_predictor(
|
|
|
98
104
|
>>> out = model([input_page])
|
|
99
105
|
|
|
100
106
|
Args:
|
|
101
|
-
----
|
|
102
107
|
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
|
|
103
108
|
pretrained: If True, returns a model pre-trained on our recognition crops dataset
|
|
109
|
+
batch_size: number of samples the model processes in parallel
|
|
104
110
|
**kwargs: keyword arguments to be passed to the OrientationPredictor
|
|
105
111
|
|
|
106
112
|
Returns:
|
|
107
|
-
-------
|
|
108
113
|
OrientationPredictor
|
|
109
114
|
"""
|
|
110
|
-
return _orientation_predictor(arch, pretrained, model_type="page", **kwargs)
|
|
115
|
+
return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="page", **kwargs)
|
doctr/models/core.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
from doctr.utils.repr import NestedObject
|
|
10
10
|
|
|
@@ -14,6 +14,6 @@ __all__ = ["BaseModel"]
|
|
|
14
14
|
class BaseModel(NestedObject):
|
|
15
15
|
"""Implements abstract DetectionModel class"""
|
|
16
16
|
|
|
17
|
-
def __init__(self, cfg:
|
|
17
|
+
def __init__(self, cfg: dict[str, Any] | None = None) -> None:
|
|
18
18
|
super().__init__()
|
|
19
19
|
self.cfg = cfg
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
from .base import *
|
|
3
3
|
|
|
4
|
-
if
|
|
5
|
-
from .tensorflow import *
|
|
6
|
-
else:
|
|
4
|
+
if is_torch_available():
|
|
7
5
|
from .pytorch import *
|
|
6
|
+
elif is_tf_available():
|
|
7
|
+
from .tensorflow import *
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Dict, List
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -11,16 +10,15 @@ __all__ = ["_remove_padding"]
|
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
def _remove_padding(
|
|
14
|
-
pages:
|
|
15
|
-
loc_preds:
|
|
13
|
+
pages: list[np.ndarray],
|
|
14
|
+
loc_preds: list[dict[str, np.ndarray]],
|
|
16
15
|
preserve_aspect_ratio: bool,
|
|
17
16
|
symmetric_pad: bool,
|
|
18
17
|
assume_straight_pages: bool,
|
|
19
|
-
) ->
|
|
18
|
+
) -> list[dict[str, np.ndarray]]:
|
|
20
19
|
"""Remove padding from the localization predictions
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
----
|
|
24
22
|
pages: list of pages
|
|
25
23
|
loc_preds: list of localization predictions
|
|
26
24
|
preserve_aspect_ratio: whether the aspect ratio was preserved during padding
|
|
@@ -28,7 +26,6 @@ def _remove_padding(
|
|
|
28
26
|
assume_straight_pages: whether the pages are assumed to be straight
|
|
29
27
|
|
|
30
28
|
Returns:
|
|
31
|
-
-------
|
|
32
29
|
list of unpaded localization predictions
|
|
33
30
|
"""
|
|
34
31
|
if preserve_aspect_ratio:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -13,12 +13,10 @@ def erode(x: Tensor, kernel_size: int) -> Tensor:
|
|
|
13
13
|
"""Performs erosion on a given tensor
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
|
-
----
|
|
17
16
|
x: boolean tensor of shape (N, C, H, W)
|
|
18
17
|
kernel_size: the size of the kernel to use for erosion
|
|
19
18
|
|
|
20
19
|
Returns:
|
|
21
|
-
-------
|
|
22
20
|
the eroded tensor
|
|
23
21
|
"""
|
|
24
22
|
_pad = (kernel_size - 1) // 2
|
|
@@ -30,12 +28,10 @@ def dilate(x: Tensor, kernel_size: int) -> Tensor:
|
|
|
30
28
|
"""Performs dilation on a given tensor
|
|
31
29
|
|
|
32
30
|
Args:
|
|
33
|
-
----
|
|
34
31
|
x: boolean tensor of shape (N, C, H, W)
|
|
35
32
|
kernel_size: the size of the kernel to use for dilation
|
|
36
33
|
|
|
37
34
|
Returns:
|
|
38
|
-
-------
|
|
39
35
|
the dilated tensor
|
|
40
36
|
"""
|
|
41
37
|
_pad = (kernel_size - 1) // 2
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -12,12 +12,10 @@ def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
|
|
|
12
12
|
"""Performs erosion on a given tensor
|
|
13
13
|
|
|
14
14
|
Args:
|
|
15
|
-
----
|
|
16
15
|
x: boolean tensor of shape (N, H, W, C)
|
|
17
16
|
kernel_size: the size of the kernel to use for erosion
|
|
18
17
|
|
|
19
18
|
Returns:
|
|
20
|
-
-------
|
|
21
19
|
the eroded tensor
|
|
22
20
|
"""
|
|
23
21
|
return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME")
|
|
@@ -27,12 +25,10 @@ def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
|
|
|
27
25
|
"""Performs dilation on a given tensor
|
|
28
26
|
|
|
29
27
|
Args:
|
|
30
|
-
----
|
|
31
28
|
x: boolean tensor of shape (N, H, W, C)
|
|
32
29
|
kernel_size: the size of the kernel to use for dilation
|
|
33
30
|
|
|
34
31
|
Returns:
|
|
35
|
-
-------
|
|
36
32
|
the dilated tensor
|
|
37
33
|
"""
|
|
38
34
|
return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME")
|
doctr/models/detection/core.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List
|
|
7
6
|
|
|
8
7
|
import cv2
|
|
9
8
|
import numpy as np
|
|
@@ -17,7 +16,6 @@ class DetectionPostProcessor(NestedObject):
|
|
|
17
16
|
"""Abstract class to postprocess the raw output of the model
|
|
18
17
|
|
|
19
18
|
Args:
|
|
20
|
-
----
|
|
21
19
|
box_thresh (float): minimal objectness score to consider a box
|
|
22
20
|
bin_thresh (float): threshold to apply to segmentation raw heatmap
|
|
23
21
|
assume straight_pages (bool): if True, fit straight boxes only
|
|
@@ -37,13 +35,11 @@ class DetectionPostProcessor(NestedObject):
|
|
|
37
35
|
"""Compute the confidence score for a polygon : mean of the p values on the polygon
|
|
38
36
|
|
|
39
37
|
Args:
|
|
40
|
-
----
|
|
41
38
|
pred (np.ndarray): p map returned by the model
|
|
42
39
|
points: coordinates of the polygon
|
|
43
40
|
assume_straight_pages: if True, fit straight boxes only
|
|
44
41
|
|
|
45
42
|
Returns:
|
|
46
|
-
-------
|
|
47
43
|
polygon objectness
|
|
48
44
|
"""
|
|
49
45
|
h, w = pred.shape[:2]
|
|
@@ -71,15 +67,13 @@ class DetectionPostProcessor(NestedObject):
|
|
|
71
67
|
def __call__(
|
|
72
68
|
self,
|
|
73
69
|
proba_map,
|
|
74
|
-
) ->
|
|
70
|
+
) -> list[list[np.ndarray]]:
|
|
75
71
|
"""Performs postprocessing for a list of model outputs
|
|
76
72
|
|
|
77
73
|
Args:
|
|
78
|
-
----
|
|
79
74
|
proba_map: probability map of shape (N, H, W, C)
|
|
80
75
|
|
|
81
76
|
Returns:
|
|
82
|
-
-------
|
|
83
77
|
list of N class predictions (for each input sample), where each class predictions is a list of C tensors
|
|
84
78
|
of shape (*, 5) or (*, 6)
|
|
85
79
|
"""
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
|
-
from typing import Dict, List, Tuple, Union
|
|
9
8
|
|
|
10
9
|
import cv2
|
|
11
10
|
import numpy as np
|
|
@@ -22,7 +21,6 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
22
21
|
<https://github.com/xuannianz/DifferentiableBinarization>`_.
|
|
23
22
|
|
|
24
23
|
Args:
|
|
25
|
-
----
|
|
26
24
|
unclip ratio: ratio used to unshrink polygons
|
|
27
25
|
min_size_box: minimal length (pix) to keep a box
|
|
28
26
|
max_candidates: maximum boxes to consider in a single page
|
|
@@ -47,11 +45,9 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
47
45
|
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
48
46
|
|
|
49
47
|
Args:
|
|
50
|
-
----
|
|
51
48
|
points: The first parameter.
|
|
52
49
|
|
|
53
50
|
Returns:
|
|
54
|
-
-------
|
|
55
51
|
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
56
52
|
"""
|
|
57
53
|
if not self.assume_straight_pages:
|
|
@@ -96,25 +92,23 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
96
92
|
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
97
93
|
|
|
98
94
|
Args:
|
|
99
|
-
----
|
|
100
95
|
pred: Pred map from differentiable binarization output
|
|
101
96
|
bitmap: Bitmap map computed from pred (binarized)
|
|
102
97
|
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
103
98
|
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
104
99
|
|
|
105
100
|
Returns:
|
|
106
|
-
-------
|
|
107
101
|
np tensor boxes for the bitmap, each box is a 5-element list
|
|
108
102
|
containing x, y, w, h, score for the box
|
|
109
103
|
"""
|
|
110
104
|
height, width = bitmap.shape[:2]
|
|
111
105
|
min_size_box = 2
|
|
112
|
-
boxes:
|
|
106
|
+
boxes: list[np.ndarray | list[float]] = []
|
|
113
107
|
# get contours from connected components on the bitmap
|
|
114
108
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
115
109
|
for contour in contours:
|
|
116
110
|
# Check whether smallest enclosing bounding box is not too small
|
|
117
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
|
|
111
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
|
|
118
112
|
continue
|
|
119
113
|
# Compute objectness
|
|
120
114
|
if self.assume_straight_pages:
|
|
@@ -164,7 +158,6 @@ class _DBNet:
|
|
|
164
158
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
165
159
|
|
|
166
160
|
Args:
|
|
167
|
-
----
|
|
168
161
|
feature extractor: the backbone serving as feature extractor
|
|
169
162
|
fpn_channels: number of channels each extracted feature maps is mapped to
|
|
170
163
|
"""
|
|
@@ -186,7 +179,6 @@ class _DBNet:
|
|
|
186
179
|
"""Compute the distance for each point of the map (xs, ys) to the (a, b) segment
|
|
187
180
|
|
|
188
181
|
Args:
|
|
189
|
-
----
|
|
190
182
|
xs : map of x coordinates (height, width)
|
|
191
183
|
ys : map of y coordinates (height, width)
|
|
192
184
|
a: first point defining the [ab] segment
|
|
@@ -194,7 +186,6 @@ class _DBNet:
|
|
|
194
186
|
eps: epsilon to avoid division by zero
|
|
195
187
|
|
|
196
188
|
Returns:
|
|
197
|
-
-------
|
|
198
189
|
The computed distance
|
|
199
190
|
|
|
200
191
|
"""
|
|
@@ -214,11 +205,10 @@ class _DBNet:
|
|
|
214
205
|
polygon: np.ndarray,
|
|
215
206
|
canvas: np.ndarray,
|
|
216
207
|
mask: np.ndarray,
|
|
217
|
-
) ->
|
|
208
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
218
209
|
"""Draw a polygon treshold map on a canvas, as described in the DB paper
|
|
219
210
|
|
|
220
211
|
Args:
|
|
221
|
-
----
|
|
222
212
|
polygon : array of coord., to draw the boundary of the polygon
|
|
223
213
|
canvas : threshold map to fill with polygons
|
|
224
214
|
mask : mask for training on threshold polygons
|
|
@@ -278,10 +268,10 @@ class _DBNet:
|
|
|
278
268
|
|
|
279
269
|
def build_target(
|
|
280
270
|
self,
|
|
281
|
-
target:
|
|
282
|
-
output_shape:
|
|
271
|
+
target: list[dict[str, np.ndarray]],
|
|
272
|
+
output_shape: tuple[int, int, int],
|
|
283
273
|
channels_last: bool = True,
|
|
284
|
-
) ->
|
|
274
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
285
275
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|
|
286
276
|
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
|
|
287
277
|
if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import torch
|
|
@@ -22,7 +23,7 @@ from .base import DBPostProcessor, _DBNet
|
|
|
22
23
|
__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
|
|
23
24
|
|
|
24
25
|
|
|
25
|
-
default_cfgs:
|
|
26
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
26
27
|
"db_resnet50": {
|
|
27
28
|
"input_shape": (3, 1024, 1024),
|
|
28
29
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -47,7 +48,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
47
48
|
class FeaturePyramidNetwork(nn.Module):
|
|
48
49
|
def __init__(
|
|
49
50
|
self,
|
|
50
|
-
in_channels:
|
|
51
|
+
in_channels: list[int],
|
|
51
52
|
out_channels: int,
|
|
52
53
|
deform_conv: bool = False,
|
|
53
54
|
) -> None:
|
|
@@ -76,12 +77,12 @@ class FeaturePyramidNetwork(nn.Module):
|
|
|
76
77
|
for idx, chans in enumerate(in_channels)
|
|
77
78
|
])
|
|
78
79
|
|
|
79
|
-
def forward(self, x:
|
|
80
|
+
def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
|
|
80
81
|
if len(x) != len(self.out_branches):
|
|
81
82
|
raise AssertionError
|
|
82
83
|
# Conv1x1 to get the same number of channels
|
|
83
|
-
_x:
|
|
84
|
-
out:
|
|
84
|
+
_x: list[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)]
|
|
85
|
+
out: list[torch.Tensor] = [_x[-1]]
|
|
85
86
|
for t in _x[:-1][::-1]:
|
|
86
87
|
out.append(self.upsample(out[-1]) + t)
|
|
87
88
|
|
|
@@ -96,7 +97,6 @@ class DBNet(_DBNet, nn.Module):
|
|
|
96
97
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
97
98
|
|
|
98
99
|
Args:
|
|
99
|
-
----
|
|
100
100
|
feature extractor: the backbone serving as feature extractor
|
|
101
101
|
head_chans: the number of channels in the head
|
|
102
102
|
deform_conv: whether to use deformable convolution
|
|
@@ -117,8 +117,8 @@ class DBNet(_DBNet, nn.Module):
|
|
|
117
117
|
box_thresh: float = 0.1,
|
|
118
118
|
assume_straight_pages: bool = True,
|
|
119
119
|
exportable: bool = False,
|
|
120
|
-
cfg:
|
|
121
|
-
class_names:
|
|
120
|
+
cfg: dict[str, Any] | None = None,
|
|
121
|
+
class_names: list[str] = [CLASS_NAME],
|
|
122
122
|
) -> None:
|
|
123
123
|
super().__init__()
|
|
124
124
|
self.class_names = class_names
|
|
@@ -182,10 +182,10 @@ class DBNet(_DBNet, nn.Module):
|
|
|
182
182
|
def forward(
|
|
183
183
|
self,
|
|
184
184
|
x: torch.Tensor,
|
|
185
|
-
target:
|
|
185
|
+
target: list[np.ndarray] | None = None,
|
|
186
186
|
return_model_output: bool = False,
|
|
187
187
|
return_preds: bool = False,
|
|
188
|
-
) ->
|
|
188
|
+
) -> dict[str, torch.Tensor]:
|
|
189
189
|
# Extract feature maps at different stages
|
|
190
190
|
feats = self.feat_extractor(x)
|
|
191
191
|
feats = [feats[str(idx)] for idx in range(len(feats))]
|
|
@@ -193,7 +193,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
193
193
|
feat_concat = self.fpn(feats)
|
|
194
194
|
logits = self.prob_head(feat_concat)
|
|
195
195
|
|
|
196
|
-
out:
|
|
196
|
+
out: dict[str, Any] = {}
|
|
197
197
|
if self.exportable:
|
|
198
198
|
out["logits"] = logits
|
|
199
199
|
return out
|
|
@@ -205,11 +205,16 @@ class DBNet(_DBNet, nn.Module):
|
|
|
205
205
|
out["out_map"] = prob_map
|
|
206
206
|
|
|
207
207
|
if target is None or return_preds:
|
|
208
|
+
# Disable for torch.compile compatibility
|
|
209
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
210
|
+
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
211
|
+
return [
|
|
212
|
+
dict(zip(self.class_names, preds))
|
|
213
|
+
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
214
|
+
]
|
|
215
|
+
|
|
208
216
|
# Post-process boxes (keep only text predictions)
|
|
209
|
-
out["preds"] =
|
|
210
|
-
dict(zip(self.class_names, preds))
|
|
211
|
-
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
212
|
-
]
|
|
217
|
+
out["preds"] = _postprocess(prob_map)
|
|
213
218
|
|
|
214
219
|
if target is not None:
|
|
215
220
|
thresh_map = self.thresh_head(feat_concat)
|
|
@@ -222,7 +227,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
222
227
|
self,
|
|
223
228
|
out_map: torch.Tensor,
|
|
224
229
|
thresh_map: torch.Tensor,
|
|
225
|
-
target:
|
|
230
|
+
target: list[np.ndarray],
|
|
226
231
|
gamma: float = 2.0,
|
|
227
232
|
alpha: float = 0.5,
|
|
228
233
|
eps: float = 1e-8,
|
|
@@ -231,7 +236,6 @@ class DBNet(_DBNet, nn.Module):
|
|
|
231
236
|
and a list of masks for each image. From there it computes the loss with the model output
|
|
232
237
|
|
|
233
238
|
Args:
|
|
234
|
-
----
|
|
235
239
|
out_map: output feature map of the model of shape (N, C, H, W)
|
|
236
240
|
thresh_map: threshold map of shape (N, C, H, W)
|
|
237
241
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
@@ -240,7 +244,6 @@ class DBNet(_DBNet, nn.Module):
|
|
|
240
244
|
eps: epsilon factor in dice loss
|
|
241
245
|
|
|
242
246
|
Returns:
|
|
243
|
-
-------
|
|
244
247
|
A loss tensor
|
|
245
248
|
"""
|
|
246
249
|
if gamma < 0:
|
|
@@ -273,7 +276,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
273
276
|
dice_map = torch.softmax(out_map, dim=1)
|
|
274
277
|
else:
|
|
275
278
|
# compute binary map instead
|
|
276
|
-
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
|
|
279
|
+
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
|
|
277
280
|
# Class reduced
|
|
278
281
|
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
|
|
279
282
|
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
|
|
@@ -290,10 +293,10 @@ def _dbnet(
|
|
|
290
293
|
arch: str,
|
|
291
294
|
pretrained: bool,
|
|
292
295
|
backbone_fn: Callable[[bool], nn.Module],
|
|
293
|
-
fpn_layers:
|
|
294
|
-
backbone_submodule:
|
|
296
|
+
fpn_layers: list[str],
|
|
297
|
+
backbone_submodule: str | None = None,
|
|
295
298
|
pretrained_backbone: bool = True,
|
|
296
|
-
ignore_keys:
|
|
299
|
+
ignore_keys: list[str] | None = None,
|
|
297
300
|
**kwargs: Any,
|
|
298
301
|
) -> DBNet:
|
|
299
302
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -341,12 +344,10 @@ def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
341
344
|
>>> out = model(input_tensor)
|
|
342
345
|
|
|
343
346
|
Args:
|
|
344
|
-
----
|
|
345
347
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
346
348
|
**kwargs: keyword arguments of the DBNet architecture
|
|
347
349
|
|
|
348
350
|
Returns:
|
|
349
|
-
-------
|
|
350
351
|
text detection architecture
|
|
351
352
|
"""
|
|
352
353
|
return _dbnet(
|
|
@@ -376,12 +377,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
376
377
|
>>> out = model(input_tensor)
|
|
377
378
|
|
|
378
379
|
Args:
|
|
379
|
-
----
|
|
380
380
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
381
381
|
**kwargs: keyword arguments of the DBNet architecture
|
|
382
382
|
|
|
383
383
|
Returns:
|
|
384
|
-
-------
|
|
385
384
|
text detection architecture
|
|
386
385
|
"""
|
|
387
386
|
return _dbnet(
|
|
@@ -411,12 +410,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
411
410
|
>>> out = model(input_tensor)
|
|
412
411
|
|
|
413
412
|
Args:
|
|
414
|
-
----
|
|
415
413
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
416
414
|
**kwargs: keyword arguments of the DBNet architecture
|
|
417
415
|
|
|
418
416
|
Returns:
|
|
419
|
-
-------
|
|
420
417
|
text detection architecture
|
|
421
418
|
"""
|
|
422
419
|
return _dbnet(
|