python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +8 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +7 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +4 -5
- doctr/datasets/ic13.py +4 -5
- doctr/datasets/iiit5k.py +6 -5
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +6 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +6 -5
- doctr/datasets/svt.py +4 -5
- doctr/datasets/synthtext.py +4 -5
- doctr/datasets/utils.py +34 -29
- doctr/datasets/vocabs.py +17 -7
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +2 -7
- doctr/io/elements.py +59 -79
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +30 -48
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +8 -11
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +5 -17
- doctr/models/classification/mobilenet/tensorflow.py +8 -21
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +6 -8
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +20 -31
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +8 -15
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +9 -12
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +6 -12
- doctr/models/classification/zoo.py +19 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +14 -26
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +14 -23
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +3 -7
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +18 -19
- doctr/models/kie_predictor/tensorflow.py +13 -14
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +12 -13
- doctr/models/predictor/tensorflow.py +8 -9
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +11 -23
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +12 -22
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +16 -22
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +12 -21
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +12 -20
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +14 -17
- doctr/models/utils/tensorflow.py +17 -16
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +16 -47
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from doctr.models.builder import KIEDocumentBuilder
|
|
9
9
|
|
|
@@ -17,7 +17,6 @@ class _KIEPredictor(_OCRPredictor):
|
|
|
17
17
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
18
18
|
|
|
19
19
|
Args:
|
|
20
|
-
----
|
|
21
20
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
22
21
|
without rotated textual elements.
|
|
23
22
|
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
|
|
@@ -30,8 +29,8 @@ class _KIEPredictor(_OCRPredictor):
|
|
|
30
29
|
kwargs: keyword args of `DocumentBuilder`
|
|
31
30
|
"""
|
|
32
31
|
|
|
33
|
-
crop_orientation_predictor:
|
|
34
|
-
page_orientation_predictor:
|
|
32
|
+
crop_orientation_predictor: OrientationPredictor | None
|
|
33
|
+
page_orientation_predictor: OrientationPredictor | None
|
|
35
34
|
|
|
36
35
|
def __init__(
|
|
37
36
|
self,
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -24,7 +24,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
-
----
|
|
28
27
|
det_predictor: detection module
|
|
29
28
|
reco_predictor: recognition module
|
|
30
29
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -52,8 +51,8 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
52
51
|
**kwargs: Any,
|
|
53
52
|
) -> None:
|
|
54
53
|
nn.Module.__init__(self)
|
|
55
|
-
self.det_predictor = det_predictor.eval()
|
|
56
|
-
self.reco_predictor = reco_predictor.eval()
|
|
54
|
+
self.det_predictor = det_predictor.eval()
|
|
55
|
+
self.reco_predictor = reco_predictor.eval()
|
|
57
56
|
_KIEPredictor.__init__(
|
|
58
57
|
self,
|
|
59
58
|
assume_straight_pages,
|
|
@@ -69,7 +68,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
69
68
|
@torch.inference_mode()
|
|
70
69
|
def forward(
|
|
71
70
|
self,
|
|
72
|
-
pages:
|
|
71
|
+
pages: list[np.ndarray | torch.Tensor],
|
|
73
72
|
**kwargs: Any,
|
|
74
73
|
) -> Document:
|
|
75
74
|
# Dimension check
|
|
@@ -89,7 +88,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
89
88
|
for out_map in out_maps
|
|
90
89
|
]
|
|
91
90
|
if self.detect_orientation:
|
|
92
|
-
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
91
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
93
92
|
orientations = [
|
|
94
93
|
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
95
94
|
]
|
|
@@ -98,14 +97,14 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
98
97
|
general_pages_orientations = None
|
|
99
98
|
origin_pages_orientations = None
|
|
100
99
|
if self.straighten_pages:
|
|
101
|
-
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
100
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
102
101
|
# update page shapes after straightening
|
|
103
102
|
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
104
103
|
|
|
105
104
|
# Forward again to get predictions on straight pages
|
|
106
105
|
loc_preds = self.det_predictor(pages, **kwargs)
|
|
107
106
|
|
|
108
|
-
dict_loc_preds:
|
|
107
|
+
dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
|
|
109
108
|
|
|
110
109
|
# Detach objectness scores from loc_preds
|
|
111
110
|
objectness_scores = {}
|
|
@@ -125,7 +124,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
125
124
|
crops = {}
|
|
126
125
|
for class_name in dict_loc_preds.keys():
|
|
127
126
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
128
|
-
pages,
|
|
127
|
+
pages,
|
|
129
128
|
dict_loc_preds[class_name],
|
|
130
129
|
channels_last=channels_last,
|
|
131
130
|
assume_straight_pages=self.assume_straight_pages,
|
|
@@ -150,18 +149,18 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
150
149
|
if not crop_orientations:
|
|
151
150
|
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
|
|
152
151
|
|
|
153
|
-
boxes:
|
|
154
|
-
text_preds:
|
|
155
|
-
word_crop_orientations:
|
|
152
|
+
boxes: dict = {}
|
|
153
|
+
text_preds: dict = {}
|
|
154
|
+
word_crop_orientations: dict = {}
|
|
156
155
|
for class_name in dict_loc_preds.keys():
|
|
157
156
|
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
|
|
158
157
|
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
|
|
159
158
|
)
|
|
160
159
|
|
|
161
|
-
boxes_per_page:
|
|
162
|
-
objectness_scores_per_page:
|
|
163
|
-
text_preds_per_page:
|
|
164
|
-
crop_orientations_per_page:
|
|
160
|
+
boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
|
|
161
|
+
objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
|
|
162
|
+
text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
|
|
163
|
+
crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
|
|
165
164
|
|
|
166
165
|
if self.detect_language:
|
|
167
166
|
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
|
|
@@ -170,7 +169,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
170
169
|
languages_dict = None
|
|
171
170
|
|
|
172
171
|
out = self.doc_builder(
|
|
173
|
-
pages,
|
|
172
|
+
pages,
|
|
174
173
|
boxes_per_page,
|
|
175
174
|
objectness_scores_per_page,
|
|
176
175
|
text_preds_per_page,
|
|
@@ -182,7 +181,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
182
181
|
return out
|
|
183
182
|
|
|
184
183
|
@staticmethod
|
|
185
|
-
def get_text(text_pred:
|
|
184
|
+
def get_text(text_pred: dict) -> str:
|
|
186
185
|
text = []
|
|
187
186
|
for value in text_pred.values():
|
|
188
187
|
text += [item[0] for item in value]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
@@ -24,7 +24,6 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
-
----
|
|
28
27
|
det_predictor: detection module
|
|
29
28
|
reco_predictor: recognition module
|
|
30
29
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -69,7 +68,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
69
68
|
|
|
70
69
|
def __call__(
|
|
71
70
|
self,
|
|
72
|
-
pages:
|
|
71
|
+
pages: list[np.ndarray | tf.Tensor],
|
|
73
72
|
**kwargs: Any,
|
|
74
73
|
) -> Document:
|
|
75
74
|
# Dimension check
|
|
@@ -103,9 +102,9 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
103
102
|
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
104
103
|
|
|
105
104
|
# Forward again to get predictions on straight pages
|
|
106
|
-
loc_preds = self.det_predictor(pages, **kwargs)
|
|
105
|
+
loc_preds = self.det_predictor(pages, **kwargs)
|
|
107
106
|
|
|
108
|
-
dict_loc_preds:
|
|
107
|
+
dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
|
|
109
108
|
|
|
110
109
|
# Detach objectness scores from loc_preds
|
|
111
110
|
objectness_scores = {}
|
|
@@ -148,18 +147,18 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
148
147
|
if not crop_orientations:
|
|
149
148
|
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
|
|
150
149
|
|
|
151
|
-
boxes:
|
|
152
|
-
text_preds:
|
|
153
|
-
word_crop_orientations:
|
|
150
|
+
boxes: dict = {}
|
|
151
|
+
text_preds: dict = {}
|
|
152
|
+
word_crop_orientations: dict = {}
|
|
154
153
|
for class_name in dict_loc_preds.keys():
|
|
155
154
|
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
|
|
156
155
|
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
|
|
157
156
|
)
|
|
158
157
|
|
|
159
|
-
boxes_per_page:
|
|
160
|
-
objectness_scores_per_page:
|
|
161
|
-
text_preds_per_page:
|
|
162
|
-
crop_orientations_per_page:
|
|
158
|
+
boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
|
|
159
|
+
objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
|
|
160
|
+
text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
|
|
161
|
+
crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
|
|
163
162
|
|
|
164
163
|
if self.detect_language:
|
|
165
164
|
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
|
|
@@ -180,7 +179,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
180
179
|
return out
|
|
181
180
|
|
|
182
181
|
@staticmethod
|
|
183
|
-
def get_text(text_pred:
|
|
182
|
+
def get_text(text_pred: dict) -> str:
|
|
184
183
|
text = []
|
|
185
184
|
for value in text_pred.values():
|
|
186
185
|
text += [item[0] for item in value]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Tuple, Union
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
import torch
|
|
@@ -19,7 +18,7 @@ class FASTConvLayer(nn.Module):
|
|
|
19
18
|
self,
|
|
20
19
|
in_channels: int,
|
|
21
20
|
out_channels: int,
|
|
22
|
-
kernel_size:
|
|
21
|
+
kernel_size: int | tuple[int, int],
|
|
23
22
|
stride: int = 1,
|
|
24
23
|
dilation: int = 1,
|
|
25
24
|
groups: int = 1,
|
|
@@ -93,9 +92,7 @@ class FASTConvLayer(nn.Module):
|
|
|
93
92
|
|
|
94
93
|
# The following logic is used to reparametrize the layer
|
|
95
94
|
# Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py
|
|
96
|
-
def _identity_to_conv(
|
|
97
|
-
self, identity: Union[nn.BatchNorm2d, None]
|
|
98
|
-
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
|
95
|
+
def _identity_to_conv(self, identity: nn.BatchNorm2d | None) -> tuple[torch.Tensor, torch.Tensor] | tuple[int, int]:
|
|
99
96
|
if identity is None or identity.running_var is None:
|
|
100
97
|
return 0, 0
|
|
101
98
|
if not hasattr(self, "id_tensor"):
|
|
@@ -106,18 +103,18 @@ class FASTConvLayer(nn.Module):
|
|
|
106
103
|
id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
|
|
107
104
|
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
108
105
|
kernel = self.id_tensor
|
|
109
|
-
std = (identity.running_var + identity.eps).sqrt()
|
|
106
|
+
std = (identity.running_var + identity.eps).sqrt() # type: ignore
|
|
110
107
|
t = (identity.weight / std).reshape(-1, 1, 1, 1)
|
|
111
108
|
return kernel * t, identity.bias - identity.running_mean * identity.weight / std
|
|
112
109
|
|
|
113
|
-
def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) ->
|
|
110
|
+
def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
|
|
114
111
|
kernel = conv.weight
|
|
115
112
|
kernel = self._pad_to_mxn_tensor(kernel)
|
|
116
113
|
std = (bn.running_var + bn.eps).sqrt() # type: ignore
|
|
117
114
|
t = (bn.weight / std).reshape(-1, 1, 1, 1)
|
|
118
115
|
return kernel * t, bn.bias - bn.running_mean * bn.weight / std
|
|
119
116
|
|
|
120
|
-
def _get_equivalent_kernel_bias(self) ->
|
|
117
|
+
def _get_equivalent_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
121
118
|
kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
|
|
122
119
|
if self.ver_conv is not None:
|
|
123
120
|
kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
@@ -21,7 +21,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
21
21
|
self,
|
|
22
22
|
in_channels: int,
|
|
23
23
|
out_channels: int,
|
|
24
|
-
kernel_size:
|
|
24
|
+
kernel_size: int | tuple[int, int],
|
|
25
25
|
stride: int = 1,
|
|
26
26
|
dilation: int = 1,
|
|
27
27
|
groups: int = 1,
|
|
@@ -103,9 +103,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
103
103
|
|
|
104
104
|
# The following logic is used to reparametrize the layer
|
|
105
105
|
# Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py
|
|
106
|
-
def _identity_to_conv(
|
|
107
|
-
self, identity: layers.BatchNormalization
|
|
108
|
-
) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]:
|
|
106
|
+
def _identity_to_conv(self, identity: layers.BatchNormalization) -> tuple[tf.Tensor, tf.Tensor] | tuple[int, int]:
|
|
109
107
|
if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"):
|
|
110
108
|
return 0, 0
|
|
111
109
|
if not hasattr(self, "id_tensor"):
|
|
@@ -120,7 +118,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
120
118
|
t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
|
|
121
119
|
return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
|
|
122
120
|
|
|
123
|
-
def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) ->
|
|
121
|
+
def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> tuple[tf.Tensor, tf.Tensor]:
|
|
124
122
|
kernel = conv.kernel
|
|
125
123
|
kernel = self._pad_to_mxn_tensor(kernel)
|
|
126
124
|
std = tf.sqrt(bn.moving_variance + bn.epsilon)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,8 @@
|
|
|
6
6
|
# This module 'transformer.py' is inspired by https://github.com/wenwenyu/MASTER-pytorch and Decoder is borrowed
|
|
7
7
|
|
|
8
8
|
import math
|
|
9
|
-
from
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from typing import Any
|
|
10
11
|
|
|
11
12
|
import torch
|
|
12
13
|
from torch import nn
|
|
@@ -33,26 +34,24 @@ class PositionalEncoding(nn.Module):
|
|
|
33
34
|
"""Forward pass
|
|
34
35
|
|
|
35
36
|
Args:
|
|
36
|
-
----
|
|
37
37
|
x: embeddings (batch, max_len, d_model)
|
|
38
38
|
|
|
39
|
-
Returns
|
|
40
|
-
-------
|
|
39
|
+
Returns:
|
|
41
40
|
positional embeddings (batch, max_len, d_model)
|
|
42
41
|
"""
|
|
43
|
-
x = x + self.pe[:, : x.size(1)]
|
|
42
|
+
x = x + self.pe[:, : x.size(1)] # type: ignore[index]
|
|
44
43
|
return self.dropout(x)
|
|
45
44
|
|
|
46
45
|
|
|
47
46
|
def scaled_dot_product_attention(
|
|
48
|
-
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask:
|
|
49
|
-
) ->
|
|
47
|
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor | None = None
|
|
48
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
50
49
|
"""Scaled Dot-Product Attention"""
|
|
51
50
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
|
|
52
51
|
if mask is not None:
|
|
53
52
|
# NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
|
|
54
|
-
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
55
|
-
p_attn = torch.softmax(scores, dim=-1)
|
|
53
|
+
scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined]
|
|
54
|
+
p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload]
|
|
56
55
|
return torch.matmul(p_attn, value), p_attn
|
|
57
56
|
|
|
58
57
|
|
|
@@ -130,7 +129,7 @@ class EncoderBlock(nn.Module):
|
|
|
130
129
|
PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
|
|
131
130
|
])
|
|
132
131
|
|
|
133
|
-
def forward(self, x: torch.Tensor, mask:
|
|
132
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
|
134
133
|
output = x
|
|
135
134
|
|
|
136
135
|
for i in range(self.num_layers):
|
|
@@ -183,8 +182,8 @@ class Decoder(nn.Module):
|
|
|
183
182
|
self,
|
|
184
183
|
tgt: torch.Tensor,
|
|
185
184
|
memory: torch.Tensor,
|
|
186
|
-
source_mask:
|
|
187
|
-
target_mask:
|
|
185
|
+
source_mask: torch.Tensor | None = None,
|
|
186
|
+
target_mask: torch.Tensor | None = None,
|
|
188
187
|
) -> torch.Tensor:
|
|
189
188
|
tgt = self.embed(tgt) * math.sqrt(self.d_model)
|
|
190
189
|
pos_enc_tgt = self.positional_encoding(tgt)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import tensorflow as tf
|
|
10
11
|
from tensorflow.keras import layers
|
|
@@ -43,12 +44,10 @@ class PositionalEncoding(layers.Layer, NestedObject):
|
|
|
43
44
|
"""Forward pass
|
|
44
45
|
|
|
45
46
|
Args:
|
|
46
|
-
----
|
|
47
47
|
x: embeddings (batch, max_len, d_model)
|
|
48
48
|
**kwargs: additional arguments
|
|
49
49
|
|
|
50
|
-
Returns
|
|
51
|
-
-------
|
|
50
|
+
Returns:
|
|
52
51
|
positional embeddings (batch, max_len, d_model)
|
|
53
52
|
"""
|
|
54
53
|
if x.dtype == tf.float16: # amp fix: cast to half
|
|
@@ -60,8 +59,8 @@ class PositionalEncoding(layers.Layer, NestedObject):
|
|
|
60
59
|
|
|
61
60
|
@tf.function
|
|
62
61
|
def scaled_dot_product_attention(
|
|
63
|
-
query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask:
|
|
64
|
-
) ->
|
|
62
|
+
query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: tf.Tensor | None = None
|
|
63
|
+
) -> tuple[tf.Tensor, tf.Tensor]:
|
|
65
64
|
"""Scaled Dot-Product Attention"""
|
|
66
65
|
scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) / math.sqrt(query.shape[-1])
|
|
67
66
|
if mask is not None:
|
|
@@ -160,7 +159,7 @@ class EncoderBlock(layers.Layer, NestedObject):
|
|
|
160
159
|
PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
|
|
161
160
|
]
|
|
162
161
|
|
|
163
|
-
def call(self, x: tf.Tensor, mask:
|
|
162
|
+
def call(self, x: tf.Tensor, mask: tf.Tensor | None = None, **kwargs: Any) -> tf.Tensor:
|
|
164
163
|
output = x
|
|
165
164
|
|
|
166
165
|
for i in range(self.num_layers):
|
|
@@ -210,8 +209,8 @@ class Decoder(layers.Layer, NestedObject):
|
|
|
210
209
|
self,
|
|
211
210
|
tgt: tf.Tensor,
|
|
212
211
|
memory: tf.Tensor,
|
|
213
|
-
source_mask:
|
|
214
|
-
target_mask:
|
|
212
|
+
source_mask: tf.Tensor | None = None,
|
|
213
|
+
target_mask: tf.Tensor | None = None,
|
|
215
214
|
**kwargs: Any,
|
|
216
215
|
) -> tf.Tensor:
|
|
217
216
|
tgt = self.embed(tgt, **kwargs) * math.sqrt(self.d_model)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import Tuple
|
|
8
7
|
|
|
9
8
|
import torch
|
|
10
9
|
from torch import nn
|
|
@@ -15,7 +14,7 @@ __all__ = ["PatchEmbedding"]
|
|
|
15
14
|
class PatchEmbedding(nn.Module):
|
|
16
15
|
"""Compute 2D patch embeddings with cls token and positional encoding"""
|
|
17
16
|
|
|
18
|
-
def __init__(self, input_shape:
|
|
17
|
+
def __init__(self, input_shape: tuple[int, int, int], embed_dim: int, patch_size: tuple[int, int]) -> None:
|
|
19
18
|
super().__init__()
|
|
20
19
|
channels, height, width = input_shape
|
|
21
20
|
self.patch_size = patch_size
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import layers
|
|
@@ -17,7 +17,7 @@ __all__ = ["PatchEmbedding"]
|
|
|
17
17
|
class PatchEmbedding(layers.Layer, NestedObject):
|
|
18
18
|
"""Compute 2D patch embeddings with cls token and positional encoding"""
|
|
19
19
|
|
|
20
|
-
def __init__(self, input_shape:
|
|
20
|
+
def __init__(self, input_shape: tuple[int, int, int], embed_dim: int, patch_size: tuple[int, int]) -> None:
|
|
21
21
|
super().__init__()
|
|
22
22
|
height, width, _ = input_shape
|
|
23
23
|
self.patch_size = patch_size
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|