python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -24,7 +24,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
-
----
|
|
28
27
|
det_predictor: detection module
|
|
29
28
|
reco_predictor: recognition module
|
|
30
29
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -52,8 +51,8 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
52
51
|
**kwargs: Any,
|
|
53
52
|
) -> None:
|
|
54
53
|
nn.Module.__init__(self)
|
|
55
|
-
self.det_predictor = det_predictor.eval()
|
|
56
|
-
self.reco_predictor = reco_predictor.eval()
|
|
54
|
+
self.det_predictor = det_predictor.eval()
|
|
55
|
+
self.reco_predictor = reco_predictor.eval()
|
|
57
56
|
_KIEPredictor.__init__(
|
|
58
57
|
self,
|
|
59
58
|
assume_straight_pages,
|
|
@@ -69,7 +68,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
69
68
|
@torch.inference_mode()
|
|
70
69
|
def forward(
|
|
71
70
|
self,
|
|
72
|
-
pages:
|
|
71
|
+
pages: list[np.ndarray | torch.Tensor],
|
|
73
72
|
**kwargs: Any,
|
|
74
73
|
) -> Document:
|
|
75
74
|
# Dimension check
|
|
@@ -89,7 +88,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
89
88
|
for out_map in out_maps
|
|
90
89
|
]
|
|
91
90
|
if self.detect_orientation:
|
|
92
|
-
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
91
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
93
92
|
orientations = [
|
|
94
93
|
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
95
94
|
]
|
|
@@ -98,14 +97,14 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
98
97
|
general_pages_orientations = None
|
|
99
98
|
origin_pages_orientations = None
|
|
100
99
|
if self.straighten_pages:
|
|
101
|
-
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
100
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
102
101
|
# update page shapes after straightening
|
|
103
102
|
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
104
103
|
|
|
105
104
|
# Forward again to get predictions on straight pages
|
|
106
105
|
loc_preds = self.det_predictor(pages, **kwargs)
|
|
107
106
|
|
|
108
|
-
dict_loc_preds:
|
|
107
|
+
dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
|
|
109
108
|
|
|
110
109
|
# Detach objectness scores from loc_preds
|
|
111
110
|
objectness_scores = {}
|
|
@@ -125,7 +124,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
125
124
|
crops = {}
|
|
126
125
|
for class_name in dict_loc_preds.keys():
|
|
127
126
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
128
|
-
pages,
|
|
127
|
+
pages,
|
|
129
128
|
dict_loc_preds[class_name],
|
|
130
129
|
channels_last=channels_last,
|
|
131
130
|
assume_straight_pages=self.assume_straight_pages,
|
|
@@ -150,18 +149,18 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
150
149
|
if not crop_orientations:
|
|
151
150
|
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
|
|
152
151
|
|
|
153
|
-
boxes:
|
|
154
|
-
text_preds:
|
|
155
|
-
word_crop_orientations:
|
|
152
|
+
boxes: dict = {}
|
|
153
|
+
text_preds: dict = {}
|
|
154
|
+
word_crop_orientations: dict = {}
|
|
156
155
|
for class_name in dict_loc_preds.keys():
|
|
157
156
|
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
|
|
158
157
|
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
|
|
159
158
|
)
|
|
160
159
|
|
|
161
|
-
boxes_per_page:
|
|
162
|
-
objectness_scores_per_page:
|
|
163
|
-
text_preds_per_page:
|
|
164
|
-
crop_orientations_per_page:
|
|
160
|
+
boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
|
|
161
|
+
objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
|
|
162
|
+
text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
|
|
163
|
+
crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
|
|
165
164
|
|
|
166
165
|
if self.detect_language:
|
|
167
166
|
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
|
|
@@ -170,11 +169,11 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
170
169
|
languages_dict = None
|
|
171
170
|
|
|
172
171
|
out = self.doc_builder(
|
|
173
|
-
pages,
|
|
172
|
+
pages,
|
|
174
173
|
boxes_per_page,
|
|
175
174
|
objectness_scores_per_page,
|
|
176
175
|
text_preds_per_page,
|
|
177
|
-
origin_page_shapes,
|
|
176
|
+
origin_page_shapes,
|
|
178
177
|
crop_orientations_per_page,
|
|
179
178
|
orientations,
|
|
180
179
|
languages_dict,
|
|
@@ -182,7 +181,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
182
181
|
return out
|
|
183
182
|
|
|
184
183
|
@staticmethod
|
|
185
|
-
def get_text(text_pred:
|
|
184
|
+
def get_text(text_pred: dict) -> str:
|
|
186
185
|
text = []
|
|
187
186
|
for value in text_pred.values():
|
|
188
187
|
text += [item[0] for item in value]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
@@ -24,7 +24,6 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
-
----
|
|
28
27
|
det_predictor: detection module
|
|
29
28
|
reco_predictor: recognition module
|
|
30
29
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -69,7 +68,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
69
68
|
|
|
70
69
|
def __call__(
|
|
71
70
|
self,
|
|
72
|
-
pages:
|
|
71
|
+
pages: list[np.ndarray | tf.Tensor],
|
|
73
72
|
**kwargs: Any,
|
|
74
73
|
) -> Document:
|
|
75
74
|
# Dimension check
|
|
@@ -103,9 +102,9 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
103
102
|
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
104
103
|
|
|
105
104
|
# Forward again to get predictions on straight pages
|
|
106
|
-
loc_preds = self.det_predictor(pages, **kwargs)
|
|
105
|
+
loc_preds = self.det_predictor(pages, **kwargs)
|
|
107
106
|
|
|
108
|
-
dict_loc_preds:
|
|
107
|
+
dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
|
|
109
108
|
|
|
110
109
|
# Detach objectness scores from loc_preds
|
|
111
110
|
objectness_scores = {}
|
|
@@ -148,18 +147,18 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
148
147
|
if not crop_orientations:
|
|
149
148
|
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
|
|
150
149
|
|
|
151
|
-
boxes:
|
|
152
|
-
text_preds:
|
|
153
|
-
word_crop_orientations:
|
|
150
|
+
boxes: dict = {}
|
|
151
|
+
text_preds: dict = {}
|
|
152
|
+
word_crop_orientations: dict = {}
|
|
154
153
|
for class_name in dict_loc_preds.keys():
|
|
155
154
|
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
|
|
156
155
|
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
|
|
157
156
|
)
|
|
158
157
|
|
|
159
|
-
boxes_per_page:
|
|
160
|
-
objectness_scores_per_page:
|
|
161
|
-
text_preds_per_page:
|
|
162
|
-
crop_orientations_per_page:
|
|
158
|
+
boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
|
|
159
|
+
objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
|
|
160
|
+
text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
|
|
161
|
+
crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
|
|
163
162
|
|
|
164
163
|
if self.detect_language:
|
|
165
164
|
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
|
|
@@ -172,7 +171,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
172
171
|
boxes_per_page,
|
|
173
172
|
objectness_scores_per_page,
|
|
174
173
|
text_preds_per_page,
|
|
175
|
-
origin_page_shapes,
|
|
174
|
+
origin_page_shapes,
|
|
176
175
|
crop_orientations_per_page,
|
|
177
176
|
orientations,
|
|
178
177
|
languages_dict,
|
|
@@ -180,7 +179,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
180
179
|
return out
|
|
181
180
|
|
|
182
181
|
@staticmethod
|
|
183
|
-
def get_text(text_pred:
|
|
182
|
+
def get_text(text_pred: dict) -> str:
|
|
184
183
|
text = []
|
|
185
184
|
for value in text_pred.values():
|
|
186
185
|
text += [item[0] for item in value]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,15 +1,62 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Tuple, Union
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
import torch
|
|
10
9
|
import torch.nn as nn
|
|
11
10
|
|
|
12
|
-
__all__ = ["FASTConvLayer"]
|
|
11
|
+
__all__ = ["FASTConvLayer", "DropPath", "AdaptiveAvgPool2d"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DropPath(nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
DropPath (Drop Connect) layer. This is a stochastic version of the identity layer.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
# Borrowed from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
|
20
|
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
|
21
|
+
super(DropPath, self).__init__()
|
|
22
|
+
self.drop_prob = drop_prob
|
|
23
|
+
self.scale_by_keep = scale_by_keep
|
|
24
|
+
|
|
25
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
if self.drop_prob == 0.0 or not self.training:
|
|
27
|
+
return x
|
|
28
|
+
keep_prob = 1 - self.drop_prob
|
|
29
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with different dimensions
|
|
30
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
31
|
+
if keep_prob > 0.0 and self.scale_by_keep:
|
|
32
|
+
random_tensor.div_(keep_prob)
|
|
33
|
+
return x * random_tensor
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AdaptiveAvgPool2d(nn.Module):
|
|
37
|
+
"""
|
|
38
|
+
Custom AdaptiveAvgPool2d implementation which is ONNX and `torch.compile` compatible.
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, output_size):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.output_size = output_size
|
|
45
|
+
|
|
46
|
+
def forward(self, x: torch.Tensor):
|
|
47
|
+
H_out, W_out = self.output_size
|
|
48
|
+
N, C, H, W = x.shape
|
|
49
|
+
|
|
50
|
+
out = torch.empty((N, C, H_out, W_out), device=x.device, dtype=x.dtype)
|
|
51
|
+
for oh in range(H_out):
|
|
52
|
+
start_h = (oh * H) // H_out
|
|
53
|
+
end_h = ((oh + 1) * H + H_out - 1) // H_out # ceil((oh+1)*H / H_out)
|
|
54
|
+
for ow in range(W_out):
|
|
55
|
+
start_w = (ow * W) // W_out
|
|
56
|
+
end_w = ((ow + 1) * W + W_out - 1) // W_out # ceil((ow+1)*W / W_out)
|
|
57
|
+
# average over the window
|
|
58
|
+
out[:, :, oh, ow] = x[:, :, start_h:end_h, start_w:end_w].mean(dim=(-2, -1))
|
|
59
|
+
return out
|
|
13
60
|
|
|
14
61
|
|
|
15
62
|
class FASTConvLayer(nn.Module):
|
|
@@ -19,7 +66,7 @@ class FASTConvLayer(nn.Module):
|
|
|
19
66
|
self,
|
|
20
67
|
in_channels: int,
|
|
21
68
|
out_channels: int,
|
|
22
|
-
kernel_size:
|
|
69
|
+
kernel_size: int | tuple[int, int],
|
|
23
70
|
stride: int = 1,
|
|
24
71
|
dilation: int = 1,
|
|
25
72
|
groups: int = 1,
|
|
@@ -93,9 +140,7 @@ class FASTConvLayer(nn.Module):
|
|
|
93
140
|
|
|
94
141
|
# The following logic is used to reparametrize the layer
|
|
95
142
|
# Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py
|
|
96
|
-
def _identity_to_conv(
|
|
97
|
-
self, identity: Union[nn.BatchNorm2d, None]
|
|
98
|
-
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
|
143
|
+
def _identity_to_conv(self, identity: nn.BatchNorm2d | None) -> tuple[torch.Tensor, torch.Tensor] | tuple[int, int]:
|
|
99
144
|
if identity is None or identity.running_var is None:
|
|
100
145
|
return 0, 0
|
|
101
146
|
if not hasattr(self, "id_tensor"):
|
|
@@ -106,18 +151,18 @@ class FASTConvLayer(nn.Module):
|
|
|
106
151
|
id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
|
|
107
152
|
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
108
153
|
kernel = self.id_tensor
|
|
109
|
-
std = (identity.running_var + identity.eps).sqrt()
|
|
154
|
+
std = (identity.running_var + identity.eps).sqrt() # type: ignore
|
|
110
155
|
t = (identity.weight / std).reshape(-1, 1, 1, 1)
|
|
111
156
|
return kernel * t, identity.bias - identity.running_mean * identity.weight / std
|
|
112
157
|
|
|
113
|
-
def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) ->
|
|
158
|
+
def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
|
|
114
159
|
kernel = conv.weight
|
|
115
160
|
kernel = self._pad_to_mxn_tensor(kernel)
|
|
116
161
|
std = (bn.running_var + bn.eps).sqrt() # type: ignore
|
|
117
162
|
t = (bn.weight / std).reshape(-1, 1, 1, 1)
|
|
118
163
|
return kernel * t, bn.bias - bn.running_mean * bn.weight / std
|
|
119
164
|
|
|
120
|
-
def _get_equivalent_kernel_bias(self) ->
|
|
165
|
+
def _get_equivalent_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
121
166
|
kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
|
|
122
167
|
if self.ver_conv is not None:
|
|
123
168
|
kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
@@ -21,7 +21,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
21
21
|
self,
|
|
22
22
|
in_channels: int,
|
|
23
23
|
out_channels: int,
|
|
24
|
-
kernel_size:
|
|
24
|
+
kernel_size: int | tuple[int, int],
|
|
25
25
|
stride: int = 1,
|
|
26
26
|
dilation: int = 1,
|
|
27
27
|
groups: int = 1,
|
|
@@ -103,9 +103,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
103
103
|
|
|
104
104
|
# The following logic is used to reparametrize the layer
|
|
105
105
|
# Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py
|
|
106
|
-
def _identity_to_conv(
|
|
107
|
-
self, identity: layers.BatchNormalization
|
|
108
|
-
) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]:
|
|
106
|
+
def _identity_to_conv(self, identity: layers.BatchNormalization) -> tuple[tf.Tensor, tf.Tensor] | tuple[int, int]:
|
|
109
107
|
if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"):
|
|
110
108
|
return 0, 0
|
|
111
109
|
if not hasattr(self, "id_tensor"):
|
|
@@ -120,7 +118,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
120
118
|
t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
|
|
121
119
|
return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
|
|
122
120
|
|
|
123
|
-
def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) ->
|
|
121
|
+
def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> tuple[tf.Tensor, tf.Tensor]:
|
|
124
122
|
kernel = conv.kernel
|
|
125
123
|
kernel = self._pad_to_mxn_tensor(kernel)
|
|
126
124
|
std = tf.sqrt(bn.moving_variance + bn.epsilon)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,8 @@
|
|
|
6
6
|
# This module 'transformer.py' is inspired by https://github.com/wenwenyu/MASTER-pytorch and Decoder is borrowed
|
|
7
7
|
|
|
8
8
|
import math
|
|
9
|
-
from
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from typing import Any
|
|
10
11
|
|
|
11
12
|
import torch
|
|
12
13
|
from torch import nn
|
|
@@ -33,26 +34,24 @@ class PositionalEncoding(nn.Module):
|
|
|
33
34
|
"""Forward pass
|
|
34
35
|
|
|
35
36
|
Args:
|
|
36
|
-
----
|
|
37
37
|
x: embeddings (batch, max_len, d_model)
|
|
38
38
|
|
|
39
|
-
Returns
|
|
40
|
-
-------
|
|
39
|
+
Returns:
|
|
41
40
|
positional embeddings (batch, max_len, d_model)
|
|
42
41
|
"""
|
|
43
|
-
x = x + self.pe[:, : x.size(1)]
|
|
42
|
+
x = x + self.pe[:, : x.size(1)] # type: ignore[index]
|
|
44
43
|
return self.dropout(x)
|
|
45
44
|
|
|
46
45
|
|
|
47
46
|
def scaled_dot_product_attention(
|
|
48
|
-
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask:
|
|
49
|
-
) ->
|
|
47
|
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor | None = None
|
|
48
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
50
49
|
"""Scaled Dot-Product Attention"""
|
|
51
50
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
|
|
52
51
|
if mask is not None:
|
|
53
52
|
# NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
|
|
54
|
-
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
55
|
-
p_attn = torch.softmax(scores, dim=-1)
|
|
53
|
+
scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined]
|
|
54
|
+
p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload]
|
|
56
55
|
return torch.matmul(p_attn, value), p_attn
|
|
57
56
|
|
|
58
57
|
|
|
@@ -130,7 +129,7 @@ class EncoderBlock(nn.Module):
|
|
|
130
129
|
PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
|
|
131
130
|
])
|
|
132
131
|
|
|
133
|
-
def forward(self, x: torch.Tensor, mask:
|
|
132
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
|
134
133
|
output = x
|
|
135
134
|
|
|
136
135
|
for i in range(self.num_layers):
|
|
@@ -183,8 +182,8 @@ class Decoder(nn.Module):
|
|
|
183
182
|
self,
|
|
184
183
|
tgt: torch.Tensor,
|
|
185
184
|
memory: torch.Tensor,
|
|
186
|
-
source_mask:
|
|
187
|
-
target_mask:
|
|
185
|
+
source_mask: torch.Tensor | None = None,
|
|
186
|
+
target_mask: torch.Tensor | None = None,
|
|
188
187
|
) -> torch.Tensor:
|
|
189
188
|
tgt = self.embed(tgt) * math.sqrt(self.d_model)
|
|
190
189
|
pos_enc_tgt = self.positional_encoding(tgt)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import tensorflow as tf
|
|
10
11
|
from tensorflow.keras import layers
|
|
@@ -43,12 +44,10 @@ class PositionalEncoding(layers.Layer, NestedObject):
|
|
|
43
44
|
"""Forward pass
|
|
44
45
|
|
|
45
46
|
Args:
|
|
46
|
-
----
|
|
47
47
|
x: embeddings (batch, max_len, d_model)
|
|
48
48
|
**kwargs: additional arguments
|
|
49
49
|
|
|
50
|
-
Returns
|
|
51
|
-
-------
|
|
50
|
+
Returns:
|
|
52
51
|
positional embeddings (batch, max_len, d_model)
|
|
53
52
|
"""
|
|
54
53
|
if x.dtype == tf.float16: # amp fix: cast to half
|
|
@@ -60,8 +59,8 @@ class PositionalEncoding(layers.Layer, NestedObject):
|
|
|
60
59
|
|
|
61
60
|
@tf.function
|
|
62
61
|
def scaled_dot_product_attention(
|
|
63
|
-
query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask:
|
|
64
|
-
) ->
|
|
62
|
+
query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: tf.Tensor | None = None
|
|
63
|
+
) -> tuple[tf.Tensor, tf.Tensor]:
|
|
65
64
|
"""Scaled Dot-Product Attention"""
|
|
66
65
|
scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) / math.sqrt(query.shape[-1])
|
|
67
66
|
if mask is not None:
|
|
@@ -160,7 +159,7 @@ class EncoderBlock(layers.Layer, NestedObject):
|
|
|
160
159
|
PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
|
|
161
160
|
]
|
|
162
161
|
|
|
163
|
-
def call(self, x: tf.Tensor, mask:
|
|
162
|
+
def call(self, x: tf.Tensor, mask: tf.Tensor | None = None, **kwargs: Any) -> tf.Tensor:
|
|
164
163
|
output = x
|
|
165
164
|
|
|
166
165
|
for i in range(self.num_layers):
|
|
@@ -210,8 +209,8 @@ class Decoder(layers.Layer, NestedObject):
|
|
|
210
209
|
self,
|
|
211
210
|
tgt: tf.Tensor,
|
|
212
211
|
memory: tf.Tensor,
|
|
213
|
-
source_mask:
|
|
214
|
-
target_mask:
|
|
212
|
+
source_mask: tf.Tensor | None = None,
|
|
213
|
+
target_mask: tf.Tensor | None = None,
|
|
215
214
|
**kwargs: Any,
|
|
216
215
|
) -> tf.Tensor:
|
|
217
216
|
tgt = self.embed(tgt, **kwargs) * math.sqrt(self.d_model)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import Tuple
|
|
8
7
|
|
|
9
8
|
import torch
|
|
10
9
|
from torch import nn
|
|
@@ -15,7 +14,7 @@ __all__ = ["PatchEmbedding"]
|
|
|
15
14
|
class PatchEmbedding(nn.Module):
|
|
16
15
|
"""Compute 2D patch embeddings with cls token and positional encoding"""
|
|
17
16
|
|
|
18
|
-
def __init__(self, input_shape:
|
|
17
|
+
def __init__(self, input_shape: tuple[int, int, int], embed_dim: int, patch_size: tuple[int, int]) -> None:
|
|
19
18
|
super().__init__()
|
|
20
19
|
channels, height, width = input_shape
|
|
21
20
|
self.patch_size = patch_size
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import layers
|
|
@@ -17,7 +17,7 @@ __all__ = ["PatchEmbedding"]
|
|
|
17
17
|
class PatchEmbedding(layers.Layer, NestedObject):
|
|
18
18
|
"""Compute 2D patch embeddings with cls token and positional encoding"""
|
|
19
19
|
|
|
20
|
-
def __init__(self, input_shape:
|
|
20
|
+
def __init__(self, input_shape: tuple[int, int, int], embed_dim: int, patch_size: tuple[int, int]) -> None:
|
|
21
21
|
super().__init__()
|
|
22
22
|
height, width, _ = input_shape
|
|
23
23
|
self.patch_size = patch_size
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|