python-doctr 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +1 -1
- doctr/contrib/__init__.py +0 -0
- doctr/contrib/artefacts.py +131 -0
- doctr/contrib/base.py +105 -0
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/generator/base.py +6 -5
- doctr/datasets/imgur5k.py +1 -1
- doctr/datasets/loader.py +1 -6
- doctr/datasets/utils.py +2 -1
- doctr/datasets/vocabs.py +9 -2
- doctr/file_utils.py +26 -12
- doctr/io/elements.py +40 -6
- doctr/io/html.py +2 -2
- doctr/io/image/pytorch.py +6 -8
- doctr/io/image/tensorflow.py +1 -1
- doctr/io/pdf.py +5 -2
- doctr/io/reader.py +6 -0
- doctr/models/__init__.py +0 -1
- doctr/models/_utils.py +57 -20
- doctr/models/builder.py +71 -13
- doctr/models/classification/mobilenet/pytorch.py +45 -9
- doctr/models/classification/mobilenet/tensorflow.py +38 -7
- doctr/models/classification/predictor/pytorch.py +18 -11
- doctr/models/classification/predictor/tensorflow.py +16 -10
- doctr/models/classification/textnet/pytorch.py +3 -3
- doctr/models/classification/textnet/tensorflow.py +3 -3
- doctr/models/classification/zoo.py +39 -15
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/__init__.py +1 -0
- doctr/models/detection/_utils/base.py +66 -0
- doctr/models/detection/differentiable_binarization/base.py +4 -3
- doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
- doctr/models/detection/differentiable_binarization/tensorflow.py +14 -18
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +257 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +4 -3
- doctr/models/detection/predictor/pytorch.py +15 -1
- doctr/models/detection/predictor/tensorflow.py +15 -1
- doctr/models/detection/zoo.py +21 -4
- doctr/models/factory/hub.py +3 -12
- doctr/models/kie_predictor/base.py +9 -3
- doctr/models/kie_predictor/pytorch.py +41 -20
- doctr/models/kie_predictor/tensorflow.py +36 -16
- doctr/models/modules/layers/pytorch.py +89 -10
- doctr/models/modules/layers/tensorflow.py +88 -10
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/predictor/base.py +77 -50
- doctr/models/predictor/pytorch.py +31 -20
- doctr/models/predictor/tensorflow.py +27 -17
- doctr/models/preprocessor/pytorch.py +4 -4
- doctr/models/preprocessor/tensorflow.py +3 -2
- doctr/models/recognition/master/pytorch.py +2 -2
- doctr/models/recognition/parseq/pytorch.py +4 -3
- doctr/models/recognition/parseq/tensorflow.py +4 -3
- doctr/models/recognition/sar/pytorch.py +7 -6
- doctr/models/recognition/sar/tensorflow.py +3 -9
- doctr/models/recognition/vitstr/pytorch.py +1 -1
- doctr/models/recognition/zoo.py +1 -1
- doctr/models/zoo.py +2 -2
- doctr/py.typed +0 -0
- doctr/transforms/functional/base.py +1 -1
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/base.py +37 -15
- doctr/transforms/modules/pytorch.py +66 -8
- doctr/transforms/modules/tensorflow.py +63 -7
- doctr/utils/fonts.py +7 -5
- doctr/utils/geometry.py +35 -12
- doctr/utils/metrics.py +33 -174
- doctr/utils/reconstitution.py +126 -0
- doctr/utils/visualization.py +5 -118
- doctr/version.py +1 -1
- {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/METADATA +96 -91
- {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/RECORD +79 -75
- {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/WHEEL +1 -1
- doctr/models/artefacts/__init__.py +0 -2
- doctr/models/artefacts/barcode.py +0 -74
- doctr/models/artefacts/face.py +0 -63
- doctr/models/obj_detection/__init__.py +0 -1
- doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
- {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/zip-safe +0 -0
|
@@ -10,10 +10,10 @@ import torch
|
|
|
10
10
|
from torch import nn
|
|
11
11
|
|
|
12
12
|
from doctr.io.elements import Document
|
|
13
|
-
from doctr.models._utils import
|
|
13
|
+
from doctr.models._utils import get_language, invert_data_structure
|
|
14
14
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
15
15
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
16
|
-
from doctr.utils.geometry import
|
|
16
|
+
from doctr.utils.geometry import detach_scores
|
|
17
17
|
|
|
18
18
|
from .base import _KIEPredictor
|
|
19
19
|
|
|
@@ -55,7 +55,13 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
55
55
|
self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
|
|
56
56
|
self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
|
|
57
57
|
_KIEPredictor.__init__(
|
|
58
|
-
self,
|
|
58
|
+
self,
|
|
59
|
+
assume_straight_pages,
|
|
60
|
+
straighten_pages,
|
|
61
|
+
preserve_aspect_ratio,
|
|
62
|
+
symmetric_pad,
|
|
63
|
+
detect_orientation,
|
|
64
|
+
**kwargs,
|
|
59
65
|
)
|
|
60
66
|
self.detect_orientation = detect_orientation
|
|
61
67
|
self.detect_language = detect_language
|
|
@@ -83,29 +89,31 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
83
89
|
for out_map in out_maps
|
|
84
90
|
]
|
|
85
91
|
if self.detect_orientation:
|
|
86
|
-
|
|
92
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
|
|
87
93
|
orientations = [
|
|
88
|
-
{"value": orientation_page, "confidence": None} for orientation_page in
|
|
94
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
89
95
|
]
|
|
90
96
|
else:
|
|
91
97
|
orientations = None
|
|
98
|
+
general_pages_orientations = None
|
|
99
|
+
origin_pages_orientations = None
|
|
92
100
|
if self.straighten_pages:
|
|
93
|
-
|
|
94
|
-
origin_page_orientations
|
|
95
|
-
if self.detect_orientation
|
|
96
|
-
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
97
|
-
)
|
|
98
|
-
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
101
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
|
|
99
102
|
# Forward again to get predictions on straight pages
|
|
100
103
|
loc_preds = self.det_predictor(pages, **kwargs)
|
|
101
104
|
|
|
102
105
|
dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
|
|
106
|
+
|
|
107
|
+
# Detach objectness scores from loc_preds
|
|
108
|
+
objectness_scores = {}
|
|
109
|
+
for class_name, det_preds in dict_loc_preds.items():
|
|
110
|
+
_loc_preds, _scores = detach_scores(det_preds)
|
|
111
|
+
dict_loc_preds[class_name] = _loc_preds
|
|
112
|
+
objectness_scores[class_name] = _scores
|
|
113
|
+
|
|
103
114
|
# Check whether crop mode should be switched to channels first
|
|
104
115
|
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
105
116
|
|
|
106
|
-
# Rectify crops if aspect ratio
|
|
107
|
-
dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}
|
|
108
|
-
|
|
109
117
|
# Apply hooks to loc_preds if any
|
|
110
118
|
for hook in self.hooks:
|
|
111
119
|
dict_loc_preds = hook(dict_loc_preds)
|
|
@@ -114,32 +122,43 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
114
122
|
crops = {}
|
|
115
123
|
for class_name in dict_loc_preds.keys():
|
|
116
124
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
117
|
-
pages,
|
|
125
|
+
pages, # type: ignore[arg-type]
|
|
118
126
|
dict_loc_preds[class_name],
|
|
119
127
|
channels_last=channels_last,
|
|
120
128
|
assume_straight_pages=self.assume_straight_pages,
|
|
121
129
|
)
|
|
122
130
|
# Rectify crop orientation
|
|
131
|
+
crop_orientations: Any = {}
|
|
123
132
|
if not self.assume_straight_pages:
|
|
124
133
|
for class_name in dict_loc_preds.keys():
|
|
125
|
-
crops[class_name], dict_loc_preds[class_name] = self._rectify_crops(
|
|
134
|
+
crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
|
|
126
135
|
crops[class_name], dict_loc_preds[class_name]
|
|
127
136
|
)
|
|
137
|
+
crop_orientations[class_name] = [
|
|
138
|
+
{"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
|
|
139
|
+
]
|
|
140
|
+
|
|
128
141
|
# Identify character sequences
|
|
129
142
|
word_preds = {
|
|
130
143
|
k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
|
|
131
144
|
for k, crop_value in crops.items()
|
|
132
145
|
}
|
|
146
|
+
if not crop_orientations:
|
|
147
|
+
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
|
|
133
148
|
|
|
134
149
|
boxes: Dict = {}
|
|
135
150
|
text_preds: Dict = {}
|
|
151
|
+
word_crop_orientations: Dict = {}
|
|
136
152
|
for class_name in dict_loc_preds.keys():
|
|
137
|
-
boxes[class_name], text_preds[class_name] = self._process_predictions(
|
|
138
|
-
dict_loc_preds[class_name], word_preds[class_name]
|
|
153
|
+
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
|
|
154
|
+
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
|
|
139
155
|
)
|
|
140
156
|
|
|
141
157
|
boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
|
|
158
|
+
objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
|
|
142
159
|
text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
|
|
160
|
+
crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
|
|
161
|
+
|
|
143
162
|
if self.detect_language:
|
|
144
163
|
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
|
|
145
164
|
languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
|
|
@@ -147,10 +166,12 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
147
166
|
languages_dict = None
|
|
148
167
|
|
|
149
168
|
out = self.doc_builder(
|
|
150
|
-
pages,
|
|
169
|
+
pages, # type: ignore[arg-type]
|
|
151
170
|
boxes_per_page,
|
|
171
|
+
objectness_scores_per_page,
|
|
152
172
|
text_preds_per_page,
|
|
153
|
-
origin_page_shapes,
|
|
173
|
+
origin_page_shapes, # type: ignore[arg-type]
|
|
174
|
+
crop_orientations_per_page,
|
|
154
175
|
orientations,
|
|
155
176
|
languages_dict,
|
|
156
177
|
)
|
|
@@ -9,10 +9,10 @@ import numpy as np
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
|
|
11
11
|
from doctr.io.elements import Document
|
|
12
|
-
from doctr.models._utils import
|
|
12
|
+
from doctr.models._utils import get_language, invert_data_structure
|
|
13
13
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
14
14
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
15
|
-
from doctr.utils.geometry import
|
|
15
|
+
from doctr.utils.geometry import detach_scores
|
|
16
16
|
from doctr.utils.repr import NestedObject
|
|
17
17
|
|
|
18
18
|
from .base import _KIEPredictor
|
|
@@ -56,7 +56,13 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
56
56
|
self.det_predictor = det_predictor
|
|
57
57
|
self.reco_predictor = reco_predictor
|
|
58
58
|
_KIEPredictor.__init__(
|
|
59
|
-
self,
|
|
59
|
+
self,
|
|
60
|
+
assume_straight_pages,
|
|
61
|
+
straighten_pages,
|
|
62
|
+
preserve_aspect_ratio,
|
|
63
|
+
symmetric_pad,
|
|
64
|
+
detect_orientation,
|
|
65
|
+
**kwargs,
|
|
60
66
|
)
|
|
61
67
|
self.detect_orientation = detect_orientation
|
|
62
68
|
self.detect_language = detect_language
|
|
@@ -83,25 +89,27 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
83
89
|
for out_map in out_maps
|
|
84
90
|
]
|
|
85
91
|
if self.detect_orientation:
|
|
86
|
-
|
|
92
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
87
93
|
orientations = [
|
|
88
|
-
{"value": orientation_page, "confidence": None} for orientation_page in
|
|
94
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
89
95
|
]
|
|
90
96
|
else:
|
|
91
97
|
orientations = None
|
|
98
|
+
general_pages_orientations = None
|
|
99
|
+
origin_pages_orientations = None
|
|
92
100
|
if self.straighten_pages:
|
|
93
|
-
|
|
94
|
-
origin_page_orientations
|
|
95
|
-
if self.detect_orientation
|
|
96
|
-
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
97
|
-
)
|
|
98
|
-
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
101
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
99
102
|
# Forward again to get predictions on straight pages
|
|
100
103
|
loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
|
|
101
104
|
|
|
102
105
|
dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
|
|
103
|
-
|
|
104
|
-
|
|
106
|
+
|
|
107
|
+
# Detach objectness scores from loc_preds
|
|
108
|
+
objectness_scores = {}
|
|
109
|
+
for class_name, det_preds in dict_loc_preds.items():
|
|
110
|
+
_loc_preds, _scores = detach_scores(det_preds)
|
|
111
|
+
dict_loc_preds[class_name] = _loc_preds
|
|
112
|
+
objectness_scores[class_name] = _scores
|
|
105
113
|
|
|
106
114
|
# Apply hooks to loc_preds if any
|
|
107
115
|
for hook in self.hooks:
|
|
@@ -113,28 +121,38 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
113
121
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
114
122
|
pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages
|
|
115
123
|
)
|
|
124
|
+
|
|
116
125
|
# Rectify crop orientation
|
|
126
|
+
crop_orientations: Any = {}
|
|
117
127
|
if not self.assume_straight_pages:
|
|
118
128
|
for class_name in dict_loc_preds.keys():
|
|
119
|
-
crops[class_name], dict_loc_preds[class_name] = self._rectify_crops(
|
|
129
|
+
crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
|
|
120
130
|
crops[class_name], dict_loc_preds[class_name]
|
|
121
131
|
)
|
|
132
|
+
crop_orientations[class_name] = [
|
|
133
|
+
{"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
|
|
134
|
+
]
|
|
122
135
|
|
|
123
136
|
# Identify character sequences
|
|
124
137
|
word_preds = {
|
|
125
138
|
k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
|
|
126
139
|
for k, crop_value in crops.items()
|
|
127
140
|
}
|
|
141
|
+
if not crop_orientations:
|
|
142
|
+
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
|
|
128
143
|
|
|
129
144
|
boxes: Dict = {}
|
|
130
145
|
text_preds: Dict = {}
|
|
146
|
+
word_crop_orientations: Dict = {}
|
|
131
147
|
for class_name in dict_loc_preds.keys():
|
|
132
|
-
boxes[class_name], text_preds[class_name] = self._process_predictions(
|
|
133
|
-
dict_loc_preds[class_name], word_preds[class_name]
|
|
148
|
+
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
|
|
149
|
+
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
|
|
134
150
|
)
|
|
135
151
|
|
|
136
152
|
boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
|
|
153
|
+
objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
|
|
137
154
|
text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
|
|
155
|
+
crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
|
|
138
156
|
|
|
139
157
|
if self.detect_language:
|
|
140
158
|
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
|
|
@@ -145,8 +163,10 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
145
163
|
out = self.doc_builder(
|
|
146
164
|
pages,
|
|
147
165
|
boxes_per_page,
|
|
166
|
+
objectness_scores_per_page,
|
|
148
167
|
text_preds_per_page,
|
|
149
168
|
origin_page_shapes, # type: ignore[arg-type]
|
|
169
|
+
crop_orientations_per_page,
|
|
150
170
|
orientations,
|
|
151
171
|
languages_dict,
|
|
152
172
|
)
|
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
from typing import Tuple, Union
|
|
7
7
|
|
|
8
|
+
import numpy as np
|
|
8
9
|
import torch
|
|
9
10
|
import torch.nn as nn
|
|
10
11
|
|
|
@@ -26,18 +27,20 @@ class FASTConvLayer(nn.Module):
|
|
|
26
27
|
) -> None:
|
|
27
28
|
super().__init__()
|
|
28
29
|
|
|
29
|
-
|
|
30
|
+
self.groups = groups
|
|
31
|
+
self.in_channels = in_channels
|
|
32
|
+
self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
|
30
33
|
|
|
31
34
|
self.hor_conv, self.hor_bn = None, None
|
|
32
35
|
self.ver_conv, self.ver_bn = None, None
|
|
33
36
|
|
|
34
|
-
padding = (int(((converted_ks[0] - 1) * dilation) / 2), int(((converted_ks[1] - 1) * dilation) / 2))
|
|
37
|
+
padding = (int(((self.converted_ks[0] - 1) * dilation) / 2), int(((self.converted_ks[1] - 1) * dilation) / 2))
|
|
35
38
|
|
|
36
39
|
self.activation = nn.ReLU(inplace=True)
|
|
37
40
|
self.conv = nn.Conv2d(
|
|
38
41
|
in_channels,
|
|
39
42
|
out_channels,
|
|
40
|
-
kernel_size=converted_ks,
|
|
43
|
+
kernel_size=self.converted_ks,
|
|
41
44
|
stride=stride,
|
|
42
45
|
padding=padding,
|
|
43
46
|
dilation=dilation,
|
|
@@ -47,12 +50,12 @@ class FASTConvLayer(nn.Module):
|
|
|
47
50
|
|
|
48
51
|
self.bn = nn.BatchNorm2d(out_channels)
|
|
49
52
|
|
|
50
|
-
if converted_ks[1] != 1:
|
|
53
|
+
if self.converted_ks[1] != 1:
|
|
51
54
|
self.ver_conv = nn.Conv2d(
|
|
52
55
|
in_channels,
|
|
53
56
|
out_channels,
|
|
54
|
-
kernel_size=(converted_ks[0], 1),
|
|
55
|
-
padding=(int(((converted_ks[0] - 1) * dilation) / 2), 0),
|
|
57
|
+
kernel_size=(self.converted_ks[0], 1),
|
|
58
|
+
padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0),
|
|
56
59
|
stride=stride,
|
|
57
60
|
dilation=dilation,
|
|
58
61
|
groups=groups,
|
|
@@ -60,12 +63,12 @@ class FASTConvLayer(nn.Module):
|
|
|
60
63
|
)
|
|
61
64
|
self.ver_bn = nn.BatchNorm2d(out_channels)
|
|
62
65
|
|
|
63
|
-
if converted_ks[0] != 1:
|
|
66
|
+
if self.converted_ks[0] != 1:
|
|
64
67
|
self.hor_conv = nn.Conv2d(
|
|
65
68
|
in_channels,
|
|
66
69
|
out_channels,
|
|
67
|
-
kernel_size=(1, converted_ks[1]),
|
|
68
|
-
padding=(0, int(((converted_ks[1] - 1) * dilation) / 2)),
|
|
70
|
+
kernel_size=(1, self.converted_ks[1]),
|
|
71
|
+
padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)),
|
|
69
72
|
stride=stride,
|
|
70
73
|
dilation=dilation,
|
|
71
74
|
groups=groups,
|
|
@@ -76,11 +79,87 @@ class FASTConvLayer(nn.Module):
|
|
|
76
79
|
self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None
|
|
77
80
|
|
|
78
81
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
82
|
+
if hasattr(self, "fused_conv"):
|
|
83
|
+
return self.activation(self.fused_conv(x))
|
|
84
|
+
|
|
79
85
|
main_outputs = self.bn(self.conv(x))
|
|
80
86
|
vertical_outputs = self.ver_bn(self.ver_conv(x)) if self.ver_conv is not None and self.ver_bn is not None else 0
|
|
81
87
|
horizontal_outputs = (
|
|
82
88
|
self.hor_bn(self.hor_conv(x)) if self.hor_bn is not None and self.hor_conv is not None else 0
|
|
83
89
|
)
|
|
84
|
-
id_out = self.rbr_identity(x) if self.rbr_identity is not None
|
|
90
|
+
id_out = self.rbr_identity(x) if self.rbr_identity is not None else 0
|
|
85
91
|
|
|
86
92
|
return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
|
|
93
|
+
|
|
94
|
+
# The following logic is used to reparametrize the layer
|
|
95
|
+
# Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py
|
|
96
|
+
def _identity_to_conv(
|
|
97
|
+
self, identity: Union[nn.BatchNorm2d, None]
|
|
98
|
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
|
99
|
+
if identity is None or identity.running_var is None:
|
|
100
|
+
return 0, 0
|
|
101
|
+
if not hasattr(self, "id_tensor"):
|
|
102
|
+
input_dim = self.in_channels // self.groups
|
|
103
|
+
kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32)
|
|
104
|
+
for i in range(self.in_channels):
|
|
105
|
+
kernel_value[i, i % input_dim, 0, 0] = 1
|
|
106
|
+
id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
|
|
107
|
+
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
108
|
+
kernel = self.id_tensor
|
|
109
|
+
std = (identity.running_var + identity.eps).sqrt()
|
|
110
|
+
t = (identity.weight / std).reshape(-1, 1, 1, 1)
|
|
111
|
+
return kernel * t, identity.bias - identity.running_mean * identity.weight / std
|
|
112
|
+
|
|
113
|
+
def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
114
|
+
kernel = conv.weight
|
|
115
|
+
kernel = self._pad_to_mxn_tensor(kernel)
|
|
116
|
+
std = (bn.running_var + bn.eps).sqrt() # type: ignore
|
|
117
|
+
t = (bn.weight / std).reshape(-1, 1, 1, 1)
|
|
118
|
+
return kernel * t, bn.bias - bn.running_mean * bn.weight / std
|
|
119
|
+
|
|
120
|
+
def _get_equivalent_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
121
|
+
kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
|
|
122
|
+
if self.ver_conv is not None:
|
|
123
|
+
kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type]
|
|
124
|
+
else:
|
|
125
|
+
kernel_mx1, bias_mx1 = 0, 0 # type: ignore[assignment]
|
|
126
|
+
if self.hor_conv is not None:
|
|
127
|
+
kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn) # type: ignore[arg-type]
|
|
128
|
+
else:
|
|
129
|
+
kernel_1xn, bias_1xn = 0, 0 # type: ignore[assignment]
|
|
130
|
+
kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
|
|
131
|
+
kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
|
|
132
|
+
bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
|
|
133
|
+
return kernel_mxn, bias_mxn
|
|
134
|
+
|
|
135
|
+
def _pad_to_mxn_tensor(self, kernel: torch.Tensor) -> torch.Tensor:
|
|
136
|
+
kernel_height, kernel_width = self.converted_ks
|
|
137
|
+
height, width = kernel.shape[2:]
|
|
138
|
+
pad_left_right = (kernel_width - width) // 2
|
|
139
|
+
pad_top_down = (kernel_height - height) // 2
|
|
140
|
+
return torch.nn.functional.pad(kernel, [pad_left_right, pad_left_right, pad_top_down, pad_top_down], value=0)
|
|
141
|
+
|
|
142
|
+
def reparameterize_layer(self):
|
|
143
|
+
if hasattr(self, "fused_conv"):
|
|
144
|
+
return
|
|
145
|
+
kernel, bias = self._get_equivalent_kernel_bias()
|
|
146
|
+
self.fused_conv = nn.Conv2d(
|
|
147
|
+
in_channels=self.conv.in_channels,
|
|
148
|
+
out_channels=self.conv.out_channels,
|
|
149
|
+
kernel_size=self.conv.kernel_size, # type: ignore[arg-type]
|
|
150
|
+
stride=self.conv.stride, # type: ignore[arg-type]
|
|
151
|
+
padding=self.conv.padding, # type: ignore[arg-type]
|
|
152
|
+
dilation=self.conv.dilation, # type: ignore[arg-type]
|
|
153
|
+
groups=self.conv.groups,
|
|
154
|
+
bias=True,
|
|
155
|
+
)
|
|
156
|
+
self.fused_conv.weight.data = kernel
|
|
157
|
+
self.fused_conv.bias.data = bias # type: ignore[union-attr]
|
|
158
|
+
for para in self.parameters():
|
|
159
|
+
para.detach_()
|
|
160
|
+
for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
|
|
161
|
+
if hasattr(self, attr):
|
|
162
|
+
self.__delattr__(attr)
|
|
163
|
+
|
|
164
|
+
if hasattr(self, "rbr_identity"):
|
|
165
|
+
self.__delattr__("rbr_identity")
|
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
from typing import Any, Tuple, Union
|
|
7
7
|
|
|
8
|
+
import numpy as np
|
|
8
9
|
import tensorflow as tf
|
|
9
10
|
from tensorflow.keras import layers
|
|
10
11
|
|
|
@@ -28,18 +29,21 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
28
29
|
) -> None:
|
|
29
30
|
super().__init__()
|
|
30
31
|
|
|
31
|
-
|
|
32
|
+
self.groups = groups
|
|
33
|
+
self.in_channels = in_channels
|
|
34
|
+
self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
|
32
35
|
|
|
33
36
|
self.hor_conv, self.hor_bn = None, None
|
|
34
37
|
self.ver_conv, self.ver_bn = None, None
|
|
35
38
|
|
|
36
|
-
padding = ((converted_ks[0] - 1) * dilation // 2, (converted_ks[1] - 1) * dilation // 2)
|
|
39
|
+
padding = ((self.converted_ks[0] - 1) * dilation // 2, (self.converted_ks[1] - 1) * dilation // 2)
|
|
37
40
|
|
|
38
41
|
self.activation = layers.ReLU()
|
|
39
42
|
self.conv_pad = layers.ZeroPadding2D(padding=padding)
|
|
43
|
+
|
|
40
44
|
self.conv = layers.Conv2D(
|
|
41
45
|
filters=out_channels,
|
|
42
|
-
kernel_size=converted_ks,
|
|
46
|
+
kernel_size=self.converted_ks,
|
|
43
47
|
strides=stride,
|
|
44
48
|
dilation_rate=dilation,
|
|
45
49
|
groups=groups,
|
|
@@ -48,13 +52,13 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
48
52
|
|
|
49
53
|
self.bn = layers.BatchNormalization()
|
|
50
54
|
|
|
51
|
-
if converted_ks[1] != 1:
|
|
55
|
+
if self.converted_ks[1] != 1:
|
|
52
56
|
self.ver_pad = layers.ZeroPadding2D(
|
|
53
|
-
padding=(int(((converted_ks[0] - 1) * dilation) / 2), 0),
|
|
57
|
+
padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0),
|
|
54
58
|
)
|
|
55
59
|
self.ver_conv = layers.Conv2D(
|
|
56
60
|
filters=out_channels,
|
|
57
|
-
kernel_size=(converted_ks[0], 1),
|
|
61
|
+
kernel_size=(self.converted_ks[0], 1),
|
|
58
62
|
strides=stride,
|
|
59
63
|
dilation_rate=dilation,
|
|
60
64
|
groups=groups,
|
|
@@ -62,13 +66,13 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
62
66
|
)
|
|
63
67
|
self.ver_bn = layers.BatchNormalization()
|
|
64
68
|
|
|
65
|
-
if converted_ks[0] != 1:
|
|
69
|
+
if self.converted_ks[0] != 1:
|
|
66
70
|
self.hor_pad = layers.ZeroPadding2D(
|
|
67
|
-
padding=(0, int(((converted_ks[1] - 1) * dilation) / 2)),
|
|
71
|
+
padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)),
|
|
68
72
|
)
|
|
69
73
|
self.hor_conv = layers.Conv2D(
|
|
70
74
|
filters=out_channels,
|
|
71
|
-
kernel_size=(1, converted_ks[1]),
|
|
75
|
+
kernel_size=(1, self.converted_ks[1]),
|
|
72
76
|
strides=stride,
|
|
73
77
|
dilation_rate=dilation,
|
|
74
78
|
groups=groups,
|
|
@@ -79,6 +83,9 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
79
83
|
self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None
|
|
80
84
|
|
|
81
85
|
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
86
|
+
if hasattr(self, "fused_conv"):
|
|
87
|
+
return self.activation(self.fused_conv(self.conv_pad(x, **kwargs), **kwargs))
|
|
88
|
+
|
|
82
89
|
main_outputs = self.bn(self.conv(self.conv_pad(x, **kwargs), **kwargs), **kwargs)
|
|
83
90
|
vertical_outputs = (
|
|
84
91
|
self.ver_bn(self.ver_conv(self.ver_pad(x, **kwargs), **kwargs), **kwargs)
|
|
@@ -90,6 +97,77 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
90
97
|
if self.hor_bn is not None and self.hor_conv is not None
|
|
91
98
|
else 0
|
|
92
99
|
)
|
|
93
|
-
id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None
|
|
100
|
+
id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0
|
|
94
101
|
|
|
95
102
|
return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
|
|
103
|
+
|
|
104
|
+
# The following logic is used to reparametrize the layer
|
|
105
|
+
# Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py
|
|
106
|
+
def _identity_to_conv(
|
|
107
|
+
self, identity: layers.BatchNormalization
|
|
108
|
+
) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]:
|
|
109
|
+
if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"):
|
|
110
|
+
return 0, 0
|
|
111
|
+
if not hasattr(self, "id_tensor"):
|
|
112
|
+
input_dim = self.in_channels // self.groups
|
|
113
|
+
kernel_value = np.zeros((1, 1, input_dim, self.in_channels), dtype=np.float32)
|
|
114
|
+
for i in range(self.in_channels):
|
|
115
|
+
kernel_value[0, 0, i % input_dim, i] = 1
|
|
116
|
+
id_tensor = tf.constant(kernel_value, dtype=tf.float32)
|
|
117
|
+
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
118
|
+
kernel = self.id_tensor
|
|
119
|
+
std = tf.sqrt(identity.moving_variance + identity.epsilon)
|
|
120
|
+
t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
|
|
121
|
+
return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
|
|
122
|
+
|
|
123
|
+
def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
124
|
+
kernel = conv.kernel
|
|
125
|
+
kernel = self._pad_to_mxn_tensor(kernel)
|
|
126
|
+
std = tf.sqrt(bn.moving_variance + bn.epsilon)
|
|
127
|
+
t = tf.reshape(bn.gamma / std, (1, 1, 1, -1))
|
|
128
|
+
return kernel * t, bn.beta - bn.moving_mean * bn.gamma / std
|
|
129
|
+
|
|
130
|
+
def _get_equivalent_kernel_bias(self):
|
|
131
|
+
kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
|
|
132
|
+
if self.ver_conv is not None:
|
|
133
|
+
kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn)
|
|
134
|
+
else:
|
|
135
|
+
kernel_mx1, bias_mx1 = 0, 0
|
|
136
|
+
if self.hor_conv is not None:
|
|
137
|
+
kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn)
|
|
138
|
+
else:
|
|
139
|
+
kernel_1xn, bias_1xn = 0, 0
|
|
140
|
+
kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
|
|
141
|
+
kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
|
|
142
|
+
bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
|
|
143
|
+
return kernel_mxn, bias_mxn
|
|
144
|
+
|
|
145
|
+
def _pad_to_mxn_tensor(self, kernel: tf.Tensor) -> tf.Tensor:
|
|
146
|
+
kernel_height, kernel_width = self.converted_ks
|
|
147
|
+
height, width = kernel.shape[:2]
|
|
148
|
+
pad_left_right = tf.maximum(0, (kernel_width - width) // 2)
|
|
149
|
+
pad_top_down = tf.maximum(0, (kernel_height - height) // 2)
|
|
150
|
+
return tf.pad(kernel, [[pad_top_down, pad_top_down], [pad_left_right, pad_left_right], [0, 0], [0, 0]])
|
|
151
|
+
|
|
152
|
+
def reparameterize_layer(self):
|
|
153
|
+
kernel, bias = self._get_equivalent_kernel_bias()
|
|
154
|
+
self.fused_conv = layers.Conv2D(
|
|
155
|
+
filters=self.conv.filters,
|
|
156
|
+
kernel_size=self.conv.kernel_size,
|
|
157
|
+
strides=self.conv.strides,
|
|
158
|
+
padding=self.conv.padding,
|
|
159
|
+
dilation_rate=self.conv.dilation_rate,
|
|
160
|
+
groups=self.conv.groups,
|
|
161
|
+
use_bias=True,
|
|
162
|
+
)
|
|
163
|
+
# build layer to initialize weights and biases
|
|
164
|
+
self.fused_conv.build(input_shape=(None, None, None, kernel.shape[-2]))
|
|
165
|
+
self.fused_conv.set_weights([kernel.numpy(), bias.numpy()])
|
|
166
|
+
for para in self.trainable_variables:
|
|
167
|
+
para._trainable = False
|
|
168
|
+
for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
|
|
169
|
+
if hasattr(self, attr):
|
|
170
|
+
delattr(self, attr)
|
|
171
|
+
|
|
172
|
+
if hasattr(self, "rbr_identity"):
|
|
173
|
+
delattr(self, "rbr_identity")
|
|
@@ -51,8 +51,8 @@ def scaled_dot_product_attention(
|
|
|
51
51
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
|
|
52
52
|
if mask is not None:
|
|
53
53
|
# NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
|
|
54
|
-
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
55
|
-
p_attn = torch.softmax(scores, dim=-1)
|
|
54
|
+
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
55
|
+
p_attn = torch.softmax(scores, dim=-1)
|
|
56
56
|
return torch.matmul(p_attn, value), p_attn
|
|
57
57
|
|
|
58
58
|
|