python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +2 -0
- doctr/datasets/cord.py +6 -4
- doctr/datasets/datasets/base.py +3 -2
- doctr/datasets/datasets/pytorch.py +4 -2
- doctr/datasets/datasets/tensorflow.py +4 -2
- doctr/datasets/detection.py +6 -3
- doctr/datasets/doc_artefacts.py +2 -1
- doctr/datasets/funsd.py +7 -8
- doctr/datasets/generator/base.py +3 -2
- doctr/datasets/generator/pytorch.py +3 -1
- doctr/datasets/generator/tensorflow.py +3 -1
- doctr/datasets/ic03.py +3 -2
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +6 -4
- doctr/datasets/iiithws.py +2 -1
- doctr/datasets/imgur5k.py +3 -2
- doctr/datasets/loader.py +4 -2
- doctr/datasets/mjsynth.py +2 -1
- doctr/datasets/ocr.py +2 -1
- doctr/datasets/orientation.py +40 -0
- doctr/datasets/recognition.py +3 -2
- doctr/datasets/sroie.py +2 -1
- doctr/datasets/svhn.py +2 -1
- doctr/datasets/svt.py +3 -2
- doctr/datasets/synthtext.py +2 -1
- doctr/datasets/utils.py +27 -11
- doctr/datasets/vocabs.py +26 -1
- doctr/datasets/wildreceipt.py +111 -0
- doctr/file_utils.py +3 -1
- doctr/io/elements.py +52 -35
- doctr/io/html.py +5 -3
- doctr/io/image/base.py +5 -4
- doctr/io/image/pytorch.py +12 -7
- doctr/io/image/tensorflow.py +11 -6
- doctr/io/pdf.py +5 -4
- doctr/io/reader.py +13 -5
- doctr/models/_utils.py +30 -53
- doctr/models/artefacts/barcode.py +4 -3
- doctr/models/artefacts/face.py +4 -2
- doctr/models/builder.py +58 -43
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +5 -2
- doctr/models/classification/magc_resnet/tensorflow.py +5 -2
- doctr/models/classification/mobilenet/pytorch.py +16 -4
- doctr/models/classification/mobilenet/tensorflow.py +29 -20
- doctr/models/classification/predictor/pytorch.py +3 -2
- doctr/models/classification/predictor/tensorflow.py +2 -1
- doctr/models/classification/resnet/pytorch.py +23 -13
- doctr/models/classification/resnet/tensorflow.py +33 -26
- doctr/models/classification/textnet/__init__.py +6 -0
- doctr/models/classification/textnet/pytorch.py +275 -0
- doctr/models/classification/textnet/tensorflow.py +267 -0
- doctr/models/classification/vgg/pytorch.py +4 -2
- doctr/models/classification/vgg/tensorflow.py +5 -2
- doctr/models/classification/vit/pytorch.py +9 -3
- doctr/models/classification/vit/tensorflow.py +9 -3
- doctr/models/classification/zoo.py +7 -2
- doctr/models/core.py +1 -1
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/pytorch.py +7 -1
- doctr/models/detection/_utils/tensorflow.py +7 -3
- doctr/models/detection/core.py +9 -3
- doctr/models/detection/differentiable_binarization/base.py +37 -25
- doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
- doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +256 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +12 -5
- doctr/models/detection/linknet/pytorch.py +28 -15
- doctr/models/detection/linknet/tensorflow.py +68 -88
- doctr/models/detection/predictor/pytorch.py +16 -6
- doctr/models/detection/predictor/tensorflow.py +13 -5
- doctr/models/detection/zoo.py +19 -16
- doctr/models/factory/hub.py +20 -10
- doctr/models/kie_predictor/base.py +2 -1
- doctr/models/kie_predictor/pytorch.py +28 -36
- doctr/models/kie_predictor/tensorflow.py +27 -27
- doctr/models/modules/__init__.py +1 -0
- doctr/models/modules/layers/__init__.py +6 -0
- doctr/models/modules/layers/pytorch.py +166 -0
- doctr/models/modules/layers/tensorflow.py +175 -0
- doctr/models/modules/transformer/pytorch.py +24 -22
- doctr/models/modules/transformer/tensorflow.py +6 -4
- doctr/models/modules/vision_transformer/pytorch.py +2 -4
- doctr/models/modules/vision_transformer/tensorflow.py +2 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
- doctr/models/predictor/base.py +14 -3
- doctr/models/predictor/pytorch.py +26 -29
- doctr/models/predictor/tensorflow.py +25 -22
- doctr/models/preprocessor/pytorch.py +14 -9
- doctr/models/preprocessor/tensorflow.py +10 -5
- doctr/models/recognition/core.py +4 -1
- doctr/models/recognition/crnn/pytorch.py +23 -16
- doctr/models/recognition/crnn/tensorflow.py +25 -17
- doctr/models/recognition/master/base.py +4 -1
- doctr/models/recognition/master/pytorch.py +20 -9
- doctr/models/recognition/master/tensorflow.py +20 -8
- doctr/models/recognition/parseq/base.py +4 -1
- doctr/models/recognition/parseq/pytorch.py +28 -22
- doctr/models/recognition/parseq/tensorflow.py +22 -11
- doctr/models/recognition/predictor/_utils.py +3 -2
- doctr/models/recognition/predictor/pytorch.py +3 -2
- doctr/models/recognition/predictor/tensorflow.py +2 -1
- doctr/models/recognition/sar/pytorch.py +14 -7
- doctr/models/recognition/sar/tensorflow.py +23 -14
- doctr/models/recognition/utils.py +5 -1
- doctr/models/recognition/vitstr/base.py +4 -1
- doctr/models/recognition/vitstr/pytorch.py +22 -13
- doctr/models/recognition/vitstr/tensorflow.py +21 -10
- doctr/models/recognition/zoo.py +4 -2
- doctr/models/utils/pytorch.py +24 -6
- doctr/models/utils/tensorflow.py +22 -3
- doctr/models/zoo.py +21 -3
- doctr/transforms/functional/base.py +8 -3
- doctr/transforms/functional/pytorch.py +23 -6
- doctr/transforms/functional/tensorflow.py +25 -5
- doctr/transforms/modules/base.py +12 -5
- doctr/transforms/modules/pytorch.py +10 -12
- doctr/transforms/modules/tensorflow.py +17 -9
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +4 -2
- doctr/utils/fonts.py +3 -2
- doctr/utils/geometry.py +95 -26
- doctr/utils/metrics.py +36 -22
- doctr/utils/multithreading.py +5 -3
- doctr/utils/repr.py +3 -1
- doctr/utils/visualization.py +31 -8
- doctr/version.py +1 -1
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
- python_doctr-0.8.1.dist-info/RECORD +173 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
- python_doctr-0.7.0.dist-info/RECORD +0 -161
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,428 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
|
+
|
|
8
|
+
from copy import deepcopy
|
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import tensorflow as tf
|
|
13
|
+
from tensorflow import keras
|
|
14
|
+
from tensorflow.keras import Sequential, layers
|
|
15
|
+
|
|
16
|
+
from doctr.file_utils import CLASS_NAME
|
|
17
|
+
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params
|
|
18
|
+
from doctr.utils.repr import NestedObject
|
|
19
|
+
|
|
20
|
+
from ...classification import textnet_base, textnet_small, textnet_tiny
|
|
21
|
+
from ...modules.layers import FASTConvLayer
|
|
22
|
+
from .base import _FAST, FASTPostProcessor
|
|
23
|
+
|
|
24
|
+
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
28
|
+
"fast_tiny": {
|
|
29
|
+
"input_shape": (1024, 1024, 3),
|
|
30
|
+
"mean": (0.798, 0.785, 0.772),
|
|
31
|
+
"std": (0.264, 0.2749, 0.287),
|
|
32
|
+
"url": None,
|
|
33
|
+
},
|
|
34
|
+
"fast_small": {
|
|
35
|
+
"input_shape": (1024, 1024, 3),
|
|
36
|
+
"mean": (0.798, 0.785, 0.772),
|
|
37
|
+
"std": (0.264, 0.2749, 0.287),
|
|
38
|
+
"url": None,
|
|
39
|
+
},
|
|
40
|
+
"fast_base": {
|
|
41
|
+
"input_shape": (1024, 1024, 3),
|
|
42
|
+
"mean": (0.798, 0.785, 0.772),
|
|
43
|
+
"std": (0.264, 0.2749, 0.287),
|
|
44
|
+
"url": None,
|
|
45
|
+
},
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class FastNeck(layers.Layer, NestedObject):
|
|
50
|
+
"""Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
----
|
|
54
|
+
in_channels: number of input channels
|
|
55
|
+
out_channels: number of output channels
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
in_channels: int,
|
|
61
|
+
out_channels: int = 128,
|
|
62
|
+
) -> None:
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.reduction = [FASTConvLayer(in_channels * scale, out_channels, kernel_size=3) for scale in [1, 2, 4, 8]]
|
|
65
|
+
|
|
66
|
+
def _upsample(self, x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
|
|
67
|
+
return tf.image.resize(x, size=y.shape[1:3], method="bilinear")
|
|
68
|
+
|
|
69
|
+
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
70
|
+
f1, f2, f3, f4 = x
|
|
71
|
+
f1, f2, f3, f4 = [reduction(f, **kwargs) for reduction, f in zip(self.reduction, (f1, f2, f3, f4))]
|
|
72
|
+
f2, f3, f4 = [self._upsample(f, f1) for f in (f2, f3, f4)]
|
|
73
|
+
f = tf.concat((f1, f2, f3, f4), axis=-1)
|
|
74
|
+
return f
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class FastHead(Sequential):
|
|
78
|
+
"""Head of the FAST architecture
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
----
|
|
82
|
+
in_channels: number of input channels
|
|
83
|
+
num_classes: number of output classes
|
|
84
|
+
out_channels: number of output channels
|
|
85
|
+
dropout: dropout probability
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
in_channels: int,
|
|
91
|
+
num_classes: int,
|
|
92
|
+
out_channels: int = 128,
|
|
93
|
+
dropout: float = 0.1,
|
|
94
|
+
) -> None:
|
|
95
|
+
_layers = [
|
|
96
|
+
FASTConvLayer(in_channels, out_channels, kernel_size=3),
|
|
97
|
+
layers.Dropout(dropout),
|
|
98
|
+
layers.Conv2D(num_classes, kernel_size=1, use_bias=False),
|
|
99
|
+
]
|
|
100
|
+
super().__init__(_layers)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class FAST(_FAST, keras.Model, NestedObject):
|
|
104
|
+
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
105
|
+
<https://arxiv.org/pdf/2111.02394.pdf>`_.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
----
|
|
109
|
+
feature extractor: the backbone serving as feature extractor
|
|
110
|
+
bin_thresh: threshold for binarization
|
|
111
|
+
box_thresh: minimal objectness score to consider a box
|
|
112
|
+
dropout_prob: dropout probability
|
|
113
|
+
pooling_size: size of the pooling layer
|
|
114
|
+
assume_straight_pages: if True, fit straight bounding boxes only
|
|
115
|
+
exportable: onnx exportable returns only logits
|
|
116
|
+
cfg: the configuration dict of the model
|
|
117
|
+
class_names: list of class names
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
_children_names: List[str] = ["feat_extractor", "neck", "head", "postprocessor"]
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
feature_extractor: IntermediateLayerGetter,
|
|
125
|
+
bin_thresh: float = 0.3,
|
|
126
|
+
box_thresh: float = 0.1,
|
|
127
|
+
dropout_prob: float = 0.1,
|
|
128
|
+
pooling_size: int = 4, # different from paper performs better on close text-rich images
|
|
129
|
+
assume_straight_pages: bool = True,
|
|
130
|
+
exportable: bool = False,
|
|
131
|
+
cfg: Optional[Dict[str, Any]] = {},
|
|
132
|
+
class_names: List[str] = [CLASS_NAME],
|
|
133
|
+
) -> None:
|
|
134
|
+
super().__init__()
|
|
135
|
+
self.class_names = class_names
|
|
136
|
+
num_classes: int = len(self.class_names)
|
|
137
|
+
self.cfg = cfg
|
|
138
|
+
|
|
139
|
+
self.feat_extractor = feature_extractor
|
|
140
|
+
self.exportable = exportable
|
|
141
|
+
self.assume_straight_pages = assume_straight_pages
|
|
142
|
+
|
|
143
|
+
# Identify the number of channels for the neck & head initialization
|
|
144
|
+
feat_out_channels = [
|
|
145
|
+
layers.Input(shape=in_shape[1:]).shape[-1] for in_shape in self.feat_extractor.output_shape
|
|
146
|
+
]
|
|
147
|
+
# Initialize neck & head
|
|
148
|
+
self.neck = FastNeck(feat_out_channels[0], feat_out_channels[1])
|
|
149
|
+
self.head = FastHead(feat_out_channels[-1], num_classes, feat_out_channels[1], dropout_prob)
|
|
150
|
+
|
|
151
|
+
# NOTE: The post processing from the paper works not well for text-rich images
|
|
152
|
+
# so we use a modified version from DBNet
|
|
153
|
+
self.postprocessor = FASTPostProcessor(
|
|
154
|
+
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Pooling layer as erosion reversal as described in the paper
|
|
158
|
+
self.pooling = layers.MaxPooling2D(pool_size=pooling_size // 2 + 1, strides=1, padding="same")
|
|
159
|
+
|
|
160
|
+
def compute_loss(
|
|
161
|
+
self,
|
|
162
|
+
out_map: tf.Tensor,
|
|
163
|
+
target: List[Dict[str, np.ndarray]],
|
|
164
|
+
eps: float = 1e-6,
|
|
165
|
+
) -> tf.Tensor:
|
|
166
|
+
"""Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
----
|
|
170
|
+
out_map: output feature map of the model of shape (N, num_classes, H, W)
|
|
171
|
+
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
172
|
+
eps: epsilon factor in dice loss
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
-------
|
|
176
|
+
A loss tensor
|
|
177
|
+
"""
|
|
178
|
+
targets = self.build_target(target, out_map.shape[1:], True)
|
|
179
|
+
|
|
180
|
+
seg_target = tf.convert_to_tensor(targets[0], dtype=out_map.dtype)
|
|
181
|
+
seg_mask = tf.convert_to_tensor(targets[1], dtype=out_map.dtype)
|
|
182
|
+
shrunken_kernel = tf.convert_to_tensor(targets[2], dtype=out_map.dtype)
|
|
183
|
+
|
|
184
|
+
def ohem(score: tf.Tensor, gt: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
|
|
185
|
+
pos_num = tf.reduce_sum(tf.cast(gt > 0.5, dtype=tf.int32)) - tf.reduce_sum(
|
|
186
|
+
tf.cast((gt > 0.5) & (mask <= 0.5), dtype=tf.int32)
|
|
187
|
+
)
|
|
188
|
+
neg_num = tf.reduce_sum(tf.cast(gt <= 0.5, dtype=tf.int32))
|
|
189
|
+
neg_num = tf.minimum(pos_num * 3, neg_num)
|
|
190
|
+
|
|
191
|
+
if neg_num == 0 or pos_num == 0:
|
|
192
|
+
return mask
|
|
193
|
+
|
|
194
|
+
neg_score_sorted, _ = tf.nn.top_k(-tf.boolean_mask(score, gt <= 0.5), k=neg_num)
|
|
195
|
+
threshold = -neg_score_sorted[-1]
|
|
196
|
+
|
|
197
|
+
selected_mask = tf.math.logical_and((score >= threshold) | (gt > 0.5), (mask > 0.5))
|
|
198
|
+
return tf.cast(selected_mask, dtype=tf.float32)
|
|
199
|
+
|
|
200
|
+
if len(self.class_names) > 1:
|
|
201
|
+
kernels = tf.nn.softmax(out_map, axis=-1)
|
|
202
|
+
prob_map = tf.nn.softmax(self.pooling(out_map), axis=-1)
|
|
203
|
+
else:
|
|
204
|
+
kernels = tf.sigmoid(out_map)
|
|
205
|
+
prob_map = tf.sigmoid(self.pooling(out_map))
|
|
206
|
+
|
|
207
|
+
# As described in the paper, we use the Dice loss for the text segmentation map and the Dice loss scaled by 0.5.
|
|
208
|
+
selected_masks = tf.stack(
|
|
209
|
+
[ohem(score, gt, mask) for score, gt, mask in zip(prob_map, seg_target, seg_mask)], axis=0
|
|
210
|
+
)
|
|
211
|
+
inter = tf.reduce_sum(selected_masks * prob_map * seg_target, axis=(0, 1, 2))
|
|
212
|
+
cardinality = tf.reduce_sum(selected_masks * (prob_map + seg_target), axis=(0, 1, 2))
|
|
213
|
+
text_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps))) * 0.5
|
|
214
|
+
|
|
215
|
+
# As described in the paper, we use the Dice loss for the text kernel map.
|
|
216
|
+
selected_masks = seg_target * seg_mask
|
|
217
|
+
inter = tf.reduce_sum(selected_masks * kernels * shrunken_kernel, axis=(0, 1, 2))
|
|
218
|
+
cardinality = tf.reduce_sum(selected_masks * (kernels + shrunken_kernel), axis=(0, 1, 2))
|
|
219
|
+
kernel_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps)))
|
|
220
|
+
|
|
221
|
+
return text_loss + kernel_loss
|
|
222
|
+
|
|
223
|
+
def call(
|
|
224
|
+
self,
|
|
225
|
+
x: tf.Tensor,
|
|
226
|
+
target: Optional[List[Dict[str, np.ndarray]]] = None,
|
|
227
|
+
return_model_output: bool = False,
|
|
228
|
+
return_preds: bool = False,
|
|
229
|
+
**kwargs: Any,
|
|
230
|
+
) -> Dict[str, Any]:
|
|
231
|
+
feat_maps = self.feat_extractor(x, **kwargs)
|
|
232
|
+
# Pass through the Neck & Head & Upsample
|
|
233
|
+
feat_concat = self.neck(feat_maps, **kwargs)
|
|
234
|
+
logits: tf.Tensor = self.head(feat_concat, **kwargs)
|
|
235
|
+
logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs)
|
|
236
|
+
|
|
237
|
+
out: Dict[str, tf.Tensor] = {}
|
|
238
|
+
if self.exportable:
|
|
239
|
+
out["logits"] = logits
|
|
240
|
+
return out
|
|
241
|
+
|
|
242
|
+
if return_model_output or target is None or return_preds:
|
|
243
|
+
prob_map = _bf16_to_float32(tf.math.sigmoid(self.pooling(logits, **kwargs)))
|
|
244
|
+
|
|
245
|
+
if return_model_output:
|
|
246
|
+
out["out_map"] = prob_map
|
|
247
|
+
|
|
248
|
+
if target is None or return_preds:
|
|
249
|
+
# Post-process boxes (keep only text predictions)
|
|
250
|
+
out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]
|
|
251
|
+
|
|
252
|
+
if target is not None:
|
|
253
|
+
loss = self.compute_loss(logits, target)
|
|
254
|
+
out["loss"] = loss
|
|
255
|
+
|
|
256
|
+
return out
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def reparameterize(model: Union[FAST, layers.Layer]) -> FAST:
|
|
260
|
+
"""Fuse batchnorm and conv layers and reparameterize the model
|
|
261
|
+
|
|
262
|
+
args:
|
|
263
|
+
----
|
|
264
|
+
model: the FAST model to reparameterize
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
-------
|
|
268
|
+
the reparameterized model
|
|
269
|
+
"""
|
|
270
|
+
last_conv = None
|
|
271
|
+
last_conv_idx = None
|
|
272
|
+
|
|
273
|
+
for idx, layer in enumerate(model.layers):
|
|
274
|
+
if hasattr(layer, "layers") or isinstance(
|
|
275
|
+
layer, (FASTConvLayer, FastNeck, FastHead, layers.BatchNormalization, layers.Conv2D)
|
|
276
|
+
):
|
|
277
|
+
if isinstance(layer, layers.BatchNormalization):
|
|
278
|
+
# fuse batchnorm only if it is followed by a conv layer
|
|
279
|
+
if last_conv is None:
|
|
280
|
+
continue
|
|
281
|
+
conv_w = last_conv.kernel
|
|
282
|
+
conv_b = last_conv.bias if last_conv.use_bias else tf.zeros_like(layer.moving_mean)
|
|
283
|
+
|
|
284
|
+
factor = layer.gamma / tf.sqrt(layer.moving_variance + layer.epsilon)
|
|
285
|
+
last_conv.kernel = conv_w * factor.numpy().reshape([1, 1, 1, -1])
|
|
286
|
+
if last_conv.use_bias:
|
|
287
|
+
last_conv.bias.assign((conv_b - layer.moving_mean) * factor + layer.beta)
|
|
288
|
+
model.layers[last_conv_idx] = last_conv # Replace the last conv layer with the fused version
|
|
289
|
+
model.layers[idx] = layers.Lambda(lambda x: x)
|
|
290
|
+
last_conv = None
|
|
291
|
+
elif isinstance(layer, layers.Conv2D):
|
|
292
|
+
last_conv = layer
|
|
293
|
+
last_conv_idx = idx
|
|
294
|
+
elif isinstance(layer, FASTConvLayer):
|
|
295
|
+
layer.reparameterize_layer()
|
|
296
|
+
elif isinstance(layer, FastNeck):
|
|
297
|
+
for reduction in layer.reduction:
|
|
298
|
+
reduction.reparameterize_layer()
|
|
299
|
+
elif isinstance(layer, FastHead):
|
|
300
|
+
reparameterize(layer)
|
|
301
|
+
else:
|
|
302
|
+
reparameterize(layer)
|
|
303
|
+
return model
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _fast(
|
|
307
|
+
arch: str,
|
|
308
|
+
pretrained: bool,
|
|
309
|
+
backbone_fn,
|
|
310
|
+
feat_layers: List[str],
|
|
311
|
+
pretrained_backbone: bool = True,
|
|
312
|
+
input_shape: Optional[Tuple[int, int, int]] = None,
|
|
313
|
+
**kwargs: Any,
|
|
314
|
+
) -> FAST:
|
|
315
|
+
pretrained_backbone = pretrained_backbone and not pretrained
|
|
316
|
+
|
|
317
|
+
# Patch the config
|
|
318
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
319
|
+
_cfg["input_shape"] = input_shape or _cfg["input_shape"]
|
|
320
|
+
if not kwargs.get("class_names", None):
|
|
321
|
+
kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME])
|
|
322
|
+
else:
|
|
323
|
+
kwargs["class_names"] = sorted(kwargs["class_names"])
|
|
324
|
+
|
|
325
|
+
# Feature extractor
|
|
326
|
+
feat_extractor = IntermediateLayerGetter(
|
|
327
|
+
backbone_fn(
|
|
328
|
+
input_shape=_cfg["input_shape"],
|
|
329
|
+
include_top=False,
|
|
330
|
+
pretrained=pretrained_backbone,
|
|
331
|
+
),
|
|
332
|
+
feat_layers,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Build the model
|
|
336
|
+
model = FAST(feat_extractor, cfg=_cfg, **kwargs)
|
|
337
|
+
# Load pretrained parameters
|
|
338
|
+
if pretrained:
|
|
339
|
+
load_pretrained_params(model, _cfg["url"])
|
|
340
|
+
|
|
341
|
+
# Build the model for reparameterization to access the layers
|
|
342
|
+
_ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False)
|
|
343
|
+
|
|
344
|
+
return model
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
348
|
+
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
349
|
+
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
|
|
350
|
+
|
|
351
|
+
>>> import tensorflow as tf
|
|
352
|
+
>>> from doctr.models import fast_tiny
|
|
353
|
+
>>> model = fast_tiny(pretrained=True)
|
|
354
|
+
>>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
|
|
355
|
+
>>> out = model(input_tensor)
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
----
|
|
359
|
+
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
360
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
-------
|
|
364
|
+
text detection architecture
|
|
365
|
+
"""
|
|
366
|
+
return _fast(
|
|
367
|
+
"fast_tiny",
|
|
368
|
+
pretrained,
|
|
369
|
+
textnet_tiny,
|
|
370
|
+
["stage_0", "stage_1", "stage_2", "stage_3"],
|
|
371
|
+
**kwargs,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
376
|
+
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
377
|
+
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
|
|
378
|
+
|
|
379
|
+
>>> import tensorflow as tf
|
|
380
|
+
>>> from doctr.models import fast_small
|
|
381
|
+
>>> model = fast_small(pretrained=True)
|
|
382
|
+
>>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
|
|
383
|
+
>>> out = model(input_tensor)
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
----
|
|
387
|
+
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
388
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
-------
|
|
392
|
+
text detection architecture
|
|
393
|
+
"""
|
|
394
|
+
return _fast(
|
|
395
|
+
"fast_small",
|
|
396
|
+
pretrained,
|
|
397
|
+
textnet_small,
|
|
398
|
+
["stage_0", "stage_1", "stage_2", "stage_3"],
|
|
399
|
+
**kwargs,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
404
|
+
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
405
|
+
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
|
|
406
|
+
|
|
407
|
+
>>> import tensorflow as tf
|
|
408
|
+
>>> from doctr.models import fast_base
|
|
409
|
+
>>> model = fast_base(pretrained=True)
|
|
410
|
+
>>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
|
|
411
|
+
>>> out = model(input_tensor)
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
----
|
|
415
|
+
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
416
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
-------
|
|
420
|
+
text detection architecture
|
|
421
|
+
"""
|
|
422
|
+
return _fast(
|
|
423
|
+
"fast_base",
|
|
424
|
+
pretrained,
|
|
425
|
+
textnet_base,
|
|
426
|
+
["stage_0", "stage_1", "stage_2", "stage_3"],
|
|
427
|
+
**kwargs,
|
|
428
|
+
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -23,6 +23,7 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
23
23
|
"""Implements a post processor for LinkNet model.
|
|
24
24
|
|
|
25
25
|
Args:
|
|
26
|
+
----
|
|
26
27
|
bin_thresh: threshold used to binzarized p_map at inference time
|
|
27
28
|
box_thresh: minimal objectness score to consider a box
|
|
28
29
|
assume_straight_pages: whether the inputs were expected to have horizontal text elements
|
|
@@ -35,7 +36,7 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
35
36
|
assume_straight_pages: bool = True,
|
|
36
37
|
) -> None:
|
|
37
38
|
super().__init__(box_thresh, bin_thresh, assume_straight_pages)
|
|
38
|
-
self.unclip_ratio = 1.
|
|
39
|
+
self.unclip_ratio = 1.5
|
|
39
40
|
|
|
40
41
|
def polygon_to_box(
|
|
41
42
|
self,
|
|
@@ -44,9 +45,11 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
44
45
|
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
45
46
|
|
|
46
47
|
Args:
|
|
48
|
+
----
|
|
47
49
|
points: The first parameter.
|
|
48
50
|
|
|
49
51
|
Returns:
|
|
52
|
+
-------
|
|
50
53
|
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
51
54
|
"""
|
|
52
55
|
if not self.assume_straight_pages:
|
|
@@ -78,7 +81,7 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
78
81
|
if len(expanded_points) < 1:
|
|
79
82
|
return None # type: ignore[return-value]
|
|
80
83
|
return (
|
|
81
|
-
cv2.boundingRect(expanded_points)
|
|
84
|
+
cv2.boundingRect(expanded_points) # type: ignore[return-value]
|
|
82
85
|
if self.assume_straight_pages
|
|
83
86
|
else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
|
|
84
87
|
)
|
|
@@ -91,12 +94,14 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
91
94
|
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
92
95
|
|
|
93
96
|
Args:
|
|
97
|
+
----
|
|
94
98
|
pred: Pred map from differentiable linknet output
|
|
95
99
|
bitmap: Bitmap map computed from pred (binarized)
|
|
96
100
|
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
97
101
|
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
98
102
|
|
|
99
103
|
Returns:
|
|
104
|
+
-------
|
|
100
105
|
np tensor boxes for the bitmap, each box is a 6-element list
|
|
101
106
|
containing x, y, w, h, alpha, score for the box
|
|
102
107
|
"""
|
|
@@ -146,6 +151,7 @@ class _LinkNet(BaseModel):
|
|
|
146
151
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
147
152
|
|
|
148
153
|
Args:
|
|
154
|
+
----
|
|
149
155
|
out_chan: number of channels for the output
|
|
150
156
|
"""
|
|
151
157
|
|
|
@@ -162,14 +168,15 @@ class _LinkNet(BaseModel):
|
|
|
162
168
|
"""Build the target, and it's mask to be used from loss computation.
|
|
163
169
|
|
|
164
170
|
Args:
|
|
171
|
+
----
|
|
165
172
|
target: target coming from dataset
|
|
166
173
|
output_shape: shape of the output of the model without batch_size
|
|
167
174
|
channels_last: whether channels are last or not
|
|
168
175
|
|
|
169
176
|
Returns:
|
|
177
|
+
-------
|
|
170
178
|
the new formatted target and the mask
|
|
171
179
|
"""
|
|
172
|
-
|
|
173
180
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|
|
174
181
|
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
|
|
175
182
|
if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
|
|
@@ -239,7 +246,7 @@ class _LinkNet(BaseModel):
|
|
|
239
246
|
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
|
|
240
247
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
241
248
|
continue
|
|
242
|
-
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1)
|
|
249
|
+
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
|
|
243
250
|
|
|
244
251
|
# Don't forget to switch back to channel last if Tensorflow is used
|
|
245
252
|
if channels_last:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -14,7 +14,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
|
|
14
14
|
from doctr.file_utils import CLASS_NAME
|
|
15
15
|
from doctr.models.classification import resnet18, resnet34, resnet50
|
|
16
16
|
|
|
17
|
-
from ...utils import load_pretrained_params
|
|
17
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
18
18
|
from .base import LinkNetPostProcessor, _LinkNet
|
|
19
19
|
|
|
20
20
|
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
|
|
@@ -25,19 +25,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
25
25
|
"input_shape": (3, 1024, 1024),
|
|
26
26
|
"mean": (0.798, 0.785, 0.772),
|
|
27
27
|
"std": (0.264, 0.2749, 0.287),
|
|
28
|
-
"url":
|
|
28
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-e47a14dc.pt&src=0",
|
|
29
29
|
},
|
|
30
30
|
"linknet_resnet34": {
|
|
31
31
|
"input_shape": (3, 1024, 1024),
|
|
32
32
|
"mean": (0.798, 0.785, 0.772),
|
|
33
33
|
"std": (0.264, 0.2749, 0.287),
|
|
34
|
-
"url":
|
|
34
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-9ca2df3e.pt&src=0",
|
|
35
35
|
},
|
|
36
36
|
"linknet_resnet50": {
|
|
37
37
|
"input_shape": (3, 1024, 1024),
|
|
38
38
|
"mean": (0.798, 0.785, 0.772),
|
|
39
39
|
"std": (0.264, 0.2749, 0.287),
|
|
40
|
-
"url":
|
|
40
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-6cf565c1.pt&src=0",
|
|
41
41
|
},
|
|
42
42
|
}
|
|
43
43
|
|
|
@@ -61,7 +61,6 @@ class LinkNetFPN(nn.Module):
|
|
|
61
61
|
@staticmethod
|
|
62
62
|
def decoder_block(in_chan: int, out_chan: int, stride: int) -> nn.Sequential:
|
|
63
63
|
"""Creates a LinkNet decoder block"""
|
|
64
|
-
|
|
65
64
|
mid_chan = in_chan // 4
|
|
66
65
|
return nn.Sequential(
|
|
67
66
|
nn.Conv2d(in_chan, mid_chan, kernel_size=1, bias=False),
|
|
@@ -90,7 +89,10 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
90
89
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
91
90
|
|
|
92
91
|
Args:
|
|
92
|
+
----
|
|
93
93
|
feature extractor: the backbone serving as feature extractor
|
|
94
|
+
bin_thresh: threshold for binarization of the output feature map
|
|
95
|
+
box_thresh: minimal objectness score to consider a box
|
|
94
96
|
head_chans: number of channels in the head layers
|
|
95
97
|
assume_straight_pages: if True, fit straight bounding boxes only
|
|
96
98
|
exportable: onnx exportable returns only logits
|
|
@@ -102,6 +104,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
102
104
|
self,
|
|
103
105
|
feat_extractor: IntermediateLayerGetter,
|
|
104
106
|
bin_thresh: float = 0.1,
|
|
107
|
+
box_thresh: float = 0.1,
|
|
105
108
|
head_chans: int = 32,
|
|
106
109
|
assume_straight_pages: bool = True,
|
|
107
110
|
exportable: bool = False,
|
|
@@ -142,7 +145,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
142
145
|
)
|
|
143
146
|
|
|
144
147
|
self.postprocessor = LinkNetPostProcessor(
|
|
145
|
-
assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh
|
|
148
|
+
assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
|
|
146
149
|
)
|
|
147
150
|
|
|
148
151
|
for n, m in self.named_modules():
|
|
@@ -175,7 +178,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
175
178
|
return out
|
|
176
179
|
|
|
177
180
|
if return_model_output or target is None or return_preds:
|
|
178
|
-
prob_map = torch.sigmoid(logits)
|
|
181
|
+
prob_map = _bf16_to_float32(torch.sigmoid(logits))
|
|
179
182
|
if return_model_output:
|
|
180
183
|
out["out_map"] = prob_map
|
|
181
184
|
|
|
@@ -204,6 +207,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
204
207
|
<https://github.com/tensorflow/addons/>`_.
|
|
205
208
|
|
|
206
209
|
Args:
|
|
210
|
+
----
|
|
207
211
|
out_map: output feature map of the model of shape (N, num_classes, H, W)
|
|
208
212
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
209
213
|
gamma: modulating factor in the focal loss formula
|
|
@@ -211,6 +215,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
211
215
|
eps: epsilon factor in dice loss
|
|
212
216
|
|
|
213
217
|
Returns:
|
|
218
|
+
-------
|
|
214
219
|
A loss tensor
|
|
215
220
|
"""
|
|
216
221
|
_target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
|
|
@@ -232,10 +237,12 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
232
237
|
# Class reduced
|
|
233
238
|
focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3))
|
|
234
239
|
|
|
235
|
-
#
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
240
|
+
# Compute dice loss for each class
|
|
241
|
+
dice_map = torch.softmax(out_map, dim=1) if len(self.class_names) > 1 else proba_map
|
|
242
|
+
# Class reduced
|
|
243
|
+
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
|
|
244
|
+
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
|
|
245
|
+
dice_loss = (1 - 2 * inter / (cardinality + eps)).mean()
|
|
239
246
|
|
|
240
247
|
# Return the full loss (equal sum of focal loss and dice loss)
|
|
241
248
|
return focal_loss + dice_loss
|
|
@@ -288,12 +295,14 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
288
295
|
>>> out = model(input_tensor)
|
|
289
296
|
|
|
290
297
|
Args:
|
|
298
|
+
----
|
|
291
299
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
300
|
+
**kwargs: keyword arguments of the LinkNet architecture
|
|
292
301
|
|
|
293
302
|
Returns:
|
|
303
|
+
-------
|
|
294
304
|
text detection architecture
|
|
295
305
|
"""
|
|
296
|
-
|
|
297
306
|
return _linknet(
|
|
298
307
|
"linknet_resnet18",
|
|
299
308
|
pretrained,
|
|
@@ -318,12 +327,14 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
318
327
|
>>> out = model(input_tensor)
|
|
319
328
|
|
|
320
329
|
Args:
|
|
330
|
+
----
|
|
321
331
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
332
|
+
**kwargs: keyword arguments of the LinkNet architecture
|
|
322
333
|
|
|
323
334
|
Returns:
|
|
335
|
+
-------
|
|
324
336
|
text detection architecture
|
|
325
337
|
"""
|
|
326
|
-
|
|
327
338
|
return _linknet(
|
|
328
339
|
"linknet_resnet34",
|
|
329
340
|
pretrained,
|
|
@@ -348,12 +359,14 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
348
359
|
>>> out = model(input_tensor)
|
|
349
360
|
|
|
350
361
|
Args:
|
|
362
|
+
----
|
|
351
363
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
364
|
+
**kwargs: keyword arguments of the LinkNet architecture
|
|
352
365
|
|
|
353
366
|
Returns:
|
|
367
|
+
-------
|
|
354
368
|
text detection architecture
|
|
355
369
|
"""
|
|
356
|
-
|
|
357
370
|
return _linknet(
|
|
358
371
|
"linknet_resnet50",
|
|
359
372
|
pretrained,
|