python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from doctr.models.builder import KIEDocumentBuilder
|
|
9
9
|
|
|
@@ -17,7 +17,6 @@ class _KIEPredictor(_OCRPredictor):
|
|
|
17
17
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
18
18
|
|
|
19
19
|
Args:
|
|
20
|
-
----
|
|
21
20
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
22
21
|
without rotated textual elements.
|
|
23
22
|
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
|
|
@@ -30,8 +29,8 @@ class _KIEPredictor(_OCRPredictor):
|
|
|
30
29
|
kwargs: keyword args of `DocumentBuilder`
|
|
31
30
|
"""
|
|
32
31
|
|
|
33
|
-
crop_orientation_predictor:
|
|
34
|
-
page_orientation_predictor:
|
|
32
|
+
crop_orientation_predictor: OrientationPredictor | None
|
|
33
|
+
page_orientation_predictor: OrientationPredictor | None
|
|
35
34
|
|
|
36
35
|
def __init__(
|
|
37
36
|
self,
|
|
@@ -46,4 +45,8 @@ class _KIEPredictor(_OCRPredictor):
|
|
|
46
45
|
assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, detect_orientation, **kwargs
|
|
47
46
|
)
|
|
48
47
|
|
|
48
|
+
# Remove the following arguments from kwargs after initialization of the parent class
|
|
49
|
+
kwargs.pop("disable_page_orientation", None)
|
|
50
|
+
kwargs.pop("disable_crop_orientation", None)
|
|
51
|
+
|
|
49
52
|
self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -24,7 +24,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
-
----
|
|
28
27
|
det_predictor: detection module
|
|
29
28
|
reco_predictor: recognition module
|
|
30
29
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -52,8 +51,8 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
52
51
|
**kwargs: Any,
|
|
53
52
|
) -> None:
|
|
54
53
|
nn.Module.__init__(self)
|
|
55
|
-
self.det_predictor = det_predictor.eval()
|
|
56
|
-
self.reco_predictor = reco_predictor.eval()
|
|
54
|
+
self.det_predictor = det_predictor.eval()
|
|
55
|
+
self.reco_predictor = reco_predictor.eval()
|
|
57
56
|
_KIEPredictor.__init__(
|
|
58
57
|
self,
|
|
59
58
|
assume_straight_pages,
|
|
@@ -69,7 +68,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
69
68
|
@torch.inference_mode()
|
|
70
69
|
def forward(
|
|
71
70
|
self,
|
|
72
|
-
pages:
|
|
71
|
+
pages: list[np.ndarray | torch.Tensor],
|
|
73
72
|
**kwargs: Any,
|
|
74
73
|
) -> Document:
|
|
75
74
|
# Dimension check
|
|
@@ -89,7 +88,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
89
88
|
for out_map in out_maps
|
|
90
89
|
]
|
|
91
90
|
if self.detect_orientation:
|
|
92
|
-
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
91
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
93
92
|
orientations = [
|
|
94
93
|
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
95
94
|
]
|
|
@@ -98,11 +97,14 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
98
97
|
general_pages_orientations = None
|
|
99
98
|
origin_pages_orientations = None
|
|
100
99
|
if self.straighten_pages:
|
|
101
|
-
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
100
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
101
|
+
# update page shapes after straightening
|
|
102
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
103
|
+
|
|
102
104
|
# Forward again to get predictions on straight pages
|
|
103
105
|
loc_preds = self.det_predictor(pages, **kwargs)
|
|
104
106
|
|
|
105
|
-
dict_loc_preds:
|
|
107
|
+
dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
|
|
106
108
|
|
|
107
109
|
# Detach objectness scores from loc_preds
|
|
108
110
|
objectness_scores = {}
|
|
@@ -122,10 +124,11 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
122
124
|
crops = {}
|
|
123
125
|
for class_name in dict_loc_preds.keys():
|
|
124
126
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
125
|
-
pages,
|
|
127
|
+
pages,
|
|
126
128
|
dict_loc_preds[class_name],
|
|
127
129
|
channels_last=channels_last,
|
|
128
130
|
assume_straight_pages=self.assume_straight_pages,
|
|
131
|
+
assume_horizontal=self._page_orientation_disabled,
|
|
129
132
|
)
|
|
130
133
|
# Rectify crop orientation
|
|
131
134
|
crop_orientations: Any = {}
|
|
@@ -146,18 +149,18 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
146
149
|
if not crop_orientations:
|
|
147
150
|
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
|
|
148
151
|
|
|
149
|
-
boxes:
|
|
150
|
-
text_preds:
|
|
151
|
-
word_crop_orientations:
|
|
152
|
+
boxes: dict = {}
|
|
153
|
+
text_preds: dict = {}
|
|
154
|
+
word_crop_orientations: dict = {}
|
|
152
155
|
for class_name in dict_loc_preds.keys():
|
|
153
156
|
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
|
|
154
157
|
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
|
|
155
158
|
)
|
|
156
159
|
|
|
157
|
-
boxes_per_page:
|
|
158
|
-
objectness_scores_per_page:
|
|
159
|
-
text_preds_per_page:
|
|
160
|
-
crop_orientations_per_page:
|
|
160
|
+
boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
|
|
161
|
+
objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
|
|
162
|
+
text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
|
|
163
|
+
crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
|
|
161
164
|
|
|
162
165
|
if self.detect_language:
|
|
163
166
|
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
|
|
@@ -166,7 +169,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
166
169
|
languages_dict = None
|
|
167
170
|
|
|
168
171
|
out = self.doc_builder(
|
|
169
|
-
pages,
|
|
172
|
+
pages,
|
|
170
173
|
boxes_per_page,
|
|
171
174
|
objectness_scores_per_page,
|
|
172
175
|
text_preds_per_page,
|
|
@@ -178,7 +181,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
178
181
|
return out
|
|
179
182
|
|
|
180
183
|
@staticmethod
|
|
181
|
-
def get_text(text_pred:
|
|
184
|
+
def get_text(text_pred: dict) -> str:
|
|
182
185
|
text = []
|
|
183
186
|
for value in text_pred.values():
|
|
184
187
|
text += [item[0] for item in value]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
@@ -24,7 +24,6 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
-
----
|
|
28
27
|
det_predictor: detection module
|
|
29
28
|
reco_predictor: recognition module
|
|
30
29
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -69,7 +68,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
69
68
|
|
|
70
69
|
def __call__(
|
|
71
70
|
self,
|
|
72
|
-
pages:
|
|
71
|
+
pages: list[np.ndarray | tf.Tensor],
|
|
73
72
|
**kwargs: Any,
|
|
74
73
|
) -> Document:
|
|
75
74
|
# Dimension check
|
|
@@ -99,10 +98,13 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
99
98
|
origin_pages_orientations = None
|
|
100
99
|
if self.straighten_pages:
|
|
101
100
|
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
101
|
+
# update page shapes after straightening
|
|
102
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
103
|
+
|
|
102
104
|
# Forward again to get predictions on straight pages
|
|
103
|
-
loc_preds = self.det_predictor(pages, **kwargs)
|
|
105
|
+
loc_preds = self.det_predictor(pages, **kwargs)
|
|
104
106
|
|
|
105
|
-
dict_loc_preds:
|
|
107
|
+
dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
|
|
106
108
|
|
|
107
109
|
# Detach objectness scores from loc_preds
|
|
108
110
|
objectness_scores = {}
|
|
@@ -119,7 +121,11 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
119
121
|
crops = {}
|
|
120
122
|
for class_name in dict_loc_preds.keys():
|
|
121
123
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
122
|
-
pages,
|
|
124
|
+
pages,
|
|
125
|
+
dict_loc_preds[class_name],
|
|
126
|
+
channels_last=True,
|
|
127
|
+
assume_straight_pages=self.assume_straight_pages,
|
|
128
|
+
assume_horizontal=self._page_orientation_disabled,
|
|
123
129
|
)
|
|
124
130
|
|
|
125
131
|
# Rectify crop orientation
|
|
@@ -141,18 +147,18 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
141
147
|
if not crop_orientations:
|
|
142
148
|
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
|
|
143
149
|
|
|
144
|
-
boxes:
|
|
145
|
-
text_preds:
|
|
146
|
-
word_crop_orientations:
|
|
150
|
+
boxes: dict = {}
|
|
151
|
+
text_preds: dict = {}
|
|
152
|
+
word_crop_orientations: dict = {}
|
|
147
153
|
for class_name in dict_loc_preds.keys():
|
|
148
154
|
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
|
|
149
155
|
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
|
|
150
156
|
)
|
|
151
157
|
|
|
152
|
-
boxes_per_page:
|
|
153
|
-
objectness_scores_per_page:
|
|
154
|
-
text_preds_per_page:
|
|
155
|
-
crop_orientations_per_page:
|
|
158
|
+
boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
|
|
159
|
+
objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
|
|
160
|
+
text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
|
|
161
|
+
crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
|
|
156
162
|
|
|
157
163
|
if self.detect_language:
|
|
158
164
|
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
|
|
@@ -173,7 +179,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
173
179
|
return out
|
|
174
180
|
|
|
175
181
|
@staticmethod
|
|
176
|
-
def get_text(text_pred:
|
|
182
|
+
def get_text(text_pred: dict) -> str:
|
|
177
183
|
text = []
|
|
178
184
|
for value in text_pred.values():
|
|
179
185
|
text += [item[0] for item in value]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Tuple, Union
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
import torch
|
|
@@ -19,7 +18,7 @@ class FASTConvLayer(nn.Module):
|
|
|
19
18
|
self,
|
|
20
19
|
in_channels: int,
|
|
21
20
|
out_channels: int,
|
|
22
|
-
kernel_size:
|
|
21
|
+
kernel_size: int | tuple[int, int],
|
|
23
22
|
stride: int = 1,
|
|
24
23
|
dilation: int = 1,
|
|
25
24
|
groups: int = 1,
|
|
@@ -93,9 +92,7 @@ class FASTConvLayer(nn.Module):
|
|
|
93
92
|
|
|
94
93
|
# The following logic is used to reparametrize the layer
|
|
95
94
|
# Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py
|
|
96
|
-
def _identity_to_conv(
|
|
97
|
-
self, identity: Union[nn.BatchNorm2d, None]
|
|
98
|
-
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
|
95
|
+
def _identity_to_conv(self, identity: nn.BatchNorm2d | None) -> tuple[torch.Tensor, torch.Tensor] | tuple[int, int]:
|
|
99
96
|
if identity is None or identity.running_var is None:
|
|
100
97
|
return 0, 0
|
|
101
98
|
if not hasattr(self, "id_tensor"):
|
|
@@ -106,18 +103,18 @@ class FASTConvLayer(nn.Module):
|
|
|
106
103
|
id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
|
|
107
104
|
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
108
105
|
kernel = self.id_tensor
|
|
109
|
-
std = (identity.running_var + identity.eps).sqrt()
|
|
106
|
+
std = (identity.running_var + identity.eps).sqrt() # type: ignore
|
|
110
107
|
t = (identity.weight / std).reshape(-1, 1, 1, 1)
|
|
111
108
|
return kernel * t, identity.bias - identity.running_mean * identity.weight / std
|
|
112
109
|
|
|
113
|
-
def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) ->
|
|
110
|
+
def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
|
|
114
111
|
kernel = conv.weight
|
|
115
112
|
kernel = self._pad_to_mxn_tensor(kernel)
|
|
116
113
|
std = (bn.running_var + bn.eps).sqrt() # type: ignore
|
|
117
114
|
t = (bn.weight / std).reshape(-1, 1, 1, 1)
|
|
118
115
|
return kernel * t, bn.bias - bn.running_mean * bn.weight / std
|
|
119
116
|
|
|
120
|
-
def _get_equivalent_kernel_bias(self) ->
|
|
117
|
+
def _get_equivalent_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
121
118
|
kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
|
|
122
119
|
if self.ver_conv is not None:
|
|
123
120
|
kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
@@ -21,7 +21,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
21
21
|
self,
|
|
22
22
|
in_channels: int,
|
|
23
23
|
out_channels: int,
|
|
24
|
-
kernel_size:
|
|
24
|
+
kernel_size: int | tuple[int, int],
|
|
25
25
|
stride: int = 1,
|
|
26
26
|
dilation: int = 1,
|
|
27
27
|
groups: int = 1,
|
|
@@ -103,9 +103,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
103
103
|
|
|
104
104
|
# The following logic is used to reparametrize the layer
|
|
105
105
|
# Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py
|
|
106
|
-
def _identity_to_conv(
|
|
107
|
-
self, identity: layers.BatchNormalization
|
|
108
|
-
) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]:
|
|
106
|
+
def _identity_to_conv(self, identity: layers.BatchNormalization) -> tuple[tf.Tensor, tf.Tensor] | tuple[int, int]:
|
|
109
107
|
if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"):
|
|
110
108
|
return 0, 0
|
|
111
109
|
if not hasattr(self, "id_tensor"):
|
|
@@ -120,7 +118,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
120
118
|
t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
|
|
121
119
|
return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
|
|
122
120
|
|
|
123
|
-
def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) ->
|
|
121
|
+
def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> tuple[tf.Tensor, tf.Tensor]:
|
|
124
122
|
kernel = conv.kernel
|
|
125
123
|
kernel = self._pad_to_mxn_tensor(kernel)
|
|
126
124
|
std = tf.sqrt(bn.moving_variance + bn.epsilon)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,8 @@
|
|
|
6
6
|
# This module 'transformer.py' is inspired by https://github.com/wenwenyu/MASTER-pytorch and Decoder is borrowed
|
|
7
7
|
|
|
8
8
|
import math
|
|
9
|
-
from
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from typing import Any
|
|
10
11
|
|
|
11
12
|
import torch
|
|
12
13
|
from torch import nn
|
|
@@ -33,26 +34,24 @@ class PositionalEncoding(nn.Module):
|
|
|
33
34
|
"""Forward pass
|
|
34
35
|
|
|
35
36
|
Args:
|
|
36
|
-
----
|
|
37
37
|
x: embeddings (batch, max_len, d_model)
|
|
38
38
|
|
|
39
|
-
Returns
|
|
40
|
-
-------
|
|
39
|
+
Returns:
|
|
41
40
|
positional embeddings (batch, max_len, d_model)
|
|
42
41
|
"""
|
|
43
|
-
x = x + self.pe[:, : x.size(1)]
|
|
42
|
+
x = x + self.pe[:, : x.size(1)] # type: ignore[index]
|
|
44
43
|
return self.dropout(x)
|
|
45
44
|
|
|
46
45
|
|
|
47
46
|
def scaled_dot_product_attention(
|
|
48
|
-
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask:
|
|
49
|
-
) ->
|
|
47
|
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor | None = None
|
|
48
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
50
49
|
"""Scaled Dot-Product Attention"""
|
|
51
50
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
|
|
52
51
|
if mask is not None:
|
|
53
52
|
# NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
|
|
54
|
-
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
55
|
-
p_attn = torch.softmax(scores, dim=-1)
|
|
53
|
+
scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined]
|
|
54
|
+
p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload]
|
|
56
55
|
return torch.matmul(p_attn, value), p_attn
|
|
57
56
|
|
|
58
57
|
|
|
@@ -130,7 +129,7 @@ class EncoderBlock(nn.Module):
|
|
|
130
129
|
PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
|
|
131
130
|
])
|
|
132
131
|
|
|
133
|
-
def forward(self, x: torch.Tensor, mask:
|
|
132
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
|
134
133
|
output = x
|
|
135
134
|
|
|
136
135
|
for i in range(self.num_layers):
|
|
@@ -183,8 +182,8 @@ class Decoder(nn.Module):
|
|
|
183
182
|
self,
|
|
184
183
|
tgt: torch.Tensor,
|
|
185
184
|
memory: torch.Tensor,
|
|
186
|
-
source_mask:
|
|
187
|
-
target_mask:
|
|
185
|
+
source_mask: torch.Tensor | None = None,
|
|
186
|
+
target_mask: torch.Tensor | None = None,
|
|
188
187
|
) -> torch.Tensor:
|
|
189
188
|
tgt = self.embed(tgt) * math.sqrt(self.d_model)
|
|
190
189
|
pos_enc_tgt = self.positional_encoding(tgt)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import tensorflow as tf
|
|
10
11
|
from tensorflow.keras import layers
|
|
@@ -13,8 +14,6 @@ from doctr.utils.repr import NestedObject
|
|
|
13
14
|
|
|
14
15
|
__all__ = ["Decoder", "PositionalEncoding", "EncoderBlock", "PositionwiseFeedForward", "MultiHeadAttention"]
|
|
15
16
|
|
|
16
|
-
tf.config.run_functions_eagerly(True)
|
|
17
|
-
|
|
18
17
|
|
|
19
18
|
class PositionalEncoding(layers.Layer, NestedObject):
|
|
20
19
|
"""Compute positional encoding"""
|
|
@@ -45,12 +44,10 @@ class PositionalEncoding(layers.Layer, NestedObject):
|
|
|
45
44
|
"""Forward pass
|
|
46
45
|
|
|
47
46
|
Args:
|
|
48
|
-
----
|
|
49
47
|
x: embeddings (batch, max_len, d_model)
|
|
50
48
|
**kwargs: additional arguments
|
|
51
49
|
|
|
52
|
-
Returns
|
|
53
|
-
-------
|
|
50
|
+
Returns:
|
|
54
51
|
positional embeddings (batch, max_len, d_model)
|
|
55
52
|
"""
|
|
56
53
|
if x.dtype == tf.float16: # amp fix: cast to half
|
|
@@ -62,8 +59,8 @@ class PositionalEncoding(layers.Layer, NestedObject):
|
|
|
62
59
|
|
|
63
60
|
@tf.function
|
|
64
61
|
def scaled_dot_product_attention(
|
|
65
|
-
query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask:
|
|
66
|
-
) ->
|
|
62
|
+
query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: tf.Tensor | None = None
|
|
63
|
+
) -> tuple[tf.Tensor, tf.Tensor]:
|
|
67
64
|
"""Scaled Dot-Product Attention"""
|
|
68
65
|
scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) / math.sqrt(query.shape[-1])
|
|
69
66
|
if mask is not None:
|
|
@@ -162,7 +159,7 @@ class EncoderBlock(layers.Layer, NestedObject):
|
|
|
162
159
|
PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
|
|
163
160
|
]
|
|
164
161
|
|
|
165
|
-
def call(self, x: tf.Tensor, mask:
|
|
162
|
+
def call(self, x: tf.Tensor, mask: tf.Tensor | None = None, **kwargs: Any) -> tf.Tensor:
|
|
166
163
|
output = x
|
|
167
164
|
|
|
168
165
|
for i in range(self.num_layers):
|
|
@@ -212,8 +209,8 @@ class Decoder(layers.Layer, NestedObject):
|
|
|
212
209
|
self,
|
|
213
210
|
tgt: tf.Tensor,
|
|
214
211
|
memory: tf.Tensor,
|
|
215
|
-
source_mask:
|
|
216
|
-
target_mask:
|
|
212
|
+
source_mask: tf.Tensor | None = None,
|
|
213
|
+
target_mask: tf.Tensor | None = None,
|
|
217
214
|
**kwargs: Any,
|
|
218
215
|
) -> tf.Tensor:
|
|
219
216
|
tgt = self.embed(tgt, **kwargs) * math.sqrt(self.d_model)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import Tuple
|
|
8
7
|
|
|
9
8
|
import torch
|
|
10
9
|
from torch import nn
|
|
@@ -15,12 +14,12 @@ __all__ = ["PatchEmbedding"]
|
|
|
15
14
|
class PatchEmbedding(nn.Module):
|
|
16
15
|
"""Compute 2D patch embeddings with cls token and positional encoding"""
|
|
17
16
|
|
|
18
|
-
def __init__(self, input_shape:
|
|
17
|
+
def __init__(self, input_shape: tuple[int, int, int], embed_dim: int, patch_size: tuple[int, int]) -> None:
|
|
19
18
|
super().__init__()
|
|
20
19
|
channels, height, width = input_shape
|
|
21
20
|
self.patch_size = patch_size
|
|
22
21
|
self.interpolate = True if patch_size[0] == patch_size[1] else False
|
|
23
|
-
self.grid_size = tuple(
|
|
22
|
+
self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size))
|
|
24
23
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
25
24
|
|
|
26
25
|
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import layers
|
|
@@ -17,12 +17,12 @@ __all__ = ["PatchEmbedding"]
|
|
|
17
17
|
class PatchEmbedding(layers.Layer, NestedObject):
|
|
18
18
|
"""Compute 2D patch embeddings with cls token and positional encoding"""
|
|
19
19
|
|
|
20
|
-
def __init__(self, input_shape:
|
|
20
|
+
def __init__(self, input_shape: tuple[int, int, int], embed_dim: int, patch_size: tuple[int, int]) -> None:
|
|
21
21
|
super().__init__()
|
|
22
22
|
height, width, _ = input_shape
|
|
23
23
|
self.patch_size = patch_size
|
|
24
24
|
self.interpolate = True if patch_size[0] == patch_size[1] else False
|
|
25
|
-
self.grid_size = tuple(
|
|
25
|
+
self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size))
|
|
26
26
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
27
27
|
|
|
28
28
|
self.cls_token = self.add_weight(shape=(1, 1, embed_dim), initializer="zeros", trainable=True, name="cls_token")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|