python-doctr 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +1 -1
- doctr/contrib/__init__.py +0 -0
- doctr/contrib/artefacts.py +131 -0
- doctr/contrib/base.py +105 -0
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/generator/base.py +6 -5
- doctr/datasets/imgur5k.py +1 -1
- doctr/datasets/loader.py +1 -6
- doctr/datasets/utils.py +2 -1
- doctr/datasets/vocabs.py +9 -2
- doctr/file_utils.py +26 -12
- doctr/io/elements.py +40 -6
- doctr/io/html.py +2 -2
- doctr/io/image/pytorch.py +6 -8
- doctr/io/image/tensorflow.py +1 -1
- doctr/io/pdf.py +5 -2
- doctr/io/reader.py +6 -0
- doctr/models/__init__.py +0 -1
- doctr/models/_utils.py +57 -20
- doctr/models/builder.py +71 -13
- doctr/models/classification/mobilenet/pytorch.py +45 -9
- doctr/models/classification/mobilenet/tensorflow.py +38 -7
- doctr/models/classification/predictor/pytorch.py +18 -11
- doctr/models/classification/predictor/tensorflow.py +16 -10
- doctr/models/classification/textnet/pytorch.py +3 -3
- doctr/models/classification/textnet/tensorflow.py +3 -3
- doctr/models/classification/zoo.py +39 -15
- doctr/models/detection/_utils/__init__.py +1 -0
- doctr/models/detection/_utils/base.py +66 -0
- doctr/models/detection/differentiable_binarization/base.py +4 -3
- doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
- doctr/models/detection/fast/base.py +6 -5
- doctr/models/detection/fast/pytorch.py +4 -4
- doctr/models/detection/fast/tensorflow.py +4 -4
- doctr/models/detection/linknet/base.py +4 -3
- doctr/models/detection/predictor/pytorch.py +15 -1
- doctr/models/detection/predictor/tensorflow.py +15 -1
- doctr/models/detection/zoo.py +7 -2
- doctr/models/factory/hub.py +3 -12
- doctr/models/kie_predictor/base.py +9 -3
- doctr/models/kie_predictor/pytorch.py +41 -20
- doctr/models/kie_predictor/tensorflow.py +36 -16
- doctr/models/modules/layers/pytorch.py +2 -3
- doctr/models/modules/layers/tensorflow.py +6 -8
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/predictor/base.py +77 -50
- doctr/models/predictor/pytorch.py +31 -20
- doctr/models/predictor/tensorflow.py +27 -17
- doctr/models/preprocessor/pytorch.py +4 -4
- doctr/models/preprocessor/tensorflow.py +3 -2
- doctr/models/recognition/master/pytorch.py +2 -2
- doctr/models/recognition/parseq/pytorch.py +4 -3
- doctr/models/recognition/parseq/tensorflow.py +4 -3
- doctr/models/recognition/sar/pytorch.py +7 -6
- doctr/models/recognition/sar/tensorflow.py +3 -9
- doctr/models/recognition/vitstr/pytorch.py +1 -1
- doctr/models/recognition/zoo.py +1 -1
- doctr/models/zoo.py +2 -2
- doctr/py.typed +0 -0
- doctr/transforms/functional/base.py +1 -1
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/base.py +37 -15
- doctr/transforms/modules/pytorch.py +66 -8
- doctr/transforms/modules/tensorflow.py +63 -7
- doctr/utils/fonts.py +7 -5
- doctr/utils/geometry.py +35 -12
- doctr/utils/metrics.py +33 -174
- doctr/utils/reconstitution.py +126 -0
- doctr/utils/visualization.py +5 -118
- doctr/version.py +1 -1
- {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/METADATA +84 -80
- {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/RECORD +76 -76
- {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/WHEEL +1 -1
- doctr/models/artefacts/__init__.py +0 -2
- doctr/models/artefacts/barcode.py +0 -74
- doctr/models/artefacts/face.py +0 -63
- doctr/models/obj_detection/__init__.py +0 -1
- doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
- {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/zip-safe +0 -0
|
@@ -10,10 +10,10 @@ import torch
|
|
|
10
10
|
from torch import nn
|
|
11
11
|
|
|
12
12
|
from doctr.io.elements import Document
|
|
13
|
-
from doctr.models._utils import
|
|
13
|
+
from doctr.models._utils import get_language, invert_data_structure
|
|
14
14
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
15
15
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
16
|
-
from doctr.utils.geometry import
|
|
16
|
+
from doctr.utils.geometry import detach_scores
|
|
17
17
|
|
|
18
18
|
from .base import _KIEPredictor
|
|
19
19
|
|
|
@@ -55,7 +55,13 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
55
55
|
self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
|
|
56
56
|
self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
|
|
57
57
|
_KIEPredictor.__init__(
|
|
58
|
-
self,
|
|
58
|
+
self,
|
|
59
|
+
assume_straight_pages,
|
|
60
|
+
straighten_pages,
|
|
61
|
+
preserve_aspect_ratio,
|
|
62
|
+
symmetric_pad,
|
|
63
|
+
detect_orientation,
|
|
64
|
+
**kwargs,
|
|
59
65
|
)
|
|
60
66
|
self.detect_orientation = detect_orientation
|
|
61
67
|
self.detect_language = detect_language
|
|
@@ -83,29 +89,31 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
83
89
|
for out_map in out_maps
|
|
84
90
|
]
|
|
85
91
|
if self.detect_orientation:
|
|
86
|
-
|
|
92
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
|
|
87
93
|
orientations = [
|
|
88
|
-
{"value": orientation_page, "confidence": None} for orientation_page in
|
|
94
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
89
95
|
]
|
|
90
96
|
else:
|
|
91
97
|
orientations = None
|
|
98
|
+
general_pages_orientations = None
|
|
99
|
+
origin_pages_orientations = None
|
|
92
100
|
if self.straighten_pages:
|
|
93
|
-
|
|
94
|
-
origin_page_orientations
|
|
95
|
-
if self.detect_orientation
|
|
96
|
-
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
97
|
-
)
|
|
98
|
-
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
101
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
|
|
99
102
|
# Forward again to get predictions on straight pages
|
|
100
103
|
loc_preds = self.det_predictor(pages, **kwargs)
|
|
101
104
|
|
|
102
105
|
dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
|
|
106
|
+
|
|
107
|
+
# Detach objectness scores from loc_preds
|
|
108
|
+
objectness_scores = {}
|
|
109
|
+
for class_name, det_preds in dict_loc_preds.items():
|
|
110
|
+
_loc_preds, _scores = detach_scores(det_preds)
|
|
111
|
+
dict_loc_preds[class_name] = _loc_preds
|
|
112
|
+
objectness_scores[class_name] = _scores
|
|
113
|
+
|
|
103
114
|
# Check whether crop mode should be switched to channels first
|
|
104
115
|
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
105
116
|
|
|
106
|
-
# Rectify crops if aspect ratio
|
|
107
|
-
dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}
|
|
108
|
-
|
|
109
117
|
# Apply hooks to loc_preds if any
|
|
110
118
|
for hook in self.hooks:
|
|
111
119
|
dict_loc_preds = hook(dict_loc_preds)
|
|
@@ -114,32 +122,43 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
114
122
|
crops = {}
|
|
115
123
|
for class_name in dict_loc_preds.keys():
|
|
116
124
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
117
|
-
pages,
|
|
125
|
+
pages, # type: ignore[arg-type]
|
|
118
126
|
dict_loc_preds[class_name],
|
|
119
127
|
channels_last=channels_last,
|
|
120
128
|
assume_straight_pages=self.assume_straight_pages,
|
|
121
129
|
)
|
|
122
130
|
# Rectify crop orientation
|
|
131
|
+
crop_orientations: Any = {}
|
|
123
132
|
if not self.assume_straight_pages:
|
|
124
133
|
for class_name in dict_loc_preds.keys():
|
|
125
|
-
crops[class_name], dict_loc_preds[class_name] = self._rectify_crops(
|
|
134
|
+
crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
|
|
126
135
|
crops[class_name], dict_loc_preds[class_name]
|
|
127
136
|
)
|
|
137
|
+
crop_orientations[class_name] = [
|
|
138
|
+
{"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
|
|
139
|
+
]
|
|
140
|
+
|
|
128
141
|
# Identify character sequences
|
|
129
142
|
word_preds = {
|
|
130
143
|
k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
|
|
131
144
|
for k, crop_value in crops.items()
|
|
132
145
|
}
|
|
146
|
+
if not crop_orientations:
|
|
147
|
+
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
|
|
133
148
|
|
|
134
149
|
boxes: Dict = {}
|
|
135
150
|
text_preds: Dict = {}
|
|
151
|
+
word_crop_orientations: Dict = {}
|
|
136
152
|
for class_name in dict_loc_preds.keys():
|
|
137
|
-
boxes[class_name], text_preds[class_name] = self._process_predictions(
|
|
138
|
-
dict_loc_preds[class_name], word_preds[class_name]
|
|
153
|
+
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
|
|
154
|
+
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
|
|
139
155
|
)
|
|
140
156
|
|
|
141
157
|
boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
|
|
158
|
+
objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
|
|
142
159
|
text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
|
|
160
|
+
crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
|
|
161
|
+
|
|
143
162
|
if self.detect_language:
|
|
144
163
|
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
|
|
145
164
|
languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
|
|
@@ -147,10 +166,12 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
147
166
|
languages_dict = None
|
|
148
167
|
|
|
149
168
|
out = self.doc_builder(
|
|
150
|
-
pages,
|
|
169
|
+
pages, # type: ignore[arg-type]
|
|
151
170
|
boxes_per_page,
|
|
171
|
+
objectness_scores_per_page,
|
|
152
172
|
text_preds_per_page,
|
|
153
|
-
origin_page_shapes,
|
|
173
|
+
origin_page_shapes, # type: ignore[arg-type]
|
|
174
|
+
crop_orientations_per_page,
|
|
154
175
|
orientations,
|
|
155
176
|
languages_dict,
|
|
156
177
|
)
|
|
@@ -9,10 +9,10 @@ import numpy as np
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
|
|
11
11
|
from doctr.io.elements import Document
|
|
12
|
-
from doctr.models._utils import
|
|
12
|
+
from doctr.models._utils import get_language, invert_data_structure
|
|
13
13
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
14
14
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
15
|
-
from doctr.utils.geometry import
|
|
15
|
+
from doctr.utils.geometry import detach_scores
|
|
16
16
|
from doctr.utils.repr import NestedObject
|
|
17
17
|
|
|
18
18
|
from .base import _KIEPredictor
|
|
@@ -56,7 +56,13 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
56
56
|
self.det_predictor = det_predictor
|
|
57
57
|
self.reco_predictor = reco_predictor
|
|
58
58
|
_KIEPredictor.__init__(
|
|
59
|
-
self,
|
|
59
|
+
self,
|
|
60
|
+
assume_straight_pages,
|
|
61
|
+
straighten_pages,
|
|
62
|
+
preserve_aspect_ratio,
|
|
63
|
+
symmetric_pad,
|
|
64
|
+
detect_orientation,
|
|
65
|
+
**kwargs,
|
|
60
66
|
)
|
|
61
67
|
self.detect_orientation = detect_orientation
|
|
62
68
|
self.detect_language = detect_language
|
|
@@ -83,25 +89,27 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
83
89
|
for out_map in out_maps
|
|
84
90
|
]
|
|
85
91
|
if self.detect_orientation:
|
|
86
|
-
|
|
92
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
87
93
|
orientations = [
|
|
88
|
-
{"value": orientation_page, "confidence": None} for orientation_page in
|
|
94
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
89
95
|
]
|
|
90
96
|
else:
|
|
91
97
|
orientations = None
|
|
98
|
+
general_pages_orientations = None
|
|
99
|
+
origin_pages_orientations = None
|
|
92
100
|
if self.straighten_pages:
|
|
93
|
-
|
|
94
|
-
origin_page_orientations
|
|
95
|
-
if self.detect_orientation
|
|
96
|
-
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
97
|
-
)
|
|
98
|
-
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
101
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
99
102
|
# Forward again to get predictions on straight pages
|
|
100
103
|
loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
|
|
101
104
|
|
|
102
105
|
dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
|
|
103
|
-
|
|
104
|
-
|
|
106
|
+
|
|
107
|
+
# Detach objectness scores from loc_preds
|
|
108
|
+
objectness_scores = {}
|
|
109
|
+
for class_name, det_preds in dict_loc_preds.items():
|
|
110
|
+
_loc_preds, _scores = detach_scores(det_preds)
|
|
111
|
+
dict_loc_preds[class_name] = _loc_preds
|
|
112
|
+
objectness_scores[class_name] = _scores
|
|
105
113
|
|
|
106
114
|
# Apply hooks to loc_preds if any
|
|
107
115
|
for hook in self.hooks:
|
|
@@ -113,28 +121,38 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
113
121
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
114
122
|
pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages
|
|
115
123
|
)
|
|
124
|
+
|
|
116
125
|
# Rectify crop orientation
|
|
126
|
+
crop_orientations: Any = {}
|
|
117
127
|
if not self.assume_straight_pages:
|
|
118
128
|
for class_name in dict_loc_preds.keys():
|
|
119
|
-
crops[class_name], dict_loc_preds[class_name] = self._rectify_crops(
|
|
129
|
+
crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
|
|
120
130
|
crops[class_name], dict_loc_preds[class_name]
|
|
121
131
|
)
|
|
132
|
+
crop_orientations[class_name] = [
|
|
133
|
+
{"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
|
|
134
|
+
]
|
|
122
135
|
|
|
123
136
|
# Identify character sequences
|
|
124
137
|
word_preds = {
|
|
125
138
|
k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
|
|
126
139
|
for k, crop_value in crops.items()
|
|
127
140
|
}
|
|
141
|
+
if not crop_orientations:
|
|
142
|
+
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
|
|
128
143
|
|
|
129
144
|
boxes: Dict = {}
|
|
130
145
|
text_preds: Dict = {}
|
|
146
|
+
word_crop_orientations: Dict = {}
|
|
131
147
|
for class_name in dict_loc_preds.keys():
|
|
132
|
-
boxes[class_name], text_preds[class_name] = self._process_predictions(
|
|
133
|
-
dict_loc_preds[class_name], word_preds[class_name]
|
|
148
|
+
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
|
|
149
|
+
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
|
|
134
150
|
)
|
|
135
151
|
|
|
136
152
|
boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
|
|
153
|
+
objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
|
|
137
154
|
text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
|
|
155
|
+
crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
|
|
138
156
|
|
|
139
157
|
if self.detect_language:
|
|
140
158
|
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
|
|
@@ -145,8 +163,10 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
145
163
|
out = self.doc_builder(
|
|
146
164
|
pages,
|
|
147
165
|
boxes_per_page,
|
|
166
|
+
objectness_scores_per_page,
|
|
148
167
|
text_preds_per_page,
|
|
149
168
|
origin_page_shapes, # type: ignore[arg-type]
|
|
169
|
+
crop_orientations_per_page,
|
|
150
170
|
orientations,
|
|
151
171
|
languages_dict,
|
|
152
172
|
)
|
|
@@ -87,7 +87,7 @@ class FASTConvLayer(nn.Module):
|
|
|
87
87
|
horizontal_outputs = (
|
|
88
88
|
self.hor_bn(self.hor_conv(x)) if self.hor_bn is not None and self.hor_conv is not None else 0
|
|
89
89
|
)
|
|
90
|
-
id_out = self.rbr_identity(x) if self.rbr_identity is not None
|
|
90
|
+
id_out = self.rbr_identity(x) if self.rbr_identity is not None else 0
|
|
91
91
|
|
|
92
92
|
return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
|
|
93
93
|
|
|
@@ -106,7 +106,7 @@ class FASTConvLayer(nn.Module):
|
|
|
106
106
|
id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
|
|
107
107
|
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
108
108
|
kernel = self.id_tensor
|
|
109
|
-
std = (identity.running_var + identity.eps).sqrt()
|
|
109
|
+
std = (identity.running_var + identity.eps).sqrt()
|
|
110
110
|
t = (identity.weight / std).reshape(-1, 1, 1, 1)
|
|
111
111
|
return kernel * t, identity.bias - identity.running_mean * identity.weight / std
|
|
112
112
|
|
|
@@ -155,7 +155,6 @@ class FASTConvLayer(nn.Module):
|
|
|
155
155
|
)
|
|
156
156
|
self.fused_conv.weight.data = kernel
|
|
157
157
|
self.fused_conv.bias.data = bias # type: ignore[union-attr]
|
|
158
|
-
self.deploy = True
|
|
159
158
|
for para in self.parameters():
|
|
160
159
|
para.detach_()
|
|
161
160
|
for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
|
|
@@ -97,7 +97,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
97
97
|
if self.hor_bn is not None and self.hor_conv is not None
|
|
98
98
|
else 0
|
|
99
99
|
)
|
|
100
|
-
id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None
|
|
100
|
+
id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0
|
|
101
101
|
|
|
102
102
|
return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
|
|
103
103
|
|
|
@@ -110,14 +110,14 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
110
110
|
return 0, 0
|
|
111
111
|
if not hasattr(self, "id_tensor"):
|
|
112
112
|
input_dim = self.in_channels // self.groups
|
|
113
|
-
kernel_value = np.zeros((
|
|
113
|
+
kernel_value = np.zeros((1, 1, input_dim, self.in_channels), dtype=np.float32)
|
|
114
114
|
for i in range(self.in_channels):
|
|
115
|
-
kernel_value[
|
|
115
|
+
kernel_value[0, 0, i % input_dim, i] = 1
|
|
116
116
|
id_tensor = tf.constant(kernel_value, dtype=tf.float32)
|
|
117
117
|
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
118
118
|
kernel = self.id_tensor
|
|
119
119
|
std = tf.sqrt(identity.moving_variance + identity.epsilon)
|
|
120
|
-
t = tf.reshape(identity.gamma / std, (
|
|
120
|
+
t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
|
|
121
121
|
return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
|
|
122
122
|
|
|
123
123
|
def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
@@ -138,18 +138,16 @@ class FASTConvLayer(layers.Layer, NestedObject):
|
|
|
138
138
|
else:
|
|
139
139
|
kernel_1xn, bias_1xn = 0, 0
|
|
140
140
|
kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
|
|
141
|
-
if not isinstance(kernel_id, int):
|
|
142
|
-
kernel_id = tf.transpose(kernel_id, (2, 3, 0, 1))
|
|
143
141
|
kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
|
|
144
142
|
bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
|
|
145
143
|
return kernel_mxn, bias_mxn
|
|
146
144
|
|
|
147
145
|
def _pad_to_mxn_tensor(self, kernel: tf.Tensor) -> tf.Tensor:
|
|
148
146
|
kernel_height, kernel_width = self.converted_ks
|
|
149
|
-
height, width = kernel.shape[2
|
|
147
|
+
height, width = kernel.shape[:2]
|
|
150
148
|
pad_left_right = tf.maximum(0, (kernel_width - width) // 2)
|
|
151
149
|
pad_top_down = tf.maximum(0, (kernel_height - height) // 2)
|
|
152
|
-
return tf.pad(kernel, [[
|
|
150
|
+
return tf.pad(kernel, [[pad_top_down, pad_top_down], [pad_left_right, pad_left_right], [0, 0], [0, 0]])
|
|
153
151
|
|
|
154
152
|
def reparameterize_layer(self):
|
|
155
153
|
kernel, bias = self._get_equivalent_kernel_bias()
|
|
@@ -51,8 +51,8 @@ def scaled_dot_product_attention(
|
|
|
51
51
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
|
|
52
52
|
if mask is not None:
|
|
53
53
|
# NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
|
|
54
|
-
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
55
|
-
p_attn = torch.softmax(scores, dim=-1)
|
|
54
|
+
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
55
|
+
p_attn = torch.softmax(scores, dim=-1)
|
|
56
56
|
return torch.matmul(p_attn, value), p_attn
|
|
57
57
|
|
|
58
58
|
|
doctr/models/predictor/base.py
CHANGED
|
@@ -3,16 +3,16 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any, Callable, List, Optional, Tuple
|
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
from doctr.models.builder import DocumentBuilder
|
|
11
|
-
from doctr.utils.geometry import extract_crops, extract_rcrops
|
|
11
|
+
from doctr.utils.geometry import extract_crops, extract_rcrops, rotate_image
|
|
12
12
|
|
|
13
|
-
from .._utils import rectify_crops, rectify_loc_preds
|
|
14
|
-
from ..classification import crop_orientation_predictor
|
|
15
|
-
from ..classification.predictor import
|
|
13
|
+
from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
|
|
14
|
+
from ..classification import crop_orientation_predictor, page_orientation_predictor
|
|
15
|
+
from ..classification.predictor import OrientationPredictor
|
|
16
16
|
|
|
17
17
|
__all__ = ["_OCRPredictor"]
|
|
18
18
|
|
|
@@ -29,10 +29,13 @@ class _OCRPredictor:
|
|
|
29
29
|
accordingly. Doing so will improve performances for documents with page-uniform rotations.
|
|
30
30
|
preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
|
|
31
31
|
symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
|
|
32
|
+
detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
|
|
33
|
+
page. Doing so will slightly deteriorate the overall latency.
|
|
32
34
|
**kwargs: keyword args of `DocumentBuilder`
|
|
33
35
|
"""
|
|
34
36
|
|
|
35
|
-
crop_orientation_predictor: Optional[
|
|
37
|
+
crop_orientation_predictor: Optional[OrientationPredictor]
|
|
38
|
+
page_orientation_predictor: Optional[OrientationPredictor]
|
|
36
39
|
|
|
37
40
|
def __init__(
|
|
38
41
|
self,
|
|
@@ -40,16 +43,69 @@ class _OCRPredictor:
|
|
|
40
43
|
straighten_pages: bool = False,
|
|
41
44
|
preserve_aspect_ratio: bool = True,
|
|
42
45
|
symmetric_pad: bool = True,
|
|
46
|
+
detect_orientation: bool = False,
|
|
43
47
|
**kwargs: Any,
|
|
44
48
|
) -> None:
|
|
45
49
|
self.assume_straight_pages = assume_straight_pages
|
|
46
50
|
self.straighten_pages = straighten_pages
|
|
47
51
|
self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True)
|
|
52
|
+
self.page_orientation_predictor = (
|
|
53
|
+
page_orientation_predictor(pretrained=True)
|
|
54
|
+
if detect_orientation or straighten_pages or not assume_straight_pages
|
|
55
|
+
else None
|
|
56
|
+
)
|
|
48
57
|
self.doc_builder = DocumentBuilder(**kwargs)
|
|
49
58
|
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
50
59
|
self.symmetric_pad = symmetric_pad
|
|
51
60
|
self.hooks: List[Callable] = []
|
|
52
61
|
|
|
62
|
+
def _general_page_orientations(
|
|
63
|
+
self,
|
|
64
|
+
pages: List[np.ndarray],
|
|
65
|
+
) -> List[Tuple[int, float]]:
|
|
66
|
+
_, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
|
|
67
|
+
# Flatten to list of tuples with (value, confidence)
|
|
68
|
+
page_orientations = [
|
|
69
|
+
(orientation, prob)
|
|
70
|
+
for page_classes, page_probs in zip(classes, probs)
|
|
71
|
+
for orientation, prob in zip(page_classes, page_probs)
|
|
72
|
+
]
|
|
73
|
+
return page_orientations
|
|
74
|
+
|
|
75
|
+
def _get_orientations(
|
|
76
|
+
self, pages: List[np.ndarray], seg_maps: List[np.ndarray]
|
|
77
|
+
) -> Tuple[List[Tuple[int, float]], List[int]]:
|
|
78
|
+
general_pages_orientations = self._general_page_orientations(pages)
|
|
79
|
+
origin_page_orientations = [
|
|
80
|
+
estimate_orientation(seq_map, general_orientation)
|
|
81
|
+
for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
|
|
82
|
+
]
|
|
83
|
+
return general_pages_orientations, origin_page_orientations
|
|
84
|
+
|
|
85
|
+
def _straighten_pages(
|
|
86
|
+
self,
|
|
87
|
+
pages: List[np.ndarray],
|
|
88
|
+
seg_maps: List[np.ndarray],
|
|
89
|
+
general_pages_orientations: Optional[List[Tuple[int, float]]] = None,
|
|
90
|
+
origin_pages_orientations: Optional[List[int]] = None,
|
|
91
|
+
) -> List[np.ndarray]:
|
|
92
|
+
general_pages_orientations = (
|
|
93
|
+
general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
|
|
94
|
+
)
|
|
95
|
+
origin_pages_orientations = (
|
|
96
|
+
origin_pages_orientations
|
|
97
|
+
if origin_pages_orientations
|
|
98
|
+
else [
|
|
99
|
+
estimate_orientation(seq_map, general_orientation)
|
|
100
|
+
for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
|
|
101
|
+
]
|
|
102
|
+
)
|
|
103
|
+
return [
|
|
104
|
+
# We exapnd if the page is wider than tall and the angle is 90 or -90
|
|
105
|
+
rotate_image(page, angle, expand=page.shape[1] > page.shape[0] and abs(angle) == 90)
|
|
106
|
+
for page, angle in zip(pages, origin_pages_orientations)
|
|
107
|
+
]
|
|
108
|
+
|
|
53
109
|
@staticmethod
|
|
54
110
|
def _generate_crops(
|
|
55
111
|
pages: List[np.ndarray],
|
|
@@ -88,68 +144,39 @@ class _OCRPredictor:
|
|
|
88
144
|
self,
|
|
89
145
|
crops: List[List[np.ndarray]],
|
|
90
146
|
loc_preds: List[np.ndarray],
|
|
91
|
-
) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
|
|
147
|
+
) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]:
|
|
92
148
|
# Work at a page level
|
|
93
|
-
orientations = [self.crop_orientation_predictor(page_crops) for page_crops in crops] # type: ignore[misc]
|
|
149
|
+
orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc]
|
|
94
150
|
rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
|
|
95
151
|
rect_loc_preds = [
|
|
96
152
|
rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds
|
|
97
153
|
for page_loc_preds, orientation in zip(loc_preds, orientations)
|
|
98
154
|
]
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
if self.preserve_aspect_ratio:
|
|
107
|
-
# Rectify loc_preds to remove padding
|
|
108
|
-
rectified_preds = []
|
|
109
|
-
for page, loc_pred in zip(pages, loc_preds):
|
|
110
|
-
h, w = page.shape[0], page.shape[1]
|
|
111
|
-
if h > w:
|
|
112
|
-
# y unchanged, dilate x coord
|
|
113
|
-
if self.symmetric_pad:
|
|
114
|
-
if self.assume_straight_pages:
|
|
115
|
-
loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1)
|
|
116
|
-
else:
|
|
117
|
-
loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1)
|
|
118
|
-
else:
|
|
119
|
-
if self.assume_straight_pages:
|
|
120
|
-
loc_pred[:, [0, 2]] *= h / w
|
|
121
|
-
else:
|
|
122
|
-
loc_pred[:, :, 0] *= h / w
|
|
123
|
-
elif w > h:
|
|
124
|
-
# x unchanged, dilate y coord
|
|
125
|
-
if self.symmetric_pad:
|
|
126
|
-
if self.assume_straight_pages:
|
|
127
|
-
loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1)
|
|
128
|
-
else:
|
|
129
|
-
loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1)
|
|
130
|
-
else:
|
|
131
|
-
if self.assume_straight_pages:
|
|
132
|
-
loc_pred[:, [1, 3]] *= w / h
|
|
133
|
-
else:
|
|
134
|
-
loc_pred[:, :, 1] *= w / h
|
|
135
|
-
rectified_preds.append(loc_pred)
|
|
136
|
-
return rectified_preds
|
|
137
|
-
return loc_preds
|
|
155
|
+
# Flatten to list of tuples with (value, confidence)
|
|
156
|
+
crop_orientations = [
|
|
157
|
+
(orientation, prob)
|
|
158
|
+
for page_classes, page_probs in zip(classes, probs)
|
|
159
|
+
for orientation, prob in zip(page_classes, page_probs)
|
|
160
|
+
]
|
|
161
|
+
return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value]
|
|
138
162
|
|
|
139
163
|
@staticmethod
|
|
140
164
|
def _process_predictions(
|
|
141
165
|
loc_preds: List[np.ndarray],
|
|
142
166
|
word_preds: List[Tuple[str, float]],
|
|
143
|
-
|
|
167
|
+
crop_orientations: List[Dict[str, Any]],
|
|
168
|
+
) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]:
|
|
144
169
|
text_preds = []
|
|
170
|
+
crop_orientation_preds = []
|
|
145
171
|
if len(loc_preds) > 0:
|
|
146
|
-
# Text
|
|
172
|
+
# Text & crop orientation predictions at page level
|
|
147
173
|
_idx = 0
|
|
148
174
|
for page_boxes in loc_preds:
|
|
149
175
|
text_preds.append(word_preds[_idx : _idx + page_boxes.shape[0]])
|
|
176
|
+
crop_orientation_preds.append(crop_orientations[_idx : _idx + page_boxes.shape[0]])
|
|
150
177
|
_idx += page_boxes.shape[0]
|
|
151
178
|
|
|
152
|
-
return loc_preds, text_preds
|
|
179
|
+
return loc_preds, text_preds, crop_orientation_preds
|
|
153
180
|
|
|
154
181
|
def add_hook(self, hook: Callable) -> None:
|
|
155
182
|
"""Add a hook to the predictor
|
|
@@ -10,10 +10,10 @@ import torch
|
|
|
10
10
|
from torch import nn
|
|
11
11
|
|
|
12
12
|
from doctr.io.elements import Document
|
|
13
|
-
from doctr.models._utils import
|
|
13
|
+
from doctr.models._utils import get_language
|
|
14
14
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
15
15
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
16
|
-
from doctr.utils.geometry import
|
|
16
|
+
from doctr.utils.geometry import detach_scores
|
|
17
17
|
|
|
18
18
|
from .base import _OCRPredictor
|
|
19
19
|
|
|
@@ -55,7 +55,13 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
55
55
|
self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
|
|
56
56
|
self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
|
|
57
57
|
_OCRPredictor.__init__(
|
|
58
|
-
self,
|
|
58
|
+
self,
|
|
59
|
+
assume_straight_pages,
|
|
60
|
+
straighten_pages,
|
|
61
|
+
preserve_aspect_ratio,
|
|
62
|
+
symmetric_pad,
|
|
63
|
+
detect_orientation,
|
|
64
|
+
**kwargs,
|
|
59
65
|
)
|
|
60
66
|
self.detect_orientation = detect_orientation
|
|
61
67
|
self.detect_language = detect_language
|
|
@@ -81,19 +87,16 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
81
87
|
for out_map in out_maps
|
|
82
88
|
]
|
|
83
89
|
if self.detect_orientation:
|
|
84
|
-
|
|
90
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
|
|
85
91
|
orientations = [
|
|
86
|
-
{"value": orientation_page, "confidence": None} for orientation_page in
|
|
92
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
87
93
|
]
|
|
88
94
|
else:
|
|
89
95
|
orientations = None
|
|
96
|
+
general_pages_orientations = None
|
|
97
|
+
origin_pages_orientations = None
|
|
90
98
|
if self.straighten_pages:
|
|
91
|
-
|
|
92
|
-
origin_page_orientations
|
|
93
|
-
if self.detect_orientation
|
|
94
|
-
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
95
|
-
)
|
|
96
|
-
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
99
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
|
|
97
100
|
# Forward again to get predictions on straight pages
|
|
98
101
|
loc_preds = self.det_predictor(pages, **kwargs)
|
|
99
102
|
|
|
@@ -102,30 +105,36 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
102
105
|
), "Detection Model in ocr_predictor should output only one class"
|
|
103
106
|
|
|
104
107
|
loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
|
|
108
|
+
# Detach objectness scores from loc_preds
|
|
109
|
+
loc_preds, objectness_scores = detach_scores(loc_preds)
|
|
105
110
|
# Check whether crop mode should be switched to channels first
|
|
106
111
|
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
107
112
|
|
|
108
|
-
# Rectify crops if aspect ratio
|
|
109
|
-
loc_preds = self._remove_padding(pages, loc_preds)
|
|
110
|
-
|
|
111
113
|
# Apply hooks to loc_preds if any
|
|
112
114
|
for hook in self.hooks:
|
|
113
115
|
loc_preds = hook(loc_preds)
|
|
114
116
|
|
|
115
117
|
# Crop images
|
|
116
118
|
crops, loc_preds = self._prepare_crops(
|
|
117
|
-
pages,
|
|
119
|
+
pages, # type: ignore[arg-type]
|
|
118
120
|
loc_preds,
|
|
119
121
|
channels_last=channels_last,
|
|
120
122
|
assume_straight_pages=self.assume_straight_pages,
|
|
121
123
|
)
|
|
122
|
-
# Rectify crop orientation
|
|
124
|
+
# Rectify crop orientation and get crop orientation predictions
|
|
125
|
+
crop_orientations: Any = []
|
|
123
126
|
if not self.assume_straight_pages:
|
|
124
|
-
crops, loc_preds = self._rectify_crops(crops, loc_preds)
|
|
127
|
+
crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds)
|
|
128
|
+
crop_orientations = [
|
|
129
|
+
{"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations
|
|
130
|
+
]
|
|
131
|
+
|
|
125
132
|
# Identify character sequences
|
|
126
133
|
word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
|
|
134
|
+
if not crop_orientations:
|
|
135
|
+
crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds]
|
|
127
136
|
|
|
128
|
-
boxes, text_preds = self._process_predictions(loc_preds, word_preds)
|
|
137
|
+
boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations)
|
|
129
138
|
|
|
130
139
|
if self.detect_language:
|
|
131
140
|
languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds]
|
|
@@ -134,10 +143,12 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
134
143
|
languages_dict = None
|
|
135
144
|
|
|
136
145
|
out = self.doc_builder(
|
|
137
|
-
pages,
|
|
146
|
+
pages, # type: ignore[arg-type]
|
|
138
147
|
boxes,
|
|
148
|
+
objectness_scores,
|
|
139
149
|
text_preds,
|
|
140
|
-
origin_page_shapes,
|
|
150
|
+
origin_page_shapes, # type: ignore[arg-type]
|
|
151
|
+
crop_orientations,
|
|
141
152
|
orientations,
|
|
142
153
|
languages_dict,
|
|
143
154
|
)
|