python-doctr 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +1 -1
- doctr/contrib/__init__.py +0 -0
- doctr/contrib/artefacts.py +131 -0
- doctr/contrib/base.py +105 -0
- doctr/datasets/cord.py +10 -1
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +11 -1
- doctr/datasets/generator/base.py +6 -5
- doctr/datasets/ic03.py +11 -1
- doctr/datasets/ic13.py +10 -1
- doctr/datasets/iiit5k.py +26 -16
- doctr/datasets/imgur5k.py +11 -2
- doctr/datasets/loader.py +1 -6
- doctr/datasets/sroie.py +11 -1
- doctr/datasets/svhn.py +11 -1
- doctr/datasets/svt.py +11 -1
- doctr/datasets/synthtext.py +11 -1
- doctr/datasets/utils.py +9 -3
- doctr/datasets/vocabs.py +15 -4
- doctr/datasets/wildreceipt.py +12 -1
- doctr/file_utils.py +45 -12
- doctr/io/elements.py +52 -10
- doctr/io/html.py +2 -2
- doctr/io/image/pytorch.py +6 -8
- doctr/io/image/tensorflow.py +1 -1
- doctr/io/pdf.py +5 -2
- doctr/io/reader.py +6 -0
- doctr/models/__init__.py +0 -1
- doctr/models/_utils.py +57 -20
- doctr/models/builder.py +73 -15
- doctr/models/classification/magc_resnet/tensorflow.py +13 -6
- doctr/models/classification/mobilenet/pytorch.py +47 -9
- doctr/models/classification/mobilenet/tensorflow.py +51 -14
- doctr/models/classification/predictor/pytorch.py +28 -17
- doctr/models/classification/predictor/tensorflow.py +26 -16
- doctr/models/classification/resnet/tensorflow.py +21 -8
- doctr/models/classification/textnet/pytorch.py +3 -3
- doctr/models/classification/textnet/tensorflow.py +11 -5
- doctr/models/classification/vgg/tensorflow.py +9 -3
- doctr/models/classification/vit/tensorflow.py +10 -4
- doctr/models/classification/zoo.py +55 -19
- doctr/models/detection/_utils/__init__.py +1 -0
- doctr/models/detection/_utils/base.py +66 -0
- doctr/models/detection/differentiable_binarization/base.py +4 -3
- doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
- doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
- doctr/models/detection/fast/base.py +6 -5
- doctr/models/detection/fast/pytorch.py +4 -4
- doctr/models/detection/fast/tensorflow.py +15 -12
- doctr/models/detection/linknet/base.py +4 -3
- doctr/models/detection/linknet/tensorflow.py +23 -11
- doctr/models/detection/predictor/pytorch.py +15 -1
- doctr/models/detection/predictor/tensorflow.py +17 -3
- doctr/models/detection/zoo.py +7 -2
- doctr/models/factory/hub.py +8 -18
- doctr/models/kie_predictor/base.py +13 -3
- doctr/models/kie_predictor/pytorch.py +45 -20
- doctr/models/kie_predictor/tensorflow.py +44 -17
- doctr/models/modules/layers/pytorch.py +2 -3
- doctr/models/modules/layers/tensorflow.py +6 -8
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/transformer/tensorflow.py +0 -2
- doctr/models/modules/vision_transformer/pytorch.py +1 -1
- doctr/models/modules/vision_transformer/tensorflow.py +1 -1
- doctr/models/predictor/base.py +97 -58
- doctr/models/predictor/pytorch.py +35 -20
- doctr/models/predictor/tensorflow.py +35 -18
- doctr/models/preprocessor/pytorch.py +4 -4
- doctr/models/preprocessor/tensorflow.py +3 -2
- doctr/models/recognition/crnn/tensorflow.py +8 -6
- doctr/models/recognition/master/pytorch.py +2 -2
- doctr/models/recognition/master/tensorflow.py +9 -4
- doctr/models/recognition/parseq/pytorch.py +4 -3
- doctr/models/recognition/parseq/tensorflow.py +14 -11
- doctr/models/recognition/sar/pytorch.py +7 -6
- doctr/models/recognition/sar/tensorflow.py +10 -12
- doctr/models/recognition/vitstr/pytorch.py +1 -1
- doctr/models/recognition/vitstr/tensorflow.py +9 -4
- doctr/models/recognition/zoo.py +1 -1
- doctr/models/utils/pytorch.py +1 -1
- doctr/models/utils/tensorflow.py +15 -15
- doctr/models/zoo.py +2 -2
- doctr/py.typed +0 -0
- doctr/transforms/functional/base.py +1 -1
- doctr/transforms/functional/pytorch.py +5 -5
- doctr/transforms/modules/base.py +37 -15
- doctr/transforms/modules/pytorch.py +73 -14
- doctr/transforms/modules/tensorflow.py +78 -19
- doctr/utils/fonts.py +7 -5
- doctr/utils/geometry.py +141 -31
- doctr/utils/metrics.py +34 -175
- doctr/utils/reconstitution.py +212 -0
- doctr/utils/visualization.py +5 -118
- doctr/version.py +1 -1
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/METADATA +85 -81
- python_doctr-0.10.0.dist-info/RECORD +173 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
- doctr/models/artefacts/__init__.py +0 -2
- doctr/models/artefacts/barcode.py +0 -74
- doctr/models/artefacts/face.py +0 -63
- doctr/models/obj_detection/__init__.py +0 -1
- doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
- python_doctr-0.8.1.dist-info/RECORD +0 -173
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
|
@@ -10,10 +10,10 @@ import torch
|
|
|
10
10
|
from torch import nn
|
|
11
11
|
|
|
12
12
|
from doctr.io.elements import Document
|
|
13
|
-
from doctr.models._utils import
|
|
13
|
+
from doctr.models._utils import get_language, invert_data_structure
|
|
14
14
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
15
15
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
16
|
-
from doctr.utils.geometry import
|
|
16
|
+
from doctr.utils.geometry import detach_scores
|
|
17
17
|
|
|
18
18
|
from .base import _KIEPredictor
|
|
19
19
|
|
|
@@ -55,7 +55,13 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
55
55
|
self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
|
|
56
56
|
self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
|
|
57
57
|
_KIEPredictor.__init__(
|
|
58
|
-
self,
|
|
58
|
+
self,
|
|
59
|
+
assume_straight_pages,
|
|
60
|
+
straighten_pages,
|
|
61
|
+
preserve_aspect_ratio,
|
|
62
|
+
symmetric_pad,
|
|
63
|
+
detect_orientation,
|
|
64
|
+
**kwargs,
|
|
59
65
|
)
|
|
60
66
|
self.detect_orientation = detect_orientation
|
|
61
67
|
self.detect_language = detect_language
|
|
@@ -83,29 +89,34 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
83
89
|
for out_map in out_maps
|
|
84
90
|
]
|
|
85
91
|
if self.detect_orientation:
|
|
86
|
-
|
|
92
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
|
|
87
93
|
orientations = [
|
|
88
|
-
{"value": orientation_page, "confidence": None} for orientation_page in
|
|
94
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
89
95
|
]
|
|
90
96
|
else:
|
|
91
97
|
orientations = None
|
|
98
|
+
general_pages_orientations = None
|
|
99
|
+
origin_pages_orientations = None
|
|
92
100
|
if self.straighten_pages:
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
)
|
|
98
|
-
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
101
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
|
|
102
|
+
# update page shapes after straightening
|
|
103
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
104
|
+
|
|
99
105
|
# Forward again to get predictions on straight pages
|
|
100
106
|
loc_preds = self.det_predictor(pages, **kwargs)
|
|
101
107
|
|
|
102
108
|
dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
|
|
109
|
+
|
|
110
|
+
# Detach objectness scores from loc_preds
|
|
111
|
+
objectness_scores = {}
|
|
112
|
+
for class_name, det_preds in dict_loc_preds.items():
|
|
113
|
+
_loc_preds, _scores = detach_scores(det_preds)
|
|
114
|
+
dict_loc_preds[class_name] = _loc_preds
|
|
115
|
+
objectness_scores[class_name] = _scores
|
|
116
|
+
|
|
103
117
|
# Check whether crop mode should be switched to channels first
|
|
104
118
|
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
105
119
|
|
|
106
|
-
# Rectify crops if aspect ratio
|
|
107
|
-
dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}
|
|
108
|
-
|
|
109
120
|
# Apply hooks to loc_preds if any
|
|
110
121
|
for hook in self.hooks:
|
|
111
122
|
dict_loc_preds = hook(dict_loc_preds)
|
|
@@ -114,32 +125,44 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
114
125
|
crops = {}
|
|
115
126
|
for class_name in dict_loc_preds.keys():
|
|
116
127
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
117
|
-
pages,
|
|
128
|
+
pages, # type: ignore[arg-type]
|
|
118
129
|
dict_loc_preds[class_name],
|
|
119
130
|
channels_last=channels_last,
|
|
120
131
|
assume_straight_pages=self.assume_straight_pages,
|
|
132
|
+
assume_horizontal=self._page_orientation_disabled,
|
|
121
133
|
)
|
|
122
134
|
# Rectify crop orientation
|
|
135
|
+
crop_orientations: Any = {}
|
|
123
136
|
if not self.assume_straight_pages:
|
|
124
137
|
for class_name in dict_loc_preds.keys():
|
|
125
|
-
crops[class_name], dict_loc_preds[class_name] = self._rectify_crops(
|
|
138
|
+
crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
|
|
126
139
|
crops[class_name], dict_loc_preds[class_name]
|
|
127
140
|
)
|
|
141
|
+
crop_orientations[class_name] = [
|
|
142
|
+
{"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
|
|
143
|
+
]
|
|
144
|
+
|
|
128
145
|
# Identify character sequences
|
|
129
146
|
word_preds = {
|
|
130
147
|
k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
|
|
131
148
|
for k, crop_value in crops.items()
|
|
132
149
|
}
|
|
150
|
+
if not crop_orientations:
|
|
151
|
+
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
|
|
133
152
|
|
|
134
153
|
boxes: Dict = {}
|
|
135
154
|
text_preds: Dict = {}
|
|
155
|
+
word_crop_orientations: Dict = {}
|
|
136
156
|
for class_name in dict_loc_preds.keys():
|
|
137
|
-
boxes[class_name], text_preds[class_name] = self._process_predictions(
|
|
138
|
-
dict_loc_preds[class_name], word_preds[class_name]
|
|
157
|
+
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
|
|
158
|
+
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
|
|
139
159
|
)
|
|
140
160
|
|
|
141
161
|
boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
|
|
162
|
+
objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
|
|
142
163
|
text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
|
|
164
|
+
crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
|
|
165
|
+
|
|
143
166
|
if self.detect_language:
|
|
144
167
|
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
|
|
145
168
|
languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
|
|
@@ -147,10 +170,12 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
147
170
|
languages_dict = None
|
|
148
171
|
|
|
149
172
|
out = self.doc_builder(
|
|
150
|
-
pages,
|
|
173
|
+
pages, # type: ignore[arg-type]
|
|
151
174
|
boxes_per_page,
|
|
175
|
+
objectness_scores_per_page,
|
|
152
176
|
text_preds_per_page,
|
|
153
|
-
origin_page_shapes,
|
|
177
|
+
origin_page_shapes, # type: ignore[arg-type]
|
|
178
|
+
crop_orientations_per_page,
|
|
154
179
|
orientations,
|
|
155
180
|
languages_dict,
|
|
156
181
|
)
|
|
@@ -9,10 +9,10 @@ import numpy as np
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
|
|
11
11
|
from doctr.io.elements import Document
|
|
12
|
-
from doctr.models._utils import
|
|
12
|
+
from doctr.models._utils import get_language, invert_data_structure
|
|
13
13
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
14
14
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
15
|
-
from doctr.utils.geometry import
|
|
15
|
+
from doctr.utils.geometry import detach_scores
|
|
16
16
|
from doctr.utils.repr import NestedObject
|
|
17
17
|
|
|
18
18
|
from .base import _KIEPredictor
|
|
@@ -56,7 +56,13 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
56
56
|
self.det_predictor = det_predictor
|
|
57
57
|
self.reco_predictor = reco_predictor
|
|
58
58
|
_KIEPredictor.__init__(
|
|
59
|
-
self,
|
|
59
|
+
self,
|
|
60
|
+
assume_straight_pages,
|
|
61
|
+
straighten_pages,
|
|
62
|
+
preserve_aspect_ratio,
|
|
63
|
+
symmetric_pad,
|
|
64
|
+
detect_orientation,
|
|
65
|
+
**kwargs,
|
|
60
66
|
)
|
|
61
67
|
self.detect_orientation = detect_orientation
|
|
62
68
|
self.detect_language = detect_language
|
|
@@ -83,25 +89,30 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
83
89
|
for out_map in out_maps
|
|
84
90
|
]
|
|
85
91
|
if self.detect_orientation:
|
|
86
|
-
|
|
92
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
87
93
|
orientations = [
|
|
88
|
-
{"value": orientation_page, "confidence": None} for orientation_page in
|
|
94
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
89
95
|
]
|
|
90
96
|
else:
|
|
91
97
|
orientations = None
|
|
98
|
+
general_pages_orientations = None
|
|
99
|
+
origin_pages_orientations = None
|
|
92
100
|
if self.straighten_pages:
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
)
|
|
98
|
-
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
101
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
102
|
+
# update page shapes after straightening
|
|
103
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
104
|
+
|
|
99
105
|
# Forward again to get predictions on straight pages
|
|
100
106
|
loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
|
|
101
107
|
|
|
102
108
|
dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
|
|
103
|
-
|
|
104
|
-
|
|
109
|
+
|
|
110
|
+
# Detach objectness scores from loc_preds
|
|
111
|
+
objectness_scores = {}
|
|
112
|
+
for class_name, det_preds in dict_loc_preds.items():
|
|
113
|
+
_loc_preds, _scores = detach_scores(det_preds)
|
|
114
|
+
dict_loc_preds[class_name] = _loc_preds
|
|
115
|
+
objectness_scores[class_name] = _scores
|
|
105
116
|
|
|
106
117
|
# Apply hooks to loc_preds if any
|
|
107
118
|
for hook in self.hooks:
|
|
@@ -111,30 +122,44 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
111
122
|
crops = {}
|
|
112
123
|
for class_name in dict_loc_preds.keys():
|
|
113
124
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
114
|
-
pages,
|
|
125
|
+
pages,
|
|
126
|
+
dict_loc_preds[class_name],
|
|
127
|
+
channels_last=True,
|
|
128
|
+
assume_straight_pages=self.assume_straight_pages,
|
|
129
|
+
assume_horizontal=self._page_orientation_disabled,
|
|
115
130
|
)
|
|
131
|
+
|
|
116
132
|
# Rectify crop orientation
|
|
133
|
+
crop_orientations: Any = {}
|
|
117
134
|
if not self.assume_straight_pages:
|
|
118
135
|
for class_name in dict_loc_preds.keys():
|
|
119
|
-
crops[class_name], dict_loc_preds[class_name] = self._rectify_crops(
|
|
136
|
+
crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
|
|
120
137
|
crops[class_name], dict_loc_preds[class_name]
|
|
121
138
|
)
|
|
139
|
+
crop_orientations[class_name] = [
|
|
140
|
+
{"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
|
|
141
|
+
]
|
|
122
142
|
|
|
123
143
|
# Identify character sequences
|
|
124
144
|
word_preds = {
|
|
125
145
|
k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
|
|
126
146
|
for k, crop_value in crops.items()
|
|
127
147
|
}
|
|
148
|
+
if not crop_orientations:
|
|
149
|
+
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
|
|
128
150
|
|
|
129
151
|
boxes: Dict = {}
|
|
130
152
|
text_preds: Dict = {}
|
|
153
|
+
word_crop_orientations: Dict = {}
|
|
131
154
|
for class_name in dict_loc_preds.keys():
|
|
132
|
-
boxes[class_name], text_preds[class_name] = self._process_predictions(
|
|
133
|
-
dict_loc_preds[class_name], word_preds[class_name]
|
|
155
|
+
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
|
|
156
|
+
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
|
|
134
157
|
)
|
|
135
158
|
|
|
136
159
|
boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
|
|
160
|
+
objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
|
|
137
161
|
text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
|
|
162
|
+
crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
|
|
138
163
|
|
|
139
164
|
if self.detect_language:
|
|
140
165
|
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
|
|
@@ -145,8 +170,10 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
145
170
|
out = self.doc_builder(
|
|
146
171
|
pages,
|
|
147
172
|
boxes_per_page,
|
|
173
|
+
objectness_scores_per_page,
|
|
148
174
|
text_preds_per_page,
|
|
149
175
|
origin_page_shapes, # type: ignore[arg-type]
|
|
176
|
+
crop_orientations_per_page,
|
|
150
177
|
orientations,
|
|
151
178
|
languages_dict,
|
|
152
179
|
)
|
|
@@ -87,7 +87,7 @@ class FASTConvLayer(nn.Module):
|
|
|
87
87
|
horizontal_outputs = (
|
|
88
88
|
self.hor_bn(self.hor_conv(x)) if self.hor_bn is not None and self.hor_conv is not None else 0
|
|
89
89
|
)
|
|
90
|
-
id_out = self.rbr_identity(x) if self.rbr_identity is not None
|
|
90
|
+
id_out = self.rbr_identity(x) if self.rbr_identity is not None else 0
|
|
91
91
|
|
|
92
92
|
return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
|
|
93
93
|
|
|
@@ -106,7 +106,7 @@ class FASTConvLayer(nn.Module):
|
|
|
106
106
|
id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
|
|
107
107
|
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
108
108
|
kernel = self.id_tensor
|
|
109
|
-
std = (identity.running_var + identity.eps).sqrt()
|
|
109
|
+
std = (identity.running_var + identity.eps).sqrt()
|
|
110
110
|
t = (identity.weight / std).reshape(-1, 1, 1, 1)
|
|
111
111
|
return kernel * t, identity.bias - identity.running_mean * identity.weight / std
|
|
112
112
|
|
|
@@ -155,7 +155,6 @@ class FASTConvLayer(nn.Module):
|
|
|
155
155
|
)
|
|
156
156
|
self.fused_conv.weight.data = kernel
|
|
157
157
|
self.fused_conv.bias.data = bias # type: ignore[union-attr]
|
|
158
|
-
self.deploy = True
|
|
159
158
|
for para in self.parameters():
|
|
160
159
|
para.detach_()
|
|
161
160
|
for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
|
|
@@ -97,7 +97,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
97
97
|
if self.hor_bn is not None and self.hor_conv is not None
|
|
98
98
|
else 0
|
|
99
99
|
)
|
|
100
|
-
id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None
|
|
100
|
+
id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0
|
|
101
101
|
|
|
102
102
|
return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
|
|
103
103
|
|
|
@@ -110,14 +110,14 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
110
110
|
return 0, 0
|
|
111
111
|
if not hasattr(self, "id_tensor"):
|
|
112
112
|
input_dim = self.in_channels // self.groups
|
|
113
|
-
kernel_value = np.zeros((
|
|
113
|
+
kernel_value = np.zeros((1, 1, input_dim, self.in_channels), dtype=np.float32)
|
|
114
114
|
for i in range(self.in_channels):
|
|
115
|
-
kernel_value[
|
|
115
|
+
kernel_value[0, 0, i % input_dim, i] = 1
|
|
116
116
|
id_tensor = tf.constant(kernel_value, dtype=tf.float32)
|
|
117
117
|
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
118
118
|
kernel = self.id_tensor
|
|
119
119
|
std = tf.sqrt(identity.moving_variance + identity.epsilon)
|
|
120
|
-
t = tf.reshape(identity.gamma / std, (
|
|
120
|
+
t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
|
|
121
121
|
return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
|
|
122
122
|
|
|
123
123
|
def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
@@ -138,18 +138,16 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
138
138
|
else:
|
|
139
139
|
kernel_1xn, bias_1xn = 0, 0
|
|
140
140
|
kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
|
|
141
|
-
if not isinstance(kernel_id, int):
|
|
142
|
-
kernel_id = tf.transpose(kernel_id, (2, 3, 0, 1))
|
|
143
141
|
kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
|
|
144
142
|
bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
|
|
145
143
|
return kernel_mxn, bias_mxn
|
|
146
144
|
|
|
147
145
|
def _pad_to_mxn_tensor(self, kernel: tf.Tensor) -> tf.Tensor:
|
|
148
146
|
kernel_height, kernel_width = self.converted_ks
|
|
149
|
-
height, width = kernel.shape[2
|
|
147
|
+
height, width = kernel.shape[:2]
|
|
150
148
|
pad_left_right = tf.maximum(0, (kernel_width - width) // 2)
|
|
151
149
|
pad_top_down = tf.maximum(0, (kernel_height - height) // 2)
|
|
152
|
-
return tf.pad(kernel, [[
|
|
150
|
+
return tf.pad(kernel, [[pad_top_down, pad_top_down], [pad_left_right, pad_left_right], [0, 0], [0, 0]])
|
|
153
151
|
|
|
154
152
|
def reparameterize_layer(self):
|
|
155
153
|
kernel, bias = self._get_equivalent_kernel_bias()
|
|
@@ -51,8 +51,8 @@ def scaled_dot_product_attention(
|
|
|
51
51
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
|
|
52
52
|
if mask is not None:
|
|
53
53
|
# NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
|
|
54
|
-
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
55
|
-
p_attn = torch.softmax(scores, dim=-1)
|
|
54
|
+
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
55
|
+
p_attn = torch.softmax(scores, dim=-1)
|
|
56
56
|
return torch.matmul(p_attn, value), p_attn
|
|
57
57
|
|
|
58
58
|
|
|
@@ -13,8 +13,6 @@ from doctr.utils.repr import NestedObject
|
|
|
13
13
|
|
|
14
14
|
__all__ = ["Decoder", "PositionalEncoding", "EncoderBlock", "PositionwiseFeedForward", "MultiHeadAttention"]
|
|
15
15
|
|
|
16
|
-
tf.config.run_functions_eagerly(True)
|
|
17
|
-
|
|
18
16
|
|
|
19
17
|
class PositionalEncoding(layers.Layer, NestedObject):
|
|
20
18
|
"""Compute positional encoding"""
|
|
@@ -20,7 +20,7 @@ class PatchEmbedding(nn.Module):
|
|
|
20
20
|
channels, height, width = input_shape
|
|
21
21
|
self.patch_size = patch_size
|
|
22
22
|
self.interpolate = True if patch_size[0] == patch_size[1] else False
|
|
23
|
-
self.grid_size = tuple(
|
|
23
|
+
self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size))
|
|
24
24
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
25
25
|
|
|
26
26
|
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
|
|
@@ -22,7 +22,7 @@ class PatchEmbedding(layers.Layer, NestedObject):
|
|
|
22
22
|
height, width, _ = input_shape
|
|
23
23
|
self.patch_size = patch_size
|
|
24
24
|
self.interpolate = True if patch_size[0] == patch_size[1] else False
|
|
25
|
-
self.grid_size = tuple(
|
|
25
|
+
self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size))
|
|
26
26
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
27
27
|
|
|
28
28
|
self.cls_token = self.add_weight(shape=(1, 1, embed_dim), initializer="zeros", trainable=True, name="cls_token")
|
doctr/models/predictor/base.py
CHANGED
|
@@ -3,16 +3,16 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any, Callable, List, Optional, Tuple
|
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
from doctr.models.builder import DocumentBuilder
|
|
11
|
-
from doctr.utils.geometry import extract_crops, extract_rcrops
|
|
11
|
+
from doctr.utils.geometry import extract_crops, extract_rcrops, remove_image_padding, rotate_image
|
|
12
12
|
|
|
13
|
-
from .._utils import rectify_crops, rectify_loc_preds
|
|
14
|
-
from ..classification import crop_orientation_predictor
|
|
15
|
-
from ..classification.predictor import
|
|
13
|
+
from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
|
|
14
|
+
from ..classification import crop_orientation_predictor, page_orientation_predictor
|
|
15
|
+
from ..classification.predictor import OrientationPredictor
|
|
16
16
|
|
|
17
17
|
__all__ = ["_OCRPredictor"]
|
|
18
18
|
|
|
@@ -29,10 +29,13 @@ class _OCRPredictor:
|
|
|
29
29
|
accordingly. Doing so will improve performances for documents with page-uniform rotations.
|
|
30
30
|
preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
|
|
31
31
|
symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
|
|
32
|
+
detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
|
|
33
|
+
page. Doing so will slightly deteriorate the overall latency.
|
|
32
34
|
**kwargs: keyword args of `DocumentBuilder`
|
|
33
35
|
"""
|
|
34
36
|
|
|
35
|
-
crop_orientation_predictor: Optional[
|
|
37
|
+
crop_orientation_predictor: Optional[OrientationPredictor]
|
|
38
|
+
page_orientation_predictor: Optional[OrientationPredictor]
|
|
36
39
|
|
|
37
40
|
def __init__(
|
|
38
41
|
self,
|
|
@@ -40,29 +43,93 @@ class _OCRPredictor:
|
|
|
40
43
|
straighten_pages: bool = False,
|
|
41
44
|
preserve_aspect_ratio: bool = True,
|
|
42
45
|
symmetric_pad: bool = True,
|
|
46
|
+
detect_orientation: bool = False,
|
|
43
47
|
**kwargs: Any,
|
|
44
48
|
) -> None:
|
|
45
49
|
self.assume_straight_pages = assume_straight_pages
|
|
46
50
|
self.straighten_pages = straighten_pages
|
|
47
|
-
self.
|
|
51
|
+
self._page_orientation_disabled = kwargs.pop("disable_page_orientation", False)
|
|
52
|
+
self._crop_orientation_disabled = kwargs.pop("disable_crop_orientation", False)
|
|
53
|
+
self.crop_orientation_predictor = (
|
|
54
|
+
None
|
|
55
|
+
if assume_straight_pages
|
|
56
|
+
else crop_orientation_predictor(pretrained=True, disabled=self._crop_orientation_disabled)
|
|
57
|
+
)
|
|
58
|
+
self.page_orientation_predictor = (
|
|
59
|
+
page_orientation_predictor(pretrained=True, disabled=self._page_orientation_disabled)
|
|
60
|
+
if detect_orientation or straighten_pages or not assume_straight_pages
|
|
61
|
+
else None
|
|
62
|
+
)
|
|
48
63
|
self.doc_builder = DocumentBuilder(**kwargs)
|
|
49
64
|
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
50
65
|
self.symmetric_pad = symmetric_pad
|
|
51
66
|
self.hooks: List[Callable] = []
|
|
52
67
|
|
|
68
|
+
def _general_page_orientations(
|
|
69
|
+
self,
|
|
70
|
+
pages: List[np.ndarray],
|
|
71
|
+
) -> List[Tuple[int, float]]:
|
|
72
|
+
_, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
|
|
73
|
+
# Flatten to list of tuples with (value, confidence)
|
|
74
|
+
page_orientations = [
|
|
75
|
+
(orientation, prob)
|
|
76
|
+
for page_classes, page_probs in zip(classes, probs)
|
|
77
|
+
for orientation, prob in zip(page_classes, page_probs)
|
|
78
|
+
]
|
|
79
|
+
return page_orientations
|
|
80
|
+
|
|
81
|
+
def _get_orientations(
|
|
82
|
+
self, pages: List[np.ndarray], seg_maps: List[np.ndarray]
|
|
83
|
+
) -> Tuple[List[Tuple[int, float]], List[int]]:
|
|
84
|
+
general_pages_orientations = self._general_page_orientations(pages)
|
|
85
|
+
origin_page_orientations = [
|
|
86
|
+
estimate_orientation(seq_map, general_orientation)
|
|
87
|
+
for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
|
|
88
|
+
]
|
|
89
|
+
return general_pages_orientations, origin_page_orientations
|
|
90
|
+
|
|
91
|
+
def _straighten_pages(
|
|
92
|
+
self,
|
|
93
|
+
pages: List[np.ndarray],
|
|
94
|
+
seg_maps: List[np.ndarray],
|
|
95
|
+
general_pages_orientations: Optional[List[Tuple[int, float]]] = None,
|
|
96
|
+
origin_pages_orientations: Optional[List[int]] = None,
|
|
97
|
+
) -> List[np.ndarray]:
|
|
98
|
+
general_pages_orientations = (
|
|
99
|
+
general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
|
|
100
|
+
)
|
|
101
|
+
origin_pages_orientations = (
|
|
102
|
+
origin_pages_orientations
|
|
103
|
+
if origin_pages_orientations
|
|
104
|
+
else [
|
|
105
|
+
estimate_orientation(seq_map, general_orientation)
|
|
106
|
+
for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
|
|
107
|
+
]
|
|
108
|
+
)
|
|
109
|
+
return [
|
|
110
|
+
# expand if height and width are not equal, then remove the padding
|
|
111
|
+
remove_image_padding(rotate_image(page, angle, expand=page.shape[0] != page.shape[1]))
|
|
112
|
+
for page, angle in zip(pages, origin_pages_orientations)
|
|
113
|
+
]
|
|
114
|
+
|
|
53
115
|
@staticmethod
|
|
54
116
|
def _generate_crops(
|
|
55
117
|
pages: List[np.ndarray],
|
|
56
118
|
loc_preds: List[np.ndarray],
|
|
57
119
|
channels_last: bool,
|
|
58
120
|
assume_straight_pages: bool = False,
|
|
121
|
+
assume_horizontal: bool = False,
|
|
59
122
|
) -> List[List[np.ndarray]]:
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
123
|
+
if assume_straight_pages:
|
|
124
|
+
crops = [
|
|
125
|
+
extract_crops(page, _boxes[:, :4], channels_last=channels_last)
|
|
126
|
+
for page, _boxes in zip(pages, loc_preds)
|
|
127
|
+
]
|
|
128
|
+
else:
|
|
129
|
+
crops = [
|
|
130
|
+
extract_rcrops(page, _boxes[:, :4], channels_last=channels_last, assume_horizontal=assume_horizontal)
|
|
131
|
+
for page, _boxes in zip(pages, loc_preds)
|
|
132
|
+
]
|
|
66
133
|
return crops
|
|
67
134
|
|
|
68
135
|
@staticmethod
|
|
@@ -71,8 +138,9 @@ class _OCRPredictor:
|
|
|
71
138
|
loc_preds: List[np.ndarray],
|
|
72
139
|
channels_last: bool,
|
|
73
140
|
assume_straight_pages: bool = False,
|
|
141
|
+
assume_horizontal: bool = False,
|
|
74
142
|
) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
|
|
75
|
-
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages)
|
|
143
|
+
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)
|
|
76
144
|
|
|
77
145
|
# Avoid sending zero-sized crops
|
|
78
146
|
is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
|
|
@@ -88,68 +156,39 @@ class _OCRPredictor:
|
|
|
88
156
|
self,
|
|
89
157
|
crops: List[List[np.ndarray]],
|
|
90
158
|
loc_preds: List[np.ndarray],
|
|
91
|
-
) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
|
|
159
|
+
) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]:
|
|
92
160
|
# Work at a page level
|
|
93
|
-
orientations = [self.crop_orientation_predictor(page_crops) for page_crops in crops] # type: ignore[misc]
|
|
161
|
+
orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc]
|
|
94
162
|
rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
|
|
95
163
|
rect_loc_preds = [
|
|
96
164
|
rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds
|
|
97
165
|
for page_loc_preds, orientation in zip(loc_preds, orientations)
|
|
98
166
|
]
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
if self.preserve_aspect_ratio:
|
|
107
|
-
# Rectify loc_preds to remove padding
|
|
108
|
-
rectified_preds = []
|
|
109
|
-
for page, loc_pred in zip(pages, loc_preds):
|
|
110
|
-
h, w = page.shape[0], page.shape[1]
|
|
111
|
-
if h > w:
|
|
112
|
-
# y unchanged, dilate x coord
|
|
113
|
-
if self.symmetric_pad:
|
|
114
|
-
if self.assume_straight_pages:
|
|
115
|
-
loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1)
|
|
116
|
-
else:
|
|
117
|
-
loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1)
|
|
118
|
-
else:
|
|
119
|
-
if self.assume_straight_pages:
|
|
120
|
-
loc_pred[:, [0, 2]] *= h / w
|
|
121
|
-
else:
|
|
122
|
-
loc_pred[:, :, 0] *= h / w
|
|
123
|
-
elif w > h:
|
|
124
|
-
# x unchanged, dilate y coord
|
|
125
|
-
if self.symmetric_pad:
|
|
126
|
-
if self.assume_straight_pages:
|
|
127
|
-
loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1)
|
|
128
|
-
else:
|
|
129
|
-
loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1)
|
|
130
|
-
else:
|
|
131
|
-
if self.assume_straight_pages:
|
|
132
|
-
loc_pred[:, [1, 3]] *= w / h
|
|
133
|
-
else:
|
|
134
|
-
loc_pred[:, :, 1] *= w / h
|
|
135
|
-
rectified_preds.append(loc_pred)
|
|
136
|
-
return rectified_preds
|
|
137
|
-
return loc_preds
|
|
167
|
+
# Flatten to list of tuples with (value, confidence)
|
|
168
|
+
crop_orientations = [
|
|
169
|
+
(orientation, prob)
|
|
170
|
+
for page_classes, page_probs in zip(classes, probs)
|
|
171
|
+
for orientation, prob in zip(page_classes, page_probs)
|
|
172
|
+
]
|
|
173
|
+
return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value]
|
|
138
174
|
|
|
139
175
|
@staticmethod
|
|
140
176
|
def _process_predictions(
|
|
141
177
|
loc_preds: List[np.ndarray],
|
|
142
178
|
word_preds: List[Tuple[str, float]],
|
|
143
|
-
|
|
179
|
+
crop_orientations: List[Dict[str, Any]],
|
|
180
|
+
) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]:
|
|
144
181
|
text_preds = []
|
|
182
|
+
crop_orientation_preds = []
|
|
145
183
|
if len(loc_preds) > 0:
|
|
146
|
-
# Text
|
|
184
|
+
# Text & crop orientation predictions at page level
|
|
147
185
|
_idx = 0
|
|
148
186
|
for page_boxes in loc_preds:
|
|
149
187
|
text_preds.append(word_preds[_idx : _idx + page_boxes.shape[0]])
|
|
188
|
+
crop_orientation_preds.append(crop_orientations[_idx : _idx + page_boxes.shape[0]])
|
|
150
189
|
_idx += page_boxes.shape[0]
|
|
151
190
|
|
|
152
|
-
return loc_preds, text_preds
|
|
191
|
+
return loc_preds, text_preds, crop_orientation_preds
|
|
153
192
|
|
|
154
193
|
def add_hook(self, hook: Callable) -> None:
|
|
155
194
|
"""Add a hook to the predictor
|