python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +8 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +7 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +4 -5
- doctr/datasets/ic13.py +4 -5
- doctr/datasets/iiit5k.py +6 -5
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +6 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +6 -5
- doctr/datasets/svt.py +4 -5
- doctr/datasets/synthtext.py +4 -5
- doctr/datasets/utils.py +34 -29
- doctr/datasets/vocabs.py +17 -7
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +2 -7
- doctr/io/elements.py +59 -79
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +30 -48
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +8 -11
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +5 -17
- doctr/models/classification/mobilenet/tensorflow.py +8 -21
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +6 -8
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +20 -31
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +8 -15
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +9 -12
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +6 -12
- doctr/models/classification/zoo.py +19 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +14 -26
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +14 -23
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +3 -7
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +18 -19
- doctr/models/kie_predictor/tensorflow.py +13 -14
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +12 -13
- doctr/models/predictor/tensorflow.py +8 -9
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +11 -23
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +12 -22
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +16 -22
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +12 -21
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +12 -20
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +14 -17
- doctr/models/utils/tensorflow.py +17 -16
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +16 -47
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
@@ -29,7 +29,7 @@ from .base import DBPostProcessor, _DBNet
|
|
|
29
29
|
__all__ = ["DBNet", "db_resnet50", "db_mobilenet_v3_large"]
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
default_cfgs:
|
|
32
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
33
33
|
"db_resnet50": {
|
|
34
34
|
"mean": (0.798, 0.785, 0.772),
|
|
35
35
|
"std": (0.264, 0.2749, 0.287),
|
|
@@ -50,7 +50,6 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
50
50
|
<https://arxiv.org/pdf/1612.03144.pdf>`_.
|
|
51
51
|
|
|
52
52
|
Args:
|
|
53
|
-
----
|
|
54
53
|
channels: number of channel to output
|
|
55
54
|
"""
|
|
56
55
|
|
|
@@ -72,12 +71,10 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
72
71
|
"""Module which performs a 3x3 convolution followed by up-sampling
|
|
73
72
|
|
|
74
73
|
Args:
|
|
75
|
-
----
|
|
76
74
|
channels: number of output channels
|
|
77
75
|
dilation_factor (int): dilation factor to scale the convolution output before concatenation
|
|
78
76
|
|
|
79
77
|
Returns:
|
|
80
|
-
-------
|
|
81
78
|
a keras.layers.Layer object, wrapping these operations in a sequential module
|
|
82
79
|
|
|
83
80
|
"""
|
|
@@ -95,7 +92,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
95
92
|
|
|
96
93
|
def call(
|
|
97
94
|
self,
|
|
98
|
-
x:
|
|
95
|
+
x: list[tf.Tensor],
|
|
99
96
|
**kwargs: Any,
|
|
100
97
|
) -> tf.Tensor:
|
|
101
98
|
# Channel mapping
|
|
@@ -114,7 +111,6 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
114
111
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
115
112
|
|
|
116
113
|
Args:
|
|
117
|
-
----
|
|
118
114
|
feature extractor: the backbone serving as feature extractor
|
|
119
115
|
fpn_channels: number of channels each extracted feature maps is mapped to
|
|
120
116
|
bin_thresh: threshold for binarization
|
|
@@ -125,7 +121,7 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
125
121
|
class_names: list of class names
|
|
126
122
|
"""
|
|
127
123
|
|
|
128
|
-
_children_names:
|
|
124
|
+
_children_names: list[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"]
|
|
129
125
|
|
|
130
126
|
def __init__(
|
|
131
127
|
self,
|
|
@@ -135,8 +131,8 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
135
131
|
box_thresh: float = 0.1,
|
|
136
132
|
assume_straight_pages: bool = True,
|
|
137
133
|
exportable: bool = False,
|
|
138
|
-
cfg:
|
|
139
|
-
class_names:
|
|
134
|
+
cfg: dict[str, Any] | None = None,
|
|
135
|
+
class_names: list[str] = [CLASS_NAME],
|
|
140
136
|
) -> None:
|
|
141
137
|
super().__init__()
|
|
142
138
|
self.class_names = class_names
|
|
@@ -175,7 +171,7 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
175
171
|
self,
|
|
176
172
|
out_map: tf.Tensor,
|
|
177
173
|
thresh_map: tf.Tensor,
|
|
178
|
-
target:
|
|
174
|
+
target: list[dict[str, np.ndarray]],
|
|
179
175
|
gamma: float = 2.0,
|
|
180
176
|
alpha: float = 0.5,
|
|
181
177
|
eps: float = 1e-8,
|
|
@@ -184,7 +180,6 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
184
180
|
and a list of masks for each image. From there it computes the loss with the model output
|
|
185
181
|
|
|
186
182
|
Args:
|
|
187
|
-
----
|
|
188
183
|
out_map: output feature map of the model of shape (N, H, W, C)
|
|
189
184
|
thresh_map: threshold map of shape (N, H, W, C)
|
|
190
185
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
@@ -193,7 +188,6 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
193
188
|
eps: epsilon factor in dice loss
|
|
194
189
|
|
|
195
190
|
Returns:
|
|
196
|
-
-------
|
|
197
191
|
A loss tensor
|
|
198
192
|
"""
|
|
199
193
|
if gamma < 0:
|
|
@@ -246,16 +240,16 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
246
240
|
def call(
|
|
247
241
|
self,
|
|
248
242
|
x: tf.Tensor,
|
|
249
|
-
target:
|
|
243
|
+
target: list[dict[str, np.ndarray]] | None = None,
|
|
250
244
|
return_model_output: bool = False,
|
|
251
245
|
return_preds: bool = False,
|
|
252
246
|
**kwargs: Any,
|
|
253
|
-
) ->
|
|
247
|
+
) -> dict[str, Any]:
|
|
254
248
|
feat_maps = self.feat_extractor(x, **kwargs)
|
|
255
249
|
feat_concat = self.fpn(feat_maps, **kwargs)
|
|
256
250
|
logits = self.probability_head(feat_concat, **kwargs)
|
|
257
251
|
|
|
258
|
-
out:
|
|
252
|
+
out: dict[str, tf.Tensor] = {}
|
|
259
253
|
if self.exportable:
|
|
260
254
|
out["logits"] = logits
|
|
261
255
|
return out
|
|
@@ -282,9 +276,9 @@ def _db_resnet(
|
|
|
282
276
|
arch: str,
|
|
283
277
|
pretrained: bool,
|
|
284
278
|
backbone_fn,
|
|
285
|
-
fpn_layers:
|
|
279
|
+
fpn_layers: list[str],
|
|
286
280
|
pretrained_backbone: bool = True,
|
|
287
|
-
input_shape:
|
|
281
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
288
282
|
**kwargs: Any,
|
|
289
283
|
) -> DBNet:
|
|
290
284
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -328,9 +322,9 @@ def _db_mobilenet(
|
|
|
328
322
|
arch: str,
|
|
329
323
|
pretrained: bool,
|
|
330
324
|
backbone_fn,
|
|
331
|
-
fpn_layers:
|
|
325
|
+
fpn_layers: list[str],
|
|
332
326
|
pretrained_backbone: bool = True,
|
|
333
|
-
input_shape:
|
|
327
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
334
328
|
**kwargs: Any,
|
|
335
329
|
) -> DBNet:
|
|
336
330
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -379,12 +373,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
379
373
|
>>> out = model(input_tensor)
|
|
380
374
|
|
|
381
375
|
Args:
|
|
382
|
-
----
|
|
383
376
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
384
377
|
**kwargs: keyword arguments of the DBNet architecture
|
|
385
378
|
|
|
386
379
|
Returns:
|
|
387
|
-
-------
|
|
388
380
|
text detection architecture
|
|
389
381
|
"""
|
|
390
382
|
return _db_resnet(
|
|
@@ -407,12 +399,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
407
399
|
>>> out = model(input_tensor)
|
|
408
400
|
|
|
409
401
|
Args:
|
|
410
|
-
----
|
|
411
402
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
412
403
|
**kwargs: keyword arguments of the DBNet architecture
|
|
413
404
|
|
|
414
405
|
Returns:
|
|
415
|
-
-------
|
|
416
406
|
text detection architecture
|
|
417
407
|
"""
|
|
418
408
|
return _db_mobilenet(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
|
-
from typing import Dict, List, Tuple, Union
|
|
9
8
|
|
|
10
9
|
import cv2
|
|
11
10
|
import numpy as np
|
|
@@ -23,7 +22,6 @@ class FASTPostProcessor(DetectionPostProcessor):
|
|
|
23
22
|
"""Implements a post processor for FAST model.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
bin_thresh: threshold used to binzarized p_map at inference time
|
|
28
26
|
box_thresh: minimal objectness score to consider a box
|
|
29
27
|
assume_straight_pages: whether the inputs were expected to have horizontal text elements
|
|
@@ -45,11 +43,9 @@ class FASTPostProcessor(DetectionPostProcessor):
|
|
|
45
43
|
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
46
44
|
|
|
47
45
|
Args:
|
|
48
|
-
----
|
|
49
46
|
points: The first parameter.
|
|
50
47
|
|
|
51
48
|
Returns:
|
|
52
|
-
-------
|
|
53
49
|
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
54
50
|
"""
|
|
55
51
|
if not self.assume_straight_pages:
|
|
@@ -94,24 +90,22 @@ class FASTPostProcessor(DetectionPostProcessor):
|
|
|
94
90
|
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
95
91
|
|
|
96
92
|
Args:
|
|
97
|
-
----
|
|
98
93
|
pred: Pred map from differentiable linknet output
|
|
99
94
|
bitmap: Bitmap map computed from pred (binarized)
|
|
100
95
|
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
101
96
|
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
102
97
|
|
|
103
98
|
Returns:
|
|
104
|
-
-------
|
|
105
99
|
np tensor boxes for the bitmap, each box is a 6-element list
|
|
106
100
|
containing x, y, w, h, alpha, score for the box
|
|
107
101
|
"""
|
|
108
102
|
height, width = bitmap.shape[:2]
|
|
109
|
-
boxes:
|
|
103
|
+
boxes: list[np.ndarray | list[float]] = []
|
|
110
104
|
# get contours from connected components on the bitmap
|
|
111
105
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
112
106
|
for contour in contours:
|
|
113
107
|
# Check whether smallest enclosing bounding box is not too small
|
|
114
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
108
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
115
109
|
continue
|
|
116
110
|
# Compute objectness
|
|
117
111
|
if self.assume_straight_pages:
|
|
@@ -158,20 +152,18 @@ class _FAST(BaseModel):
|
|
|
158
152
|
|
|
159
153
|
def build_target(
|
|
160
154
|
self,
|
|
161
|
-
target:
|
|
162
|
-
output_shape:
|
|
155
|
+
target: list[dict[str, np.ndarray]],
|
|
156
|
+
output_shape: tuple[int, int, int],
|
|
163
157
|
channels_last: bool = True,
|
|
164
|
-
) ->
|
|
158
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
165
159
|
"""Build the target, and it's mask to be used from loss computation.
|
|
166
160
|
|
|
167
161
|
Args:
|
|
168
|
-
----
|
|
169
162
|
target: target coming from dataset
|
|
170
163
|
output_shape: shape of the output of the model without batch_size
|
|
171
164
|
channels_last: whether channels are last or not
|
|
172
165
|
|
|
173
166
|
Returns:
|
|
174
|
-
-------
|
|
175
167
|
the new formatted target, mask and shrunken text kernel
|
|
176
168
|
"""
|
|
177
169
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import torch
|
|
@@ -21,7 +22,7 @@ from .base import _FAST, FASTPostProcessor
|
|
|
21
22
|
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
default_cfgs:
|
|
25
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
25
26
|
"fast_tiny": {
|
|
26
27
|
"input_shape": (3, 1024, 1024),
|
|
27
28
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -47,7 +48,6 @@ class FastNeck(nn.Module):
|
|
|
47
48
|
"""Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layers.
|
|
48
49
|
|
|
49
50
|
Args:
|
|
50
|
-
----
|
|
51
51
|
in_channels: number of input channels
|
|
52
52
|
out_channels: number of output channels
|
|
53
53
|
"""
|
|
@@ -77,7 +77,6 @@ class FastHead(nn.Sequential):
|
|
|
77
77
|
"""Head of the FAST architecture
|
|
78
78
|
|
|
79
79
|
Args:
|
|
80
|
-
----
|
|
81
80
|
in_channels: number of input channels
|
|
82
81
|
num_classes: number of output classes
|
|
83
82
|
out_channels: number of output channels
|
|
@@ -91,7 +90,7 @@ class FastHead(nn.Sequential):
|
|
|
91
90
|
out_channels: int = 128,
|
|
92
91
|
dropout: float = 0.1,
|
|
93
92
|
) -> None:
|
|
94
|
-
_layers:
|
|
93
|
+
_layers: list[nn.Module] = [
|
|
95
94
|
FASTConvLayer(in_channels, out_channels, kernel_size=3),
|
|
96
95
|
nn.Dropout(dropout),
|
|
97
96
|
nn.Conv2d(out_channels, num_classes, kernel_size=1, bias=False),
|
|
@@ -104,7 +103,6 @@ class FAST(_FAST, nn.Module):
|
|
|
104
103
|
<https://arxiv.org/pdf/2111.02394.pdf>`_.
|
|
105
104
|
|
|
106
105
|
Args:
|
|
107
|
-
----
|
|
108
106
|
feat extractor: the backbone serving as feature extractor
|
|
109
107
|
bin_thresh: threshold for binarization
|
|
110
108
|
box_thresh: minimal objectness score to consider a box
|
|
@@ -125,8 +123,8 @@ class FAST(_FAST, nn.Module):
|
|
|
125
123
|
pooling_size: int = 4, # different from paper performs better on close text-rich images
|
|
126
124
|
assume_straight_pages: bool = True,
|
|
127
125
|
exportable: bool = False,
|
|
128
|
-
cfg:
|
|
129
|
-
class_names:
|
|
126
|
+
cfg: dict[str, Any] = {},
|
|
127
|
+
class_names: list[str] = [CLASS_NAME],
|
|
130
128
|
) -> None:
|
|
131
129
|
super().__init__()
|
|
132
130
|
self.class_names = class_names
|
|
@@ -175,10 +173,10 @@ class FAST(_FAST, nn.Module):
|
|
|
175
173
|
def forward(
|
|
176
174
|
self,
|
|
177
175
|
x: torch.Tensor,
|
|
178
|
-
target:
|
|
176
|
+
target: list[np.ndarray] | None = None,
|
|
179
177
|
return_model_output: bool = False,
|
|
180
178
|
return_preds: bool = False,
|
|
181
|
-
) ->
|
|
179
|
+
) -> dict[str, torch.Tensor]:
|
|
182
180
|
# Extract feature maps at different stages
|
|
183
181
|
feats = self.feat_extractor(x)
|
|
184
182
|
feats = [feats[str(idx)] for idx in range(len(feats))]
|
|
@@ -186,7 +184,7 @@ class FAST(_FAST, nn.Module):
|
|
|
186
184
|
feat_concat = self.neck(feats)
|
|
187
185
|
logits = F.interpolate(self.prob_head(feat_concat), size=x.shape[-2:], mode="bilinear")
|
|
188
186
|
|
|
189
|
-
out:
|
|
187
|
+
out: dict[str, Any] = {}
|
|
190
188
|
if self.exportable:
|
|
191
189
|
out["logits"] = logits
|
|
192
190
|
return out
|
|
@@ -198,11 +196,16 @@ class FAST(_FAST, nn.Module):
|
|
|
198
196
|
out["out_map"] = prob_map
|
|
199
197
|
|
|
200
198
|
if target is None or return_preds:
|
|
199
|
+
# Disable for torch.compile compatibility
|
|
200
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
201
|
+
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
202
|
+
return [
|
|
203
|
+
dict(zip(self.class_names, preds))
|
|
204
|
+
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
205
|
+
]
|
|
206
|
+
|
|
201
207
|
# Post-process boxes (keep only text predictions)
|
|
202
|
-
out["preds"] =
|
|
203
|
-
dict(zip(self.class_names, preds))
|
|
204
|
-
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
205
|
-
]
|
|
208
|
+
out["preds"] = _postprocess(prob_map)
|
|
206
209
|
|
|
207
210
|
if target is not None:
|
|
208
211
|
loss = self.compute_loss(logits, target)
|
|
@@ -213,19 +216,17 @@ class FAST(_FAST, nn.Module):
|
|
|
213
216
|
def compute_loss(
|
|
214
217
|
self,
|
|
215
218
|
out_map: torch.Tensor,
|
|
216
|
-
target:
|
|
219
|
+
target: list[np.ndarray],
|
|
217
220
|
eps: float = 1e-6,
|
|
218
221
|
) -> torch.Tensor:
|
|
219
222
|
"""Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
|
|
220
223
|
|
|
221
224
|
Args:
|
|
222
|
-
----
|
|
223
225
|
out_map: output feature map of the model of shape (N, num_classes, H, W)
|
|
224
226
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
225
227
|
eps: epsilon factor in dice loss
|
|
226
228
|
|
|
227
229
|
Returns:
|
|
228
|
-
-------
|
|
229
230
|
A loss tensor
|
|
230
231
|
"""
|
|
231
232
|
targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
|
|
@@ -279,15 +280,13 @@ class FAST(_FAST, nn.Module):
|
|
|
279
280
|
return text_loss + kernel_loss
|
|
280
281
|
|
|
281
282
|
|
|
282
|
-
def reparameterize(model:
|
|
283
|
+
def reparameterize(model: FAST | nn.Module) -> FAST:
|
|
283
284
|
"""Fuse batchnorm and conv layers and reparameterize the model
|
|
284
285
|
|
|
285
|
-
|
|
286
|
-
----
|
|
286
|
+
Args:
|
|
287
287
|
model: the FAST model to reparameterize
|
|
288
288
|
|
|
289
289
|
Returns:
|
|
290
|
-
-------
|
|
291
290
|
the reparameterized model
|
|
292
291
|
"""
|
|
293
292
|
last_conv = None
|
|
@@ -324,9 +323,9 @@ def _fast(
|
|
|
324
323
|
arch: str,
|
|
325
324
|
pretrained: bool,
|
|
326
325
|
backbone_fn: Callable[[bool], nn.Module],
|
|
327
|
-
feat_layers:
|
|
326
|
+
feat_layers: list[str],
|
|
328
327
|
pretrained_backbone: bool = True,
|
|
329
|
-
ignore_keys:
|
|
328
|
+
ignore_keys: list[str] | None = None,
|
|
330
329
|
**kwargs: Any,
|
|
331
330
|
) -> FAST:
|
|
332
331
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -366,12 +365,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
366
365
|
>>> out = model(input_tensor)
|
|
367
366
|
|
|
368
367
|
Args:
|
|
369
|
-
----
|
|
370
368
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
371
369
|
**kwargs: keyword arguments of the DBNet architecture
|
|
372
370
|
|
|
373
371
|
Returns:
|
|
374
|
-
-------
|
|
375
372
|
text detection architecture
|
|
376
373
|
"""
|
|
377
374
|
return _fast(
|
|
@@ -395,12 +392,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
395
392
|
>>> out = model(input_tensor)
|
|
396
393
|
|
|
397
394
|
Args:
|
|
398
|
-
----
|
|
399
395
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
400
396
|
**kwargs: keyword arguments of the DBNet architecture
|
|
401
397
|
|
|
402
398
|
Returns:
|
|
403
|
-
-------
|
|
404
399
|
text detection architecture
|
|
405
400
|
"""
|
|
406
401
|
return _fast(
|
|
@@ -424,12 +419,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
424
419
|
>>> out = model(input_tensor)
|
|
425
420
|
|
|
426
421
|
Args:
|
|
427
|
-
----
|
|
428
422
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
429
423
|
**kwargs: keyword arguments of the DBNet architecture
|
|
430
424
|
|
|
431
425
|
Returns:
|
|
432
|
-
-------
|
|
433
426
|
text detection architecture
|
|
434
427
|
"""
|
|
435
428
|
return _fast(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
@@ -23,7 +23,7 @@ from .base import _FAST, FASTPostProcessor
|
|
|
23
23
|
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
default_cfgs:
|
|
26
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
27
27
|
"fast_tiny": {
|
|
28
28
|
"input_shape": (1024, 1024, 3),
|
|
29
29
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -49,7 +49,6 @@ class FastNeck(layers.Layer, NestedObject):
|
|
|
49
49
|
"""Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer.
|
|
50
50
|
|
|
51
51
|
Args:
|
|
52
|
-
----
|
|
53
52
|
in_channels: number of input channels
|
|
54
53
|
out_channels: number of output channels
|
|
55
54
|
"""
|
|
@@ -77,7 +76,6 @@ class FastHead(Sequential):
|
|
|
77
76
|
"""Head of the FAST architecture
|
|
78
77
|
|
|
79
78
|
Args:
|
|
80
|
-
----
|
|
81
79
|
in_channels: number of input channels
|
|
82
80
|
num_classes: number of output classes
|
|
83
81
|
out_channels: number of output channels
|
|
@@ -104,7 +102,6 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
104
102
|
<https://arxiv.org/pdf/2111.02394.pdf>`_.
|
|
105
103
|
|
|
106
104
|
Args:
|
|
107
|
-
----
|
|
108
105
|
feature extractor: the backbone serving as feature extractor
|
|
109
106
|
bin_thresh: threshold for binarization
|
|
110
107
|
box_thresh: minimal objectness score to consider a box
|
|
@@ -116,7 +113,7 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
116
113
|
class_names: list of class names
|
|
117
114
|
"""
|
|
118
115
|
|
|
119
|
-
_children_names:
|
|
116
|
+
_children_names: list[str] = ["feat_extractor", "neck", "head", "postprocessor"]
|
|
120
117
|
|
|
121
118
|
def __init__(
|
|
122
119
|
self,
|
|
@@ -127,8 +124,8 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
127
124
|
pooling_size: int = 4, # different from paper performs better on close text-rich images
|
|
128
125
|
assume_straight_pages: bool = True,
|
|
129
126
|
exportable: bool = False,
|
|
130
|
-
cfg:
|
|
131
|
-
class_names:
|
|
127
|
+
cfg: dict[str, Any] = {},
|
|
128
|
+
class_names: list[str] = [CLASS_NAME],
|
|
132
129
|
) -> None:
|
|
133
130
|
super().__init__()
|
|
134
131
|
self.class_names = class_names
|
|
@@ -159,19 +156,17 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
159
156
|
def compute_loss(
|
|
160
157
|
self,
|
|
161
158
|
out_map: tf.Tensor,
|
|
162
|
-
target:
|
|
159
|
+
target: list[dict[str, np.ndarray]],
|
|
163
160
|
eps: float = 1e-6,
|
|
164
161
|
) -> tf.Tensor:
|
|
165
162
|
"""Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
|
|
166
163
|
|
|
167
164
|
Args:
|
|
168
|
-
----
|
|
169
165
|
out_map: output feature map of the model of shape (N, num_classes, H, W)
|
|
170
166
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
171
167
|
eps: epsilon factor in dice loss
|
|
172
168
|
|
|
173
169
|
Returns:
|
|
174
|
-
-------
|
|
175
170
|
A loss tensor
|
|
176
171
|
"""
|
|
177
172
|
targets = self.build_target(target, out_map.shape[1:], True)
|
|
@@ -222,18 +217,18 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
222
217
|
def call(
|
|
223
218
|
self,
|
|
224
219
|
x: tf.Tensor,
|
|
225
|
-
target:
|
|
220
|
+
target: list[dict[str, np.ndarray]] | None = None,
|
|
226
221
|
return_model_output: bool = False,
|
|
227
222
|
return_preds: bool = False,
|
|
228
223
|
**kwargs: Any,
|
|
229
|
-
) ->
|
|
224
|
+
) -> dict[str, Any]:
|
|
230
225
|
feat_maps = self.feat_extractor(x, **kwargs)
|
|
231
226
|
# Pass through the Neck & Head & Upsample
|
|
232
227
|
feat_concat = self.neck(feat_maps, **kwargs)
|
|
233
228
|
logits: tf.Tensor = self.head(feat_concat, **kwargs)
|
|
234
229
|
logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs)
|
|
235
230
|
|
|
236
|
-
out:
|
|
231
|
+
out: dict[str, tf.Tensor] = {}
|
|
237
232
|
if self.exportable:
|
|
238
233
|
out["logits"] = logits
|
|
239
234
|
return out
|
|
@@ -255,15 +250,14 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
255
250
|
return out
|
|
256
251
|
|
|
257
252
|
|
|
258
|
-
def reparameterize(model:
|
|
253
|
+
def reparameterize(model: FAST | layers.Layer) -> FAST:
|
|
259
254
|
"""Fuse batchnorm and conv layers and reparameterize the model
|
|
260
255
|
|
|
261
256
|
args:
|
|
262
|
-
|
|
257
|
+
|
|
263
258
|
model: the FAST model to reparameterize
|
|
264
259
|
|
|
265
260
|
Returns:
|
|
266
|
-
-------
|
|
267
261
|
the reparameterized model
|
|
268
262
|
"""
|
|
269
263
|
last_conv = None
|
|
@@ -306,9 +300,9 @@ def _fast(
|
|
|
306
300
|
arch: str,
|
|
307
301
|
pretrained: bool,
|
|
308
302
|
backbone_fn,
|
|
309
|
-
feat_layers:
|
|
303
|
+
feat_layers: list[str],
|
|
310
304
|
pretrained_backbone: bool = True,
|
|
311
|
-
input_shape:
|
|
305
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
312
306
|
**kwargs: Any,
|
|
313
307
|
) -> FAST:
|
|
314
308
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -358,12 +352,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
358
352
|
>>> out = model(input_tensor)
|
|
359
353
|
|
|
360
354
|
Args:
|
|
361
|
-
----
|
|
362
355
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
363
356
|
**kwargs: keyword arguments of the DBNet architecture
|
|
364
357
|
|
|
365
358
|
Returns:
|
|
366
|
-
-------
|
|
367
359
|
text detection architecture
|
|
368
360
|
"""
|
|
369
361
|
return _fast(
|
|
@@ -386,12 +378,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
386
378
|
>>> out = model(input_tensor)
|
|
387
379
|
|
|
388
380
|
Args:
|
|
389
|
-
----
|
|
390
381
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
391
382
|
**kwargs: keyword arguments of the DBNet architecture
|
|
392
383
|
|
|
393
384
|
Returns:
|
|
394
|
-
-------
|
|
395
385
|
text detection architecture
|
|
396
386
|
"""
|
|
397
387
|
return _fast(
|
|
@@ -414,12 +404,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
414
404
|
>>> out = model(input_tensor)
|
|
415
405
|
|
|
416
406
|
Args:
|
|
417
|
-
----
|
|
418
407
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
419
408
|
**kwargs: keyword arguments of the DBNet architecture
|
|
420
409
|
|
|
421
410
|
Returns:
|
|
422
|
-
-------
|
|
423
411
|
text detection architecture
|
|
424
412
|
"""
|
|
425
413
|
return _fast(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|