python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +2 -0
- doctr/datasets/cord.py +6 -4
- doctr/datasets/datasets/base.py +3 -2
- doctr/datasets/datasets/pytorch.py +4 -2
- doctr/datasets/datasets/tensorflow.py +4 -2
- doctr/datasets/detection.py +6 -3
- doctr/datasets/doc_artefacts.py +2 -1
- doctr/datasets/funsd.py +7 -8
- doctr/datasets/generator/base.py +3 -2
- doctr/datasets/generator/pytorch.py +3 -1
- doctr/datasets/generator/tensorflow.py +3 -1
- doctr/datasets/ic03.py +3 -2
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +6 -4
- doctr/datasets/iiithws.py +2 -1
- doctr/datasets/imgur5k.py +3 -2
- doctr/datasets/loader.py +4 -2
- doctr/datasets/mjsynth.py +2 -1
- doctr/datasets/ocr.py +2 -1
- doctr/datasets/orientation.py +40 -0
- doctr/datasets/recognition.py +3 -2
- doctr/datasets/sroie.py +2 -1
- doctr/datasets/svhn.py +2 -1
- doctr/datasets/svt.py +3 -2
- doctr/datasets/synthtext.py +2 -1
- doctr/datasets/utils.py +27 -11
- doctr/datasets/vocabs.py +26 -1
- doctr/datasets/wildreceipt.py +111 -0
- doctr/file_utils.py +3 -1
- doctr/io/elements.py +52 -35
- doctr/io/html.py +5 -3
- doctr/io/image/base.py +5 -4
- doctr/io/image/pytorch.py +12 -7
- doctr/io/image/tensorflow.py +11 -6
- doctr/io/pdf.py +5 -4
- doctr/io/reader.py +13 -5
- doctr/models/_utils.py +30 -53
- doctr/models/artefacts/barcode.py +4 -3
- doctr/models/artefacts/face.py +4 -2
- doctr/models/builder.py +58 -43
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +5 -2
- doctr/models/classification/magc_resnet/tensorflow.py +5 -2
- doctr/models/classification/mobilenet/pytorch.py +16 -4
- doctr/models/classification/mobilenet/tensorflow.py +29 -20
- doctr/models/classification/predictor/pytorch.py +3 -2
- doctr/models/classification/predictor/tensorflow.py +2 -1
- doctr/models/classification/resnet/pytorch.py +23 -13
- doctr/models/classification/resnet/tensorflow.py +33 -26
- doctr/models/classification/textnet/__init__.py +6 -0
- doctr/models/classification/textnet/pytorch.py +275 -0
- doctr/models/classification/textnet/tensorflow.py +267 -0
- doctr/models/classification/vgg/pytorch.py +4 -2
- doctr/models/classification/vgg/tensorflow.py +5 -2
- doctr/models/classification/vit/pytorch.py +9 -3
- doctr/models/classification/vit/tensorflow.py +9 -3
- doctr/models/classification/zoo.py +7 -2
- doctr/models/core.py +1 -1
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/pytorch.py +7 -1
- doctr/models/detection/_utils/tensorflow.py +7 -3
- doctr/models/detection/core.py +9 -3
- doctr/models/detection/differentiable_binarization/base.py +37 -25
- doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
- doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +256 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +12 -5
- doctr/models/detection/linknet/pytorch.py +28 -15
- doctr/models/detection/linknet/tensorflow.py +68 -88
- doctr/models/detection/predictor/pytorch.py +16 -6
- doctr/models/detection/predictor/tensorflow.py +13 -5
- doctr/models/detection/zoo.py +19 -16
- doctr/models/factory/hub.py +20 -10
- doctr/models/kie_predictor/base.py +2 -1
- doctr/models/kie_predictor/pytorch.py +28 -36
- doctr/models/kie_predictor/tensorflow.py +27 -27
- doctr/models/modules/__init__.py +1 -0
- doctr/models/modules/layers/__init__.py +6 -0
- doctr/models/modules/layers/pytorch.py +166 -0
- doctr/models/modules/layers/tensorflow.py +175 -0
- doctr/models/modules/transformer/pytorch.py +24 -22
- doctr/models/modules/transformer/tensorflow.py +6 -4
- doctr/models/modules/vision_transformer/pytorch.py +2 -4
- doctr/models/modules/vision_transformer/tensorflow.py +2 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
- doctr/models/predictor/base.py +14 -3
- doctr/models/predictor/pytorch.py +26 -29
- doctr/models/predictor/tensorflow.py +25 -22
- doctr/models/preprocessor/pytorch.py +14 -9
- doctr/models/preprocessor/tensorflow.py +10 -5
- doctr/models/recognition/core.py +4 -1
- doctr/models/recognition/crnn/pytorch.py +23 -16
- doctr/models/recognition/crnn/tensorflow.py +25 -17
- doctr/models/recognition/master/base.py +4 -1
- doctr/models/recognition/master/pytorch.py +20 -9
- doctr/models/recognition/master/tensorflow.py +20 -8
- doctr/models/recognition/parseq/base.py +4 -1
- doctr/models/recognition/parseq/pytorch.py +28 -22
- doctr/models/recognition/parseq/tensorflow.py +22 -11
- doctr/models/recognition/predictor/_utils.py +3 -2
- doctr/models/recognition/predictor/pytorch.py +3 -2
- doctr/models/recognition/predictor/tensorflow.py +2 -1
- doctr/models/recognition/sar/pytorch.py +14 -7
- doctr/models/recognition/sar/tensorflow.py +23 -14
- doctr/models/recognition/utils.py +5 -1
- doctr/models/recognition/vitstr/base.py +4 -1
- doctr/models/recognition/vitstr/pytorch.py +22 -13
- doctr/models/recognition/vitstr/tensorflow.py +21 -10
- doctr/models/recognition/zoo.py +4 -2
- doctr/models/utils/pytorch.py +24 -6
- doctr/models/utils/tensorflow.py +22 -3
- doctr/models/zoo.py +21 -3
- doctr/transforms/functional/base.py +8 -3
- doctr/transforms/functional/pytorch.py +23 -6
- doctr/transforms/functional/tensorflow.py +25 -5
- doctr/transforms/modules/base.py +12 -5
- doctr/transforms/modules/pytorch.py +10 -12
- doctr/transforms/modules/tensorflow.py +17 -9
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +4 -2
- doctr/utils/fonts.py +3 -2
- doctr/utils/geometry.py +95 -26
- doctr/utils/metrics.py +36 -22
- doctr/utils/multithreading.py +5 -3
- doctr/utils/repr.py +3 -1
- doctr/utils/visualization.py +31 -8
- doctr/version.py +1 -1
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
- python_doctr-0.8.1.dist-info/RECORD +173 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
- python_doctr-0.7.0.dist-info/RECORD +0 -161
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -15,7 +15,7 @@ from tensorflow.keras import layers
|
|
|
15
15
|
from tensorflow.keras.applications import ResNet50
|
|
16
16
|
|
|
17
17
|
from doctr.file_utils import CLASS_NAME
|
|
18
|
-
from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params
|
|
18
|
+
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
|
|
19
19
|
from doctr.utils.repr import NestedObject
|
|
20
20
|
|
|
21
21
|
from ...classification import mobilenet_v3_large
|
|
@@ -29,13 +29,13 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
29
29
|
"mean": (0.798, 0.785, 0.772),
|
|
30
30
|
"std": (0.264, 0.2749, 0.287),
|
|
31
31
|
"input_shape": (1024, 1024, 3),
|
|
32
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
32
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-84171458.zip&src=0",
|
|
33
33
|
},
|
|
34
34
|
"db_mobilenet_v3_large": {
|
|
35
35
|
"mean": (0.798, 0.785, 0.772),
|
|
36
36
|
"std": (0.264, 0.2749, 0.287),
|
|
37
37
|
"input_shape": (1024, 1024, 3),
|
|
38
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
38
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-da524564.zip&src=0",
|
|
39
39
|
},
|
|
40
40
|
}
|
|
41
41
|
|
|
@@ -45,6 +45,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
45
45
|
<https://arxiv.org/pdf/1612.03144.pdf>`_.
|
|
46
46
|
|
|
47
47
|
Args:
|
|
48
|
+
----
|
|
48
49
|
channels: number of channel to output
|
|
49
50
|
"""
|
|
50
51
|
|
|
@@ -66,14 +67,15 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
66
67
|
"""Module which performs a 3x3 convolution followed by up-sampling
|
|
67
68
|
|
|
68
69
|
Args:
|
|
70
|
+
----
|
|
69
71
|
channels: number of output channels
|
|
70
72
|
dilation_factor (int): dilation factor to scale the convolution output before concatenation
|
|
71
73
|
|
|
72
74
|
Returns:
|
|
75
|
+
-------
|
|
73
76
|
a keras.layers.Layer object, wrapping these operations in a sequential module
|
|
74
77
|
|
|
75
78
|
"""
|
|
76
|
-
|
|
77
79
|
_layers = conv_sequence(channels, "relu", True, kernel_size=3)
|
|
78
80
|
|
|
79
81
|
if dilation_factor > 1:
|
|
@@ -107,8 +109,11 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
107
109
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
108
110
|
|
|
109
111
|
Args:
|
|
112
|
+
----
|
|
110
113
|
feature extractor: the backbone serving as feature extractor
|
|
111
114
|
fpn_channels: number of channels each extracted feature maps is mapped to
|
|
115
|
+
bin_thresh: threshold for binarization
|
|
116
|
+
box_thresh: minimal objectness score to consider a box
|
|
112
117
|
assume_straight_pages: if True, fit straight bounding boxes only
|
|
113
118
|
exportable: onnx exportable returns only logits
|
|
114
119
|
cfg: the configuration dict of the model
|
|
@@ -122,6 +127,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
122
127
|
feature_extractor: IntermediateLayerGetter,
|
|
123
128
|
fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea
|
|
124
129
|
bin_thresh: float = 0.3,
|
|
130
|
+
box_thresh: float = 0.1,
|
|
125
131
|
assume_straight_pages: bool = True,
|
|
126
132
|
exportable: bool = False,
|
|
127
133
|
cfg: Optional[Dict[str, Any]] = None,
|
|
@@ -141,87 +147,96 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
141
147
|
_inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
|
|
142
148
|
output_shape = tuple(self.fpn(_inputs).shape)
|
|
143
149
|
|
|
144
|
-
self.probability_head = keras.Sequential(
|
|
145
|
-
[
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
150
|
+
self.probability_head = keras.Sequential([
|
|
151
|
+
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
|
|
152
|
+
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
|
|
153
|
+
layers.BatchNormalization(),
|
|
154
|
+
layers.Activation("relu"),
|
|
155
|
+
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
|
|
156
|
+
])
|
|
157
|
+
self.threshold_head = keras.Sequential([
|
|
158
|
+
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
|
|
159
|
+
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
|
|
160
|
+
layers.BatchNormalization(),
|
|
161
|
+
layers.Activation("relu"),
|
|
162
|
+
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
|
|
163
|
+
])
|
|
164
|
+
|
|
165
|
+
self.postprocessor = DBPostProcessor(
|
|
166
|
+
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
|
|
161
167
|
)
|
|
162
168
|
|
|
163
|
-
self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)
|
|
164
|
-
|
|
165
169
|
def compute_loss(
|
|
166
170
|
self,
|
|
167
171
|
out_map: tf.Tensor,
|
|
168
172
|
thresh_map: tf.Tensor,
|
|
169
173
|
target: List[Dict[str, np.ndarray]],
|
|
174
|
+
gamma: float = 2.0,
|
|
175
|
+
alpha: float = 0.5,
|
|
176
|
+
eps: float = 1e-8,
|
|
170
177
|
) -> tf.Tensor:
|
|
171
178
|
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
|
|
172
179
|
and a list of masks for each image. From there it computes the loss with the model output
|
|
173
180
|
|
|
174
181
|
Args:
|
|
182
|
+
----
|
|
175
183
|
out_map: output feature map of the model of shape (N, H, W, C)
|
|
176
184
|
thresh_map: threshold map of shape (N, H, W, C)
|
|
177
185
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
186
|
+
gamma: modulating factor in the focal loss formula
|
|
187
|
+
alpha: balancing factor in the focal loss formula
|
|
188
|
+
eps: epsilon factor in dice loss
|
|
178
189
|
|
|
179
190
|
Returns:
|
|
191
|
+
-------
|
|
180
192
|
A loss tensor
|
|
181
193
|
"""
|
|
194
|
+
if gamma < 0:
|
|
195
|
+
raise ValueError("Value of gamma should be greater than or equal to zero.")
|
|
182
196
|
|
|
183
197
|
prob_map = tf.math.sigmoid(out_map)
|
|
184
198
|
thresh_map = tf.math.sigmoid(thresh_map)
|
|
185
199
|
|
|
186
|
-
seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape, True)
|
|
200
|
+
seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True)
|
|
187
201
|
seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
|
|
188
202
|
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
|
|
203
|
+
seg_mask = tf.cast(seg_mask, tf.float32)
|
|
189
204
|
thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
|
|
190
205
|
thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool)
|
|
191
206
|
|
|
192
|
-
#
|
|
193
|
-
|
|
194
|
-
bce_loss = tf.keras.losses.binary_crossentropy(
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8
|
|
215
|
-
dice_loss = 1 - 2.0 * inter / union
|
|
207
|
+
# Focal loss
|
|
208
|
+
focal_scale = 10.0
|
|
209
|
+
bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
|
|
210
|
+
|
|
211
|
+
# Convert logits to prob, compute gamma factor
|
|
212
|
+
p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
|
|
213
|
+
alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha)
|
|
214
|
+
# Unreduced loss
|
|
215
|
+
focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
|
|
216
|
+
# Class reduced
|
|
217
|
+
focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))
|
|
218
|
+
|
|
219
|
+
# Compute dice loss for each class or for approx binary_map
|
|
220
|
+
if len(self.class_names) > 1:
|
|
221
|
+
dice_map = tf.nn.softmax(out_map, axis=-1)
|
|
222
|
+
else:
|
|
223
|
+
# compute binary map instead
|
|
224
|
+
dice_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map)))
|
|
225
|
+
# Class-reduced dice loss
|
|
226
|
+
inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2])
|
|
227
|
+
cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2])
|
|
228
|
+
dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps))
|
|
216
229
|
|
|
217
230
|
# Compute l1 loss for thresh_map
|
|
218
|
-
l1_scale = 10.0
|
|
219
231
|
if tf.reduce_any(thresh_mask):
|
|
220
|
-
|
|
232
|
+
thresh_mask = tf.cast(thresh_mask, tf.float32)
|
|
233
|
+
l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / (
|
|
234
|
+
tf.reduce_sum(thresh_mask) + eps
|
|
235
|
+
)
|
|
221
236
|
else:
|
|
222
237
|
l1_loss = tf.constant(0.0)
|
|
223
238
|
|
|
224
|
-
return
|
|
239
|
+
return l1_loss + focal_scale * focal_loss + dice_loss
|
|
225
240
|
|
|
226
241
|
def call(
|
|
227
242
|
self,
|
|
@@ -241,7 +256,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
241
256
|
return out
|
|
242
257
|
|
|
243
258
|
if return_model_output or target is None or return_preds:
|
|
244
|
-
prob_map = tf.math.sigmoid(logits)
|
|
259
|
+
prob_map = _bf16_to_float32(tf.math.sigmoid(logits))
|
|
245
260
|
|
|
246
261
|
if return_model_output:
|
|
247
262
|
out["out_map"] = prob_map
|
|
@@ -342,12 +357,14 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
342
357
|
>>> out = model(input_tensor)
|
|
343
358
|
|
|
344
359
|
Args:
|
|
360
|
+
----
|
|
345
361
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
362
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
346
363
|
|
|
347
364
|
Returns:
|
|
365
|
+
-------
|
|
348
366
|
text detection architecture
|
|
349
367
|
"""
|
|
350
|
-
|
|
351
368
|
return _db_resnet(
|
|
352
369
|
"db_resnet50",
|
|
353
370
|
pretrained,
|
|
@@ -368,12 +385,14 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
368
385
|
>>> out = model(input_tensor)
|
|
369
386
|
|
|
370
387
|
Args:
|
|
388
|
+
----
|
|
371
389
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
390
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
372
391
|
|
|
373
392
|
Returns:
|
|
393
|
+
-------
|
|
374
394
|
text detection architecture
|
|
375
395
|
"""
|
|
376
|
-
|
|
377
396
|
return _db_mobilenet(
|
|
378
397
|
"db_mobilenet_v3_large",
|
|
379
398
|
pretrained,
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
|
+
|
|
8
|
+
from typing import Dict, List, Tuple, Union
|
|
9
|
+
|
|
10
|
+
import cv2
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pyclipper
|
|
13
|
+
from shapely.geometry import Polygon
|
|
14
|
+
|
|
15
|
+
from doctr.models.core import BaseModel
|
|
16
|
+
|
|
17
|
+
from ..core import DetectionPostProcessor
|
|
18
|
+
|
|
19
|
+
__all__ = ["_FAST", "FASTPostProcessor"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FASTPostProcessor(DetectionPostProcessor):
|
|
23
|
+
"""Implements a post processor for FAST model.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
----
|
|
27
|
+
bin_thresh: threshold used to binzarized p_map at inference time
|
|
28
|
+
box_thresh: minimal objectness score to consider a box
|
|
29
|
+
assume_straight_pages: whether the inputs were expected to have horizontal text elements
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
bin_thresh: float = 0.3,
|
|
35
|
+
box_thresh: float = 0.1,
|
|
36
|
+
assume_straight_pages: bool = True,
|
|
37
|
+
) -> None:
|
|
38
|
+
super().__init__(box_thresh, bin_thresh, assume_straight_pages)
|
|
39
|
+
self.unclip_ratio = 1.0
|
|
40
|
+
|
|
41
|
+
def polygon_to_box(
|
|
42
|
+
self,
|
|
43
|
+
points: np.ndarray,
|
|
44
|
+
) -> np.ndarray:
|
|
45
|
+
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
----
|
|
49
|
+
points: The first parameter.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
-------
|
|
53
|
+
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
54
|
+
"""
|
|
55
|
+
if not self.assume_straight_pages:
|
|
56
|
+
# Compute the rectangle polygon enclosing the raw polygon
|
|
57
|
+
rect = cv2.minAreaRect(points)
|
|
58
|
+
points = cv2.boxPoints(rect)
|
|
59
|
+
# Add 1 pixel to correct cv2 approx
|
|
60
|
+
area = (rect[1][0] + 1) * (1 + rect[1][1])
|
|
61
|
+
length = 2 * (rect[1][0] + rect[1][1]) + 2
|
|
62
|
+
else:
|
|
63
|
+
poly = Polygon(points)
|
|
64
|
+
area = poly.area
|
|
65
|
+
length = poly.length
|
|
66
|
+
distance = area * self.unclip_ratio / length # compute distance to expand polygon
|
|
67
|
+
offset = pyclipper.PyclipperOffset()
|
|
68
|
+
offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
69
|
+
_points = offset.Execute(distance)
|
|
70
|
+
# Take biggest stack of points
|
|
71
|
+
idx = 0
|
|
72
|
+
if len(_points) > 1:
|
|
73
|
+
max_size = 0
|
|
74
|
+
for _idx, p in enumerate(_points):
|
|
75
|
+
if len(p) > max_size:
|
|
76
|
+
idx = _idx
|
|
77
|
+
max_size = len(p)
|
|
78
|
+
# We ensure that _points can be correctly casted to a ndarray
|
|
79
|
+
_points = [_points[idx]]
|
|
80
|
+
expanded_points: np.ndarray = np.asarray(_points) # expand polygon
|
|
81
|
+
if len(expanded_points) < 1:
|
|
82
|
+
return None # type: ignore[return-value]
|
|
83
|
+
return (
|
|
84
|
+
cv2.boundingRect(expanded_points) # type: ignore[return-value]
|
|
85
|
+
if self.assume_straight_pages
|
|
86
|
+
else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def bitmap_to_boxes(
|
|
90
|
+
self,
|
|
91
|
+
pred: np.ndarray,
|
|
92
|
+
bitmap: np.ndarray,
|
|
93
|
+
) -> np.ndarray:
|
|
94
|
+
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
----
|
|
98
|
+
pred: Pred map from differentiable linknet output
|
|
99
|
+
bitmap: Bitmap map computed from pred (binarized)
|
|
100
|
+
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
101
|
+
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
-------
|
|
105
|
+
np tensor boxes for the bitmap, each box is a 6-element list
|
|
106
|
+
containing x, y, w, h, alpha, score for the box
|
|
107
|
+
"""
|
|
108
|
+
height, width = bitmap.shape[:2]
|
|
109
|
+
boxes: List[Union[np.ndarray, List[float]]] = []
|
|
110
|
+
# get contours from connected components on the bitmap
|
|
111
|
+
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
112
|
+
for contour in contours:
|
|
113
|
+
# Check whether smallest enclosing bounding box is not too small
|
|
114
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
115
|
+
continue
|
|
116
|
+
# Compute objectness
|
|
117
|
+
if self.assume_straight_pages:
|
|
118
|
+
x, y, w, h = cv2.boundingRect(contour)
|
|
119
|
+
points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]])
|
|
120
|
+
score = self.box_score(pred, points, assume_straight_pages=True)
|
|
121
|
+
else:
|
|
122
|
+
score = self.box_score(pred, contour, assume_straight_pages=False)
|
|
123
|
+
|
|
124
|
+
if score < self.box_thresh: # remove polygons with a weak objectness
|
|
125
|
+
continue
|
|
126
|
+
|
|
127
|
+
if self.assume_straight_pages:
|
|
128
|
+
_box = self.polygon_to_box(points)
|
|
129
|
+
else:
|
|
130
|
+
_box = self.polygon_to_box(np.squeeze(contour))
|
|
131
|
+
|
|
132
|
+
if self.assume_straight_pages:
|
|
133
|
+
# compute relative polygon to get rid of img shape
|
|
134
|
+
x, y, w, h = _box
|
|
135
|
+
xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
|
|
136
|
+
boxes.append([xmin, ymin, xmax, ymax, score])
|
|
137
|
+
else:
|
|
138
|
+
# compute relative box to get rid of img shape
|
|
139
|
+
_box[:, 0] /= width
|
|
140
|
+
_box[:, 1] /= height
|
|
141
|
+
boxes.append(_box)
|
|
142
|
+
|
|
143
|
+
if not self.assume_straight_pages:
|
|
144
|
+
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
|
|
145
|
+
else:
|
|
146
|
+
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class _FAST(BaseModel):
|
|
150
|
+
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
151
|
+
<https://arxiv.org/pdf/2111.02394.pdf>`_.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
min_size_box: int = 3
|
|
155
|
+
assume_straight_pages: bool = True
|
|
156
|
+
shrink_ratio = 0.1
|
|
157
|
+
|
|
158
|
+
def build_target(
|
|
159
|
+
self,
|
|
160
|
+
target: List[Dict[str, np.ndarray]],
|
|
161
|
+
output_shape: Tuple[int, int, int],
|
|
162
|
+
channels_last: bool = True,
|
|
163
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
164
|
+
"""Build the target, and it's mask to be used from loss computation.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
----
|
|
168
|
+
target: target coming from dataset
|
|
169
|
+
output_shape: shape of the output of the model without batch_size
|
|
170
|
+
channels_last: whether channels are last or not
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
-------
|
|
174
|
+
the new formatted target, mask and shrunken text kernel
|
|
175
|
+
"""
|
|
176
|
+
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|
|
177
|
+
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
|
|
178
|
+
if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
|
|
179
|
+
raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")
|
|
180
|
+
|
|
181
|
+
h: int
|
|
182
|
+
w: int
|
|
183
|
+
if channels_last:
|
|
184
|
+
h, w, num_classes = output_shape
|
|
185
|
+
else:
|
|
186
|
+
num_classes, h, w = output_shape
|
|
187
|
+
target_shape = (len(target), num_classes, h, w)
|
|
188
|
+
|
|
189
|
+
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
190
|
+
seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
|
|
191
|
+
shrunken_kernel: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
192
|
+
|
|
193
|
+
for idx, tgt in enumerate(target):
|
|
194
|
+
for class_idx, _tgt in enumerate(tgt.values()):
|
|
195
|
+
# Draw each polygon on gt
|
|
196
|
+
if _tgt.shape[0] == 0:
|
|
197
|
+
# Empty image, full masked
|
|
198
|
+
seg_mask[idx, class_idx] = False
|
|
199
|
+
|
|
200
|
+
# Absolute bounding boxes
|
|
201
|
+
abs_boxes = _tgt.copy()
|
|
202
|
+
|
|
203
|
+
if abs_boxes.ndim == 3:
|
|
204
|
+
abs_boxes[:, :, 0] *= w
|
|
205
|
+
abs_boxes[:, :, 1] *= h
|
|
206
|
+
polys = abs_boxes
|
|
207
|
+
boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1)
|
|
208
|
+
abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32)
|
|
209
|
+
else:
|
|
210
|
+
abs_boxes[:, [0, 2]] *= w
|
|
211
|
+
abs_boxes[:, [1, 3]] *= h
|
|
212
|
+
abs_boxes = abs_boxes.round().astype(np.int32)
|
|
213
|
+
polys = np.stack(
|
|
214
|
+
[
|
|
215
|
+
abs_boxes[:, [0, 1]],
|
|
216
|
+
abs_boxes[:, [0, 3]],
|
|
217
|
+
abs_boxes[:, [2, 3]],
|
|
218
|
+
abs_boxes[:, [2, 1]],
|
|
219
|
+
],
|
|
220
|
+
axis=1,
|
|
221
|
+
)
|
|
222
|
+
boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])
|
|
223
|
+
|
|
224
|
+
for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
|
|
225
|
+
# Mask boxes that are too small
|
|
226
|
+
if box_size < self.min_size_box:
|
|
227
|
+
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
228
|
+
continue
|
|
229
|
+
|
|
230
|
+
# Negative shrink for gt, as described in paper
|
|
231
|
+
polygon = Polygon(poly)
|
|
232
|
+
distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
|
|
233
|
+
subject = [tuple(coor) for coor in poly]
|
|
234
|
+
padding = pyclipper.PyclipperOffset()
|
|
235
|
+
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
236
|
+
shrunken = padding.Execute(-distance)
|
|
237
|
+
|
|
238
|
+
# Draw polygon on gt if it is valid
|
|
239
|
+
if len(shrunken) == 0:
|
|
240
|
+
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
241
|
+
continue
|
|
242
|
+
shrunken = np.array(shrunken[0]).reshape(-1, 2)
|
|
243
|
+
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
|
|
244
|
+
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
245
|
+
continue
|
|
246
|
+
cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
|
|
247
|
+
# draw the original polygon on the segmentation target
|
|
248
|
+
cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload]
|
|
249
|
+
|
|
250
|
+
# Don't forget to switch back to channel last if Tensorflow is used
|
|
251
|
+
if channels_last:
|
|
252
|
+
seg_target = seg_target.transpose((0, 2, 3, 1))
|
|
253
|
+
seg_mask = seg_mask.transpose((0, 2, 3, 1))
|
|
254
|
+
shrunken_kernel = shrunken_kernel.transpose((0, 2, 3, 1))
|
|
255
|
+
|
|
256
|
+
return seg_target, seg_mask, shrunken_kernel
|