python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +2 -0
- doctr/datasets/cord.py +6 -4
- doctr/datasets/datasets/base.py +3 -2
- doctr/datasets/datasets/pytorch.py +4 -2
- doctr/datasets/datasets/tensorflow.py +4 -2
- doctr/datasets/detection.py +6 -3
- doctr/datasets/doc_artefacts.py +2 -1
- doctr/datasets/funsd.py +7 -8
- doctr/datasets/generator/base.py +3 -2
- doctr/datasets/generator/pytorch.py +3 -1
- doctr/datasets/generator/tensorflow.py +3 -1
- doctr/datasets/ic03.py +3 -2
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +6 -4
- doctr/datasets/iiithws.py +2 -1
- doctr/datasets/imgur5k.py +3 -2
- doctr/datasets/loader.py +4 -2
- doctr/datasets/mjsynth.py +2 -1
- doctr/datasets/ocr.py +2 -1
- doctr/datasets/orientation.py +40 -0
- doctr/datasets/recognition.py +3 -2
- doctr/datasets/sroie.py +2 -1
- doctr/datasets/svhn.py +2 -1
- doctr/datasets/svt.py +3 -2
- doctr/datasets/synthtext.py +2 -1
- doctr/datasets/utils.py +27 -11
- doctr/datasets/vocabs.py +26 -1
- doctr/datasets/wildreceipt.py +111 -0
- doctr/file_utils.py +3 -1
- doctr/io/elements.py +52 -35
- doctr/io/html.py +5 -3
- doctr/io/image/base.py +5 -4
- doctr/io/image/pytorch.py +12 -7
- doctr/io/image/tensorflow.py +11 -6
- doctr/io/pdf.py +5 -4
- doctr/io/reader.py +13 -5
- doctr/models/_utils.py +30 -53
- doctr/models/artefacts/barcode.py +4 -3
- doctr/models/artefacts/face.py +4 -2
- doctr/models/builder.py +58 -43
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +5 -2
- doctr/models/classification/magc_resnet/tensorflow.py +5 -2
- doctr/models/classification/mobilenet/pytorch.py +16 -4
- doctr/models/classification/mobilenet/tensorflow.py +29 -20
- doctr/models/classification/predictor/pytorch.py +3 -2
- doctr/models/classification/predictor/tensorflow.py +2 -1
- doctr/models/classification/resnet/pytorch.py +23 -13
- doctr/models/classification/resnet/tensorflow.py +33 -26
- doctr/models/classification/textnet/__init__.py +6 -0
- doctr/models/classification/textnet/pytorch.py +275 -0
- doctr/models/classification/textnet/tensorflow.py +267 -0
- doctr/models/classification/vgg/pytorch.py +4 -2
- doctr/models/classification/vgg/tensorflow.py +5 -2
- doctr/models/classification/vit/pytorch.py +9 -3
- doctr/models/classification/vit/tensorflow.py +9 -3
- doctr/models/classification/zoo.py +7 -2
- doctr/models/core.py +1 -1
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/pytorch.py +7 -1
- doctr/models/detection/_utils/tensorflow.py +7 -3
- doctr/models/detection/core.py +9 -3
- doctr/models/detection/differentiable_binarization/base.py +37 -25
- doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
- doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +256 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +12 -5
- doctr/models/detection/linknet/pytorch.py +28 -15
- doctr/models/detection/linknet/tensorflow.py +68 -88
- doctr/models/detection/predictor/pytorch.py +16 -6
- doctr/models/detection/predictor/tensorflow.py +13 -5
- doctr/models/detection/zoo.py +19 -16
- doctr/models/factory/hub.py +20 -10
- doctr/models/kie_predictor/base.py +2 -1
- doctr/models/kie_predictor/pytorch.py +28 -36
- doctr/models/kie_predictor/tensorflow.py +27 -27
- doctr/models/modules/__init__.py +1 -0
- doctr/models/modules/layers/__init__.py +6 -0
- doctr/models/modules/layers/pytorch.py +166 -0
- doctr/models/modules/layers/tensorflow.py +175 -0
- doctr/models/modules/transformer/pytorch.py +24 -22
- doctr/models/modules/transformer/tensorflow.py +6 -4
- doctr/models/modules/vision_transformer/pytorch.py +2 -4
- doctr/models/modules/vision_transformer/tensorflow.py +2 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
- doctr/models/predictor/base.py +14 -3
- doctr/models/predictor/pytorch.py +26 -29
- doctr/models/predictor/tensorflow.py +25 -22
- doctr/models/preprocessor/pytorch.py +14 -9
- doctr/models/preprocessor/tensorflow.py +10 -5
- doctr/models/recognition/core.py +4 -1
- doctr/models/recognition/crnn/pytorch.py +23 -16
- doctr/models/recognition/crnn/tensorflow.py +25 -17
- doctr/models/recognition/master/base.py +4 -1
- doctr/models/recognition/master/pytorch.py +20 -9
- doctr/models/recognition/master/tensorflow.py +20 -8
- doctr/models/recognition/parseq/base.py +4 -1
- doctr/models/recognition/parseq/pytorch.py +28 -22
- doctr/models/recognition/parseq/tensorflow.py +22 -11
- doctr/models/recognition/predictor/_utils.py +3 -2
- doctr/models/recognition/predictor/pytorch.py +3 -2
- doctr/models/recognition/predictor/tensorflow.py +2 -1
- doctr/models/recognition/sar/pytorch.py +14 -7
- doctr/models/recognition/sar/tensorflow.py +23 -14
- doctr/models/recognition/utils.py +5 -1
- doctr/models/recognition/vitstr/base.py +4 -1
- doctr/models/recognition/vitstr/pytorch.py +22 -13
- doctr/models/recognition/vitstr/tensorflow.py +21 -10
- doctr/models/recognition/zoo.py +4 -2
- doctr/models/utils/pytorch.py +24 -6
- doctr/models/utils/tensorflow.py +22 -3
- doctr/models/zoo.py +21 -3
- doctr/transforms/functional/base.py +8 -3
- doctr/transforms/functional/pytorch.py +23 -6
- doctr/transforms/functional/tensorflow.py +25 -5
- doctr/transforms/modules/base.py +12 -5
- doctr/transforms/modules/pytorch.py +10 -12
- doctr/transforms/modules/tensorflow.py +17 -9
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +4 -2
- doctr/utils/fonts.py +3 -2
- doctr/utils/geometry.py +95 -26
- doctr/utils/metrics.py +36 -22
- doctr/utils/multithreading.py +5 -3
- doctr/utils/repr.py +3 -1
- doctr/utils/visualization.py +31 -8
- doctr/version.py +1 -1
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
- python_doctr-0.8.1.dist-info/RECORD +173 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
- python_doctr-0.7.0.dist-info/RECORD +0 -161
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -22,6 +22,7 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
22
22
|
<https://github.com/xuannianz/DifferentiableBinarization>`_.
|
|
23
23
|
|
|
24
24
|
Args:
|
|
25
|
+
----
|
|
25
26
|
unclip ratio: ratio used to unshrink polygons
|
|
26
27
|
min_size_box: minimal length (pix) to keep a box
|
|
27
28
|
max_candidates: maximum boxes to consider in a single page
|
|
@@ -37,7 +38,7 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
37
38
|
assume_straight_pages: bool = True,
|
|
38
39
|
) -> None:
|
|
39
40
|
super().__init__(box_thresh, bin_thresh, assume_straight_pages)
|
|
40
|
-
self.unclip_ratio = 1.5
|
|
41
|
+
self.unclip_ratio = 1.5
|
|
41
42
|
|
|
42
43
|
def polygon_to_box(
|
|
43
44
|
self,
|
|
@@ -46,9 +47,11 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
46
47
|
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
47
48
|
|
|
48
49
|
Args:
|
|
50
|
+
----
|
|
49
51
|
points: The first parameter.
|
|
50
52
|
|
|
51
53
|
Returns:
|
|
54
|
+
-------
|
|
52
55
|
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
53
56
|
"""
|
|
54
57
|
if not self.assume_straight_pages:
|
|
@@ -80,7 +83,7 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
80
83
|
if len(expanded_points) < 1:
|
|
81
84
|
return None # type: ignore[return-value]
|
|
82
85
|
return (
|
|
83
|
-
cv2.boundingRect(expanded_points)
|
|
86
|
+
cv2.boundingRect(expanded_points) # type: ignore[return-value]
|
|
84
87
|
if self.assume_straight_pages
|
|
85
88
|
else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
|
|
86
89
|
)
|
|
@@ -90,20 +93,22 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
90
93
|
pred: np.ndarray,
|
|
91
94
|
bitmap: np.ndarray,
|
|
92
95
|
) -> np.ndarray:
|
|
93
|
-
"""Compute boxes from a bitmap/pred_map
|
|
96
|
+
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
94
97
|
|
|
95
98
|
Args:
|
|
99
|
+
----
|
|
96
100
|
pred: Pred map from differentiable binarization output
|
|
97
101
|
bitmap: Bitmap map computed from pred (binarized)
|
|
98
102
|
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
99
103
|
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
100
104
|
|
|
101
105
|
Returns:
|
|
106
|
+
-------
|
|
102
107
|
np tensor boxes for the bitmap, each box is a 5-element list
|
|
103
108
|
containing x, y, w, h, score for the box
|
|
104
109
|
"""
|
|
105
110
|
height, width = bitmap.shape[:2]
|
|
106
|
-
min_size_box =
|
|
111
|
+
min_size_box = 2
|
|
107
112
|
boxes: List[Union[np.ndarray, List[float]]] = []
|
|
108
113
|
# get contours from connected components on the bitmap
|
|
109
114
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
@@ -158,6 +163,7 @@ class _DBNet:
|
|
|
158
163
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
159
164
|
|
|
160
165
|
Args:
|
|
166
|
+
----
|
|
161
167
|
feature extractor: the backbone serving as feature extractor
|
|
162
168
|
fpn_channels: number of channels each extracted feature maps is mapped to
|
|
163
169
|
"""
|
|
@@ -174,17 +180,20 @@ class _DBNet:
|
|
|
174
180
|
ys: np.ndarray,
|
|
175
181
|
a: np.ndarray,
|
|
176
182
|
b: np.ndarray,
|
|
177
|
-
eps: float = 1e-
|
|
183
|
+
eps: float = 1e-6,
|
|
178
184
|
) -> float:
|
|
179
185
|
"""Compute the distance for each point of the map (xs, ys) to the (a, b) segment
|
|
180
186
|
|
|
181
187
|
Args:
|
|
188
|
+
----
|
|
182
189
|
xs : map of x coordinates (height, width)
|
|
183
190
|
ys : map of y coordinates (height, width)
|
|
184
191
|
a: first point defining the [ab] segment
|
|
185
192
|
b: second point defining the [ab] segment
|
|
193
|
+
eps: epsilon to avoid division by zero
|
|
186
194
|
|
|
187
195
|
Returns:
|
|
196
|
+
-------
|
|
188
197
|
The computed distance
|
|
189
198
|
|
|
190
199
|
"""
|
|
@@ -192,9 +201,10 @@ class _DBNet:
|
|
|
192
201
|
square_dist_2 = np.square(xs - b[0]) + np.square(ys - b[1])
|
|
193
202
|
square_dist = np.square(a[0] - b[0]) + np.square(a[1] - b[1])
|
|
194
203
|
cosin = (square_dist - square_dist_1 - square_dist_2) / (2 * np.sqrt(square_dist_1 * square_dist_2) + eps)
|
|
204
|
+
cosin = np.clip(cosin, -1.0, 1.0)
|
|
195
205
|
square_sin = 1 - np.square(cosin)
|
|
196
206
|
square_sin = np.nan_to_num(square_sin)
|
|
197
|
-
result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist)
|
|
207
|
+
result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist + eps)
|
|
198
208
|
result[cosin < 0] = np.sqrt(np.fmin(square_dist_1, square_dist_2))[cosin < 0]
|
|
199
209
|
return result
|
|
200
210
|
|
|
@@ -207,6 +217,7 @@ class _DBNet:
|
|
|
207
217
|
"""Draw a polygon treshold map on a canvas, as described in the DB paper
|
|
208
218
|
|
|
209
219
|
Args:
|
|
220
|
+
----
|
|
210
221
|
polygon : array of coord., to draw the boundary of the polygon
|
|
211
222
|
canvas : threshold map to fill with polygons
|
|
212
223
|
mask : mask for training on threshold polygons
|
|
@@ -223,7 +234,7 @@ class _DBNet:
|
|
|
223
234
|
padded_polygon: np.ndarray = np.array(padding.Execute(distance)[0])
|
|
224
235
|
|
|
225
236
|
# Fill the mask with 1 on the new padded polygon
|
|
226
|
-
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
|
|
237
|
+
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) # type: ignore[call-overload]
|
|
227
238
|
|
|
228
239
|
# Get min/max to recover polygon after distance computation
|
|
229
240
|
xmin = padded_polygon[:, 0].min()
|
|
@@ -255,7 +266,10 @@ class _DBNet:
|
|
|
255
266
|
|
|
256
267
|
# Fill the canvas with the distances computed inside the valid padded polygon
|
|
257
268
|
canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
|
|
258
|
-
1
|
|
269
|
+
1
|
|
270
|
+
- distance_map[
|
|
271
|
+
ymin_valid - ymin : ymax_valid - ymax + height, xmin_valid - xmin : xmax_valid - xmax + width
|
|
272
|
+
],
|
|
259
273
|
canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
|
|
260
274
|
)
|
|
261
275
|
|
|
@@ -264,7 +278,7 @@ class _DBNet:
|
|
|
264
278
|
def build_target(
|
|
265
279
|
self,
|
|
266
280
|
target: List[Dict[str, np.ndarray]],
|
|
267
|
-
output_shape: Tuple[int, int, int
|
|
281
|
+
output_shape: Tuple[int, int, int],
|
|
268
282
|
channels_last: bool = True,
|
|
269
283
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
270
284
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|
|
@@ -274,23 +288,24 @@ class _DBNet:
|
|
|
274
288
|
|
|
275
289
|
input_dtype = next(iter(target[0].values())).dtype if len(target) > 0 else np.float32
|
|
276
290
|
|
|
291
|
+
h: int
|
|
292
|
+
w: int
|
|
277
293
|
if channels_last:
|
|
278
|
-
h, w = output_shape
|
|
279
|
-
target_shape = (output_shape[0], output_shape[-1], h, w) # (Batch_size, num_classes, h, w)
|
|
294
|
+
h, w, num_classes = output_shape
|
|
280
295
|
else:
|
|
281
|
-
h, w = output_shape
|
|
282
|
-
|
|
296
|
+
num_classes, h, w = output_shape
|
|
297
|
+
target_shape = (len(target), num_classes, h, w)
|
|
298
|
+
|
|
283
299
|
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
284
300
|
seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
|
|
285
301
|
thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32)
|
|
286
|
-
thresh_mask: np.ndarray = np.
|
|
302
|
+
thresh_mask: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
287
303
|
|
|
288
304
|
for idx, tgt in enumerate(target):
|
|
289
305
|
for class_idx, _tgt in enumerate(tgt.values()):
|
|
290
306
|
# Draw each polygon on gt
|
|
291
307
|
if _tgt.shape[0] == 0:
|
|
292
308
|
# Empty image, full masked
|
|
293
|
-
# seg_mask[idx, :, :, class_idx] = False
|
|
294
309
|
seg_mask[idx, class_idx] = False
|
|
295
310
|
|
|
296
311
|
# Absolute bounding boxes
|
|
@@ -316,10 +331,9 @@ class _DBNet:
|
|
|
316
331
|
)
|
|
317
332
|
boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])
|
|
318
333
|
|
|
319
|
-
for box, box_size
|
|
334
|
+
for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
|
|
320
335
|
# Mask boxes that are too small
|
|
321
336
|
if box_size < self.min_size_box:
|
|
322
|
-
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
|
|
323
337
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
324
338
|
continue
|
|
325
339
|
|
|
@@ -329,19 +343,17 @@ class _DBNet:
|
|
|
329
343
|
subject = [tuple(coor) for coor in poly]
|
|
330
344
|
padding = pyclipper.PyclipperOffset()
|
|
331
345
|
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
332
|
-
|
|
346
|
+
shrunken = padding.Execute(-distance)
|
|
333
347
|
|
|
334
348
|
# Draw polygon on gt if it is valid
|
|
335
|
-
if len(
|
|
336
|
-
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
|
|
349
|
+
if len(shrunken) == 0:
|
|
337
350
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
338
351
|
continue
|
|
339
|
-
|
|
340
|
-
if
|
|
341
|
-
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
|
|
352
|
+
shrunken = np.array(shrunken[0]).reshape(-1, 2)
|
|
353
|
+
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
|
|
342
354
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
343
355
|
continue
|
|
344
|
-
cv2.fillPoly(seg_target[idx, class_idx], [
|
|
356
|
+
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
|
|
345
357
|
|
|
346
358
|
# Draw on both thresh map and thresh mask
|
|
347
359
|
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -16,10 +16,10 @@ from torchvision.ops.deform_conv import DeformConv2d
|
|
|
16
16
|
from doctr.file_utils import CLASS_NAME
|
|
17
17
|
|
|
18
18
|
from ...classification import mobilenet_v3_large
|
|
19
|
-
from ...utils import load_pretrained_params
|
|
19
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
20
20
|
from .base import DBPostProcessor, _DBNet
|
|
21
21
|
|
|
22
|
-
__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"
|
|
22
|
+
__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
@@ -27,25 +27,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
27
27
|
"input_shape": (3, 1024, 1024),
|
|
28
28
|
"mean": (0.798, 0.785, 0.772),
|
|
29
29
|
"std": (0.264, 0.2749, 0.287),
|
|
30
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
30
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-79bd7d70.pt&src=0",
|
|
31
31
|
},
|
|
32
32
|
"db_resnet34": {
|
|
33
33
|
"input_shape": (3, 1024, 1024),
|
|
34
34
|
"mean": (0.798, 0.785, 0.772),
|
|
35
35
|
"std": (0.264, 0.2749, 0.287),
|
|
36
|
-
"url":
|
|
36
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet34-cb6aed9e.pt&src=0",
|
|
37
37
|
},
|
|
38
38
|
"db_mobilenet_v3_large": {
|
|
39
39
|
"input_shape": (3, 1024, 1024),
|
|
40
40
|
"mean": (0.798, 0.785, 0.772),
|
|
41
41
|
"std": (0.264, 0.2749, 0.287),
|
|
42
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
43
|
-
},
|
|
44
|
-
"db_resnet50_rotation": {
|
|
45
|
-
"input_shape": (3, 1024, 1024),
|
|
46
|
-
"mean": (0.798, 0.785, 0.772),
|
|
47
|
-
"std": (0.264, 0.2749, 0.287),
|
|
48
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/db_resnet50-1138863a.pt&src=0",
|
|
42
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-81e9b152.pt&src=0",
|
|
49
43
|
},
|
|
50
44
|
}
|
|
51
45
|
|
|
@@ -63,28 +57,24 @@ class FeaturePyramidNetwork(nn.Module):
|
|
|
63
57
|
|
|
64
58
|
conv_layer = DeformConv2d if deform_conv else nn.Conv2d
|
|
65
59
|
|
|
66
|
-
self.in_branches = nn.ModuleList(
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
]
|
|
75
|
-
)
|
|
60
|
+
self.in_branches = nn.ModuleList([
|
|
61
|
+
nn.Sequential(
|
|
62
|
+
conv_layer(chans, out_channels, 1, bias=False),
|
|
63
|
+
nn.BatchNorm2d(out_channels),
|
|
64
|
+
nn.ReLU(inplace=True),
|
|
65
|
+
)
|
|
66
|
+
for idx, chans in enumerate(in_channels)
|
|
67
|
+
])
|
|
76
68
|
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
|
77
|
-
self.out_branches = nn.ModuleList(
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
]
|
|
87
|
-
)
|
|
69
|
+
self.out_branches = nn.ModuleList([
|
|
70
|
+
nn.Sequential(
|
|
71
|
+
conv_layer(out_channels, out_chans, 3, padding=1, bias=False),
|
|
72
|
+
nn.BatchNorm2d(out_chans),
|
|
73
|
+
nn.ReLU(inplace=True),
|
|
74
|
+
nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True),
|
|
75
|
+
)
|
|
76
|
+
for idx, chans in enumerate(in_channels)
|
|
77
|
+
])
|
|
88
78
|
|
|
89
79
|
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
|
|
90
80
|
if len(x) != len(self.out_branches):
|
|
@@ -106,9 +96,12 @@ class DBNet(_DBNet, nn.Module):
|
|
|
106
96
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
107
97
|
|
|
108
98
|
Args:
|
|
99
|
+
----
|
|
109
100
|
feature extractor: the backbone serving as feature extractor
|
|
110
101
|
head_chans: the number of channels in the head
|
|
111
102
|
deform_conv: whether to use deformable convolution
|
|
103
|
+
bin_thresh: threshold for binarization
|
|
104
|
+
box_thresh: minimal objectness score to consider a box
|
|
112
105
|
assume_straight_pages: if True, fit straight bounding boxes only
|
|
113
106
|
exportable: onnx exportable returns only logits
|
|
114
107
|
cfg: the configuration dict of the model
|
|
@@ -121,6 +114,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
121
114
|
head_chans: int = 256,
|
|
122
115
|
deform_conv: bool = False,
|
|
123
116
|
bin_thresh: float = 0.3,
|
|
117
|
+
box_thresh: float = 0.1,
|
|
124
118
|
assume_straight_pages: bool = True,
|
|
125
119
|
exportable: bool = False,
|
|
126
120
|
cfg: Optional[Dict[str, Any]] = None,
|
|
@@ -169,7 +163,9 @@ class DBNet(_DBNet, nn.Module):
|
|
|
169
163
|
nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2),
|
|
170
164
|
)
|
|
171
165
|
|
|
172
|
-
self.postprocessor = DBPostProcessor(
|
|
166
|
+
self.postprocessor = DBPostProcessor(
|
|
167
|
+
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
|
|
168
|
+
)
|
|
173
169
|
|
|
174
170
|
for n, m in self.named_modules():
|
|
175
171
|
# Don't override the initialization of the backbone
|
|
@@ -203,7 +199,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
203
199
|
return out
|
|
204
200
|
|
|
205
201
|
if return_model_output or target is None or return_preds:
|
|
206
|
-
prob_map = torch.sigmoid(logits)
|
|
202
|
+
prob_map = _bf16_to_float32(torch.sigmoid(logits))
|
|
207
203
|
|
|
208
204
|
if return_model_output:
|
|
209
205
|
out["out_map"] = prob_map
|
|
@@ -222,64 +218,72 @@ class DBNet(_DBNet, nn.Module):
|
|
|
222
218
|
|
|
223
219
|
return out
|
|
224
220
|
|
|
225
|
-
def compute_loss(
|
|
221
|
+
def compute_loss(
|
|
222
|
+
self,
|
|
223
|
+
out_map: torch.Tensor,
|
|
224
|
+
thresh_map: torch.Tensor,
|
|
225
|
+
target: List[np.ndarray],
|
|
226
|
+
gamma: float = 2.0,
|
|
227
|
+
alpha: float = 0.5,
|
|
228
|
+
eps: float = 1e-8,
|
|
229
|
+
) -> torch.Tensor:
|
|
226
230
|
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
|
|
227
231
|
and a list of masks for each image. From there it computes the loss with the model output
|
|
228
232
|
|
|
229
233
|
Args:
|
|
234
|
+
----
|
|
230
235
|
out_map: output feature map of the model of shape (N, C, H, W)
|
|
231
236
|
thresh_map: threshold map of shape (N, C, H, W)
|
|
232
237
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
238
|
+
gamma: modulating factor in the focal loss formula
|
|
239
|
+
alpha: balancing factor in the focal loss formula
|
|
240
|
+
eps: epsilon factor in dice loss
|
|
233
241
|
|
|
234
242
|
Returns:
|
|
243
|
+
-------
|
|
235
244
|
A loss tensor
|
|
236
245
|
"""
|
|
246
|
+
if gamma < 0:
|
|
247
|
+
raise ValueError("Value of gamma should be greater than or equal to zero.")
|
|
237
248
|
|
|
238
249
|
prob_map = torch.sigmoid(out_map)
|
|
239
250
|
thresh_map = torch.sigmoid(thresh_map)
|
|
240
251
|
|
|
241
|
-
targets = self.build_target(target,
|
|
252
|
+
targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
|
|
242
253
|
|
|
243
254
|
seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
|
|
244
255
|
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
|
|
245
256
|
thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3])
|
|
246
257
|
thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device)
|
|
247
258
|
|
|
248
|
-
# Compute balanced BCE loss for proba_map
|
|
249
|
-
bce_scale = 5.0
|
|
250
|
-
balanced_bce_loss = torch.zeros(1, device=out_map.device)
|
|
251
|
-
dice_loss = torch.zeros(1, device=out_map.device)
|
|
252
|
-
l1_loss = torch.zeros(1, device=out_map.device)
|
|
253
259
|
if torch.any(seg_mask):
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
dice_loss = 1 - 2.0 * inter / union
|
|
260
|
+
# Focal loss
|
|
261
|
+
focal_scale = 10.0
|
|
262
|
+
bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none")
|
|
263
|
+
|
|
264
|
+
p_t = prob_map * seg_target + (1 - prob_map) * (1 - seg_target)
|
|
265
|
+
alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target)
|
|
266
|
+
# Unreduced version
|
|
267
|
+
focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
|
|
268
|
+
# Class reduced
|
|
269
|
+
focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3))
|
|
270
|
+
|
|
271
|
+
# Compute dice loss for each class or for approx binary_map
|
|
272
|
+
if len(self.class_names) > 1:
|
|
273
|
+
dice_map = torch.softmax(out_map, dim=1)
|
|
274
|
+
else:
|
|
275
|
+
# compute binary map instead
|
|
276
|
+
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
|
|
277
|
+
# Class reduced
|
|
278
|
+
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
|
|
279
|
+
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
|
|
280
|
+
dice_loss = (1 - 2 * inter / (cardinality + eps)).mean()
|
|
276
281
|
|
|
277
282
|
# Compute l1 loss for thresh_map
|
|
278
|
-
l1_scale = 10.0
|
|
279
283
|
if torch.any(thresh_mask):
|
|
280
|
-
l1_loss =
|
|
284
|
+
l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps)
|
|
281
285
|
|
|
282
|
-
return
|
|
286
|
+
return l1_loss + focal_scale * focal_loss + dice_loss
|
|
283
287
|
|
|
284
288
|
|
|
285
289
|
def _dbnet(
|
|
@@ -337,12 +341,14 @@ def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
337
341
|
>>> out = model(input_tensor)
|
|
338
342
|
|
|
339
343
|
Args:
|
|
344
|
+
----
|
|
340
345
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
346
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
341
347
|
|
|
342
348
|
Returns:
|
|
349
|
+
-------
|
|
343
350
|
text detection architecture
|
|
344
351
|
"""
|
|
345
|
-
|
|
346
352
|
return _dbnet(
|
|
347
353
|
"db_resnet34",
|
|
348
354
|
pretrained,
|
|
@@ -370,12 +376,14 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
370
376
|
>>> out = model(input_tensor)
|
|
371
377
|
|
|
372
378
|
Args:
|
|
379
|
+
----
|
|
373
380
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
381
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
374
382
|
|
|
375
383
|
Returns:
|
|
384
|
+
-------
|
|
376
385
|
text detection architecture
|
|
377
386
|
"""
|
|
378
|
-
|
|
379
387
|
return _dbnet(
|
|
380
388
|
"db_resnet50",
|
|
381
389
|
pretrained,
|
|
@@ -403,12 +411,14 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
403
411
|
>>> out = model(input_tensor)
|
|
404
412
|
|
|
405
413
|
Args:
|
|
414
|
+
----
|
|
406
415
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
416
|
+
**kwargs: keyword arguments of the DBNet architecture
|
|
407
417
|
|
|
408
418
|
Returns:
|
|
419
|
+
-------
|
|
409
420
|
text detection architecture
|
|
410
421
|
"""
|
|
411
|
-
|
|
412
422
|
return _dbnet(
|
|
413
423
|
"db_mobilenet_v3_large",
|
|
414
424
|
pretrained,
|
|
@@ -423,37 +433,3 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
423
433
|
],
|
|
424
434
|
**kwargs,
|
|
425
435
|
)
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
def db_resnet50_rotation(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
429
|
-
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
430
|
-
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
|
|
431
|
-
This model is trained with rotated documents
|
|
432
|
-
|
|
433
|
-
>>> import torch
|
|
434
|
-
>>> from doctr.models import db_resnet50_rotation
|
|
435
|
-
>>> model = db_resnet50_rotation(pretrained=True)
|
|
436
|
-
>>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
|
|
437
|
-
>>> out = model(input_tensor)
|
|
438
|
-
|
|
439
|
-
Args:
|
|
440
|
-
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
441
|
-
|
|
442
|
-
Returns:
|
|
443
|
-
text detection architecture
|
|
444
|
-
"""
|
|
445
|
-
|
|
446
|
-
return _dbnet(
|
|
447
|
-
"db_resnet50_rotation",
|
|
448
|
-
pretrained,
|
|
449
|
-
resnet50,
|
|
450
|
-
["layer1", "layer2", "layer3", "layer4"],
|
|
451
|
-
None,
|
|
452
|
-
ignore_keys=[
|
|
453
|
-
"prob_head.6.weight",
|
|
454
|
-
"prob_head.6.bias",
|
|
455
|
-
"thresh_head.6.weight",
|
|
456
|
-
"thresh_head.6.bias",
|
|
457
|
-
],
|
|
458
|
-
**kwargs,
|
|
459
|
-
)
|