python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import torch
|
|
@@ -22,7 +23,7 @@ from .base import DBPostProcessor, _DBNet
|
|
|
22
23
|
__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
|
|
23
24
|
|
|
24
25
|
|
|
25
|
-
default_cfgs:
|
|
26
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
26
27
|
"db_resnet50": {
|
|
27
28
|
"input_shape": (3, 1024, 1024),
|
|
28
29
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -47,7 +48,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
47
48
|
class FeaturePyramidNetwork(nn.Module):
|
|
48
49
|
def __init__(
|
|
49
50
|
self,
|
|
50
|
-
in_channels:
|
|
51
|
+
in_channels: list[int],
|
|
51
52
|
out_channels: int,
|
|
52
53
|
deform_conv: bool = False,
|
|
53
54
|
) -> None:
|
|
@@ -76,12 +77,12 @@ class FeaturePyramidNetwork(nn.Module):
|
|
|
76
77
|
for idx, chans in enumerate(in_channels)
|
|
77
78
|
])
|
|
78
79
|
|
|
79
|
-
def forward(self, x:
|
|
80
|
+
def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
|
|
80
81
|
if len(x) != len(self.out_branches):
|
|
81
82
|
raise AssertionError
|
|
82
83
|
# Conv1x1 to get the same number of channels
|
|
83
|
-
_x:
|
|
84
|
-
out:
|
|
84
|
+
_x: list[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)]
|
|
85
|
+
out: list[torch.Tensor] = [_x[-1]]
|
|
85
86
|
for t in _x[:-1][::-1]:
|
|
86
87
|
out.append(self.upsample(out[-1]) + t)
|
|
87
88
|
|
|
@@ -96,7 +97,6 @@ class DBNet(_DBNet, nn.Module):
|
|
|
96
97
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
97
98
|
|
|
98
99
|
Args:
|
|
99
|
-
----
|
|
100
100
|
feature extractor: the backbone serving as feature extractor
|
|
101
101
|
head_chans: the number of channels in the head
|
|
102
102
|
deform_conv: whether to use deformable convolution
|
|
@@ -117,8 +117,8 @@ class DBNet(_DBNet, nn.Module):
|
|
|
117
117
|
box_thresh: float = 0.1,
|
|
118
118
|
assume_straight_pages: bool = True,
|
|
119
119
|
exportable: bool = False,
|
|
120
|
-
cfg:
|
|
121
|
-
class_names:
|
|
120
|
+
cfg: dict[str, Any] | None = None,
|
|
121
|
+
class_names: list[str] = [CLASS_NAME],
|
|
122
122
|
) -> None:
|
|
123
123
|
super().__init__()
|
|
124
124
|
self.class_names = class_names
|
|
@@ -182,10 +182,10 @@ class DBNet(_DBNet, nn.Module):
|
|
|
182
182
|
def forward(
|
|
183
183
|
self,
|
|
184
184
|
x: torch.Tensor,
|
|
185
|
-
target:
|
|
185
|
+
target: list[np.ndarray] | None = None,
|
|
186
186
|
return_model_output: bool = False,
|
|
187
187
|
return_preds: bool = False,
|
|
188
|
-
) ->
|
|
188
|
+
) -> dict[str, torch.Tensor]:
|
|
189
189
|
# Extract feature maps at different stages
|
|
190
190
|
feats = self.feat_extractor(x)
|
|
191
191
|
feats = [feats[str(idx)] for idx in range(len(feats))]
|
|
@@ -193,7 +193,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
193
193
|
feat_concat = self.fpn(feats)
|
|
194
194
|
logits = self.prob_head(feat_concat)
|
|
195
195
|
|
|
196
|
-
out:
|
|
196
|
+
out: dict[str, Any] = {}
|
|
197
197
|
if self.exportable:
|
|
198
198
|
out["logits"] = logits
|
|
199
199
|
return out
|
|
@@ -205,11 +205,16 @@ class DBNet(_DBNet, nn.Module):
|
|
|
205
205
|
out["out_map"] = prob_map
|
|
206
206
|
|
|
207
207
|
if target is None or return_preds:
|
|
208
|
+
# Disable for torch.compile compatibility
|
|
209
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
210
|
+
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
211
|
+
return [
|
|
212
|
+
dict(zip(self.class_names, preds))
|
|
213
|
+
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
214
|
+
]
|
|
215
|
+
|
|
208
216
|
# Post-process boxes (keep only text predictions)
|
|
209
|
-
out["preds"] =
|
|
210
|
-
dict(zip(self.class_names, preds))
|
|
211
|
-
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
212
|
-
]
|
|
217
|
+
out["preds"] = _postprocess(prob_map)
|
|
213
218
|
|
|
214
219
|
if target is not None:
|
|
215
220
|
thresh_map = self.thresh_head(feat_concat)
|
|
@@ -222,7 +227,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
222
227
|
self,
|
|
223
228
|
out_map: torch.Tensor,
|
|
224
229
|
thresh_map: torch.Tensor,
|
|
225
|
-
target:
|
|
230
|
+
target: list[np.ndarray],
|
|
226
231
|
gamma: float = 2.0,
|
|
227
232
|
alpha: float = 0.5,
|
|
228
233
|
eps: float = 1e-8,
|
|
@@ -231,7 +236,6 @@ class DBNet(_DBNet, nn.Module):
|
|
|
231
236
|
and a list of masks for each image. From there it computes the loss with the model output
|
|
232
237
|
|
|
233
238
|
Args:
|
|
234
|
-
----
|
|
235
239
|
out_map: output feature map of the model of shape (N, C, H, W)
|
|
236
240
|
thresh_map: threshold map of shape (N, C, H, W)
|
|
237
241
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
@@ -240,7 +244,6 @@ class DBNet(_DBNet, nn.Module):
|
|
|
240
244
|
eps: epsilon factor in dice loss
|
|
241
245
|
|
|
242
246
|
Returns:
|
|
243
|
-
-------
|
|
244
247
|
A loss tensor
|
|
245
248
|
"""
|
|
246
249
|
if gamma < 0:
|
|
@@ -273,7 +276,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
273
276
|
dice_map = torch.softmax(out_map, dim=1)
|
|
274
277
|
else:
|
|
275
278
|
# compute binary map instead
|
|
276
|
-
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
|
|
279
|
+
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
|
|
277
280
|
# Class reduced
|
|
278
281
|
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
|
|
279
282
|
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
|
|
@@ -290,10 +293,10 @@ def _dbnet(
|
|
|
290
293
|
arch: str,
|
|
291
294
|
pretrained: bool,
|
|
292
295
|
backbone_fn: Callable[[bool], nn.Module],
|
|
293
|
-
fpn_layers:
|
|
294
|
-
backbone_submodule:
|
|
296
|
+
fpn_layers: list[str],
|
|
297
|
+
backbone_submodule: str | None = None,
|
|
295
298
|
pretrained_backbone: bool = True,
|
|
296
|
-
ignore_keys:
|
|
299
|
+
ignore_keys: list[str] | None = None,
|
|
297
300
|
**kwargs: Any,
|
|
298
301
|
) -> DBNet:
|
|
299
302
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -341,12 +344,10 @@ def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
341
344
|
>>> out = model(input_tensor)
|
|
342
345
|
|
|
343
346
|
Args:
|
|
344
|
-
----
|
|
345
347
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
346
348
|
**kwargs: keyword arguments of the DBNet architecture
|
|
347
349
|
|
|
348
350
|
Returns:
|
|
349
|
-
-------
|
|
350
351
|
text detection architecture
|
|
351
352
|
"""
|
|
352
353
|
return _dbnet(
|
|
@@ -376,12 +377,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
376
377
|
>>> out = model(input_tensor)
|
|
377
378
|
|
|
378
379
|
Args:
|
|
379
|
-
----
|
|
380
380
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
381
381
|
**kwargs: keyword arguments of the DBNet architecture
|
|
382
382
|
|
|
383
383
|
Returns:
|
|
384
|
-
-------
|
|
385
384
|
text detection architecture
|
|
386
385
|
"""
|
|
387
386
|
return _dbnet(
|
|
@@ -411,12 +410,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
411
410
|
>>> out = model(input_tensor)
|
|
412
411
|
|
|
413
412
|
Args:
|
|
414
|
-
----
|
|
415
413
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
416
414
|
**kwargs: keyword arguments of the DBNet architecture
|
|
417
415
|
|
|
418
416
|
Returns:
|
|
419
|
-
-------
|
|
420
417
|
text detection architecture
|
|
421
418
|
"""
|
|
422
419
|
return _dbnet(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,16 +6,21 @@
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
13
|
-
from tensorflow import
|
|
14
|
-
from tensorflow.keras import layers
|
|
13
|
+
from tensorflow.keras import Model, Sequential, layers, losses
|
|
15
14
|
from tensorflow.keras.applications import ResNet50
|
|
16
15
|
|
|
17
16
|
from doctr.file_utils import CLASS_NAME
|
|
18
|
-
from doctr.models.utils import
|
|
17
|
+
from doctr.models.utils import (
|
|
18
|
+
IntermediateLayerGetter,
|
|
19
|
+
_bf16_to_float32,
|
|
20
|
+
_build_model,
|
|
21
|
+
conv_sequence,
|
|
22
|
+
load_pretrained_params,
|
|
23
|
+
)
|
|
19
24
|
from doctr.utils.repr import NestedObject
|
|
20
25
|
|
|
21
26
|
from ...classification import mobilenet_v3_large
|
|
@@ -24,18 +29,18 @@ from .base import DBPostProcessor, _DBNet
|
|
|
24
29
|
__all__ = ["DBNet", "db_resnet50", "db_mobilenet_v3_large"]
|
|
25
30
|
|
|
26
31
|
|
|
27
|
-
default_cfgs:
|
|
32
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
28
33
|
"db_resnet50": {
|
|
29
34
|
"mean": (0.798, 0.785, 0.772),
|
|
30
35
|
"std": (0.264, 0.2749, 0.287),
|
|
31
36
|
"input_shape": (1024, 1024, 3),
|
|
32
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
37
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0",
|
|
33
38
|
},
|
|
34
39
|
"db_mobilenet_v3_large": {
|
|
35
40
|
"mean": (0.798, 0.785, 0.772),
|
|
36
41
|
"std": (0.264, 0.2749, 0.287),
|
|
37
42
|
"input_shape": (1024, 1024, 3),
|
|
38
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
43
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0",
|
|
39
44
|
},
|
|
40
45
|
}
|
|
41
46
|
|
|
@@ -45,7 +50,6 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
45
50
|
<https://arxiv.org/pdf/1612.03144.pdf>`_.
|
|
46
51
|
|
|
47
52
|
Args:
|
|
48
|
-
----
|
|
49
53
|
channels: number of channel to output
|
|
50
54
|
"""
|
|
51
55
|
|
|
@@ -67,12 +71,10 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
67
71
|
"""Module which performs a 3x3 convolution followed by up-sampling
|
|
68
72
|
|
|
69
73
|
Args:
|
|
70
|
-
----
|
|
71
74
|
channels: number of output channels
|
|
72
75
|
dilation_factor (int): dilation factor to scale the convolution output before concatenation
|
|
73
76
|
|
|
74
77
|
Returns:
|
|
75
|
-
-------
|
|
76
78
|
a keras.layers.Layer object, wrapping these operations in a sequential module
|
|
77
79
|
|
|
78
80
|
"""
|
|
@@ -81,7 +83,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
81
83
|
if dilation_factor > 1:
|
|
82
84
|
_layers.append(layers.UpSampling2D(size=(dilation_factor, dilation_factor), interpolation="nearest"))
|
|
83
85
|
|
|
84
|
-
module =
|
|
86
|
+
module = Sequential(_layers)
|
|
85
87
|
|
|
86
88
|
return module
|
|
87
89
|
|
|
@@ -90,7 +92,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
90
92
|
|
|
91
93
|
def call(
|
|
92
94
|
self,
|
|
93
|
-
x:
|
|
95
|
+
x: list[tf.Tensor],
|
|
94
96
|
**kwargs: Any,
|
|
95
97
|
) -> tf.Tensor:
|
|
96
98
|
# Channel mapping
|
|
@@ -104,12 +106,11 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
104
106
|
return layers.concatenate(results)
|
|
105
107
|
|
|
106
108
|
|
|
107
|
-
class DBNet(_DBNet,
|
|
109
|
+
class DBNet(_DBNet, Model, NestedObject):
|
|
108
110
|
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
|
|
109
111
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
110
112
|
|
|
111
113
|
Args:
|
|
112
|
-
----
|
|
113
114
|
feature extractor: the backbone serving as feature extractor
|
|
114
115
|
fpn_channels: number of channels each extracted feature maps is mapped to
|
|
115
116
|
bin_thresh: threshold for binarization
|
|
@@ -120,7 +121,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
120
121
|
class_names: list of class names
|
|
121
122
|
"""
|
|
122
123
|
|
|
123
|
-
_children_names:
|
|
124
|
+
_children_names: list[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"]
|
|
124
125
|
|
|
125
126
|
def __init__(
|
|
126
127
|
self,
|
|
@@ -130,8 +131,8 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
130
131
|
box_thresh: float = 0.1,
|
|
131
132
|
assume_straight_pages: bool = True,
|
|
132
133
|
exportable: bool = False,
|
|
133
|
-
cfg:
|
|
134
|
-
class_names:
|
|
134
|
+
cfg: dict[str, Any] | None = None,
|
|
135
|
+
class_names: list[str] = [CLASS_NAME],
|
|
135
136
|
) -> None:
|
|
136
137
|
super().__init__()
|
|
137
138
|
self.class_names = class_names
|
|
@@ -147,14 +148,14 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
147
148
|
_inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
|
|
148
149
|
output_shape = tuple(self.fpn(_inputs).shape)
|
|
149
150
|
|
|
150
|
-
self.probability_head =
|
|
151
|
+
self.probability_head = Sequential([
|
|
151
152
|
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
|
|
152
153
|
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
|
|
153
154
|
layers.BatchNormalization(),
|
|
154
155
|
layers.Activation("relu"),
|
|
155
156
|
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
|
|
156
157
|
])
|
|
157
|
-
self.threshold_head =
|
|
158
|
+
self.threshold_head = Sequential([
|
|
158
159
|
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
|
|
159
160
|
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
|
|
160
161
|
layers.BatchNormalization(),
|
|
@@ -170,7 +171,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
170
171
|
self,
|
|
171
172
|
out_map: tf.Tensor,
|
|
172
173
|
thresh_map: tf.Tensor,
|
|
173
|
-
target:
|
|
174
|
+
target: list[dict[str, np.ndarray]],
|
|
174
175
|
gamma: float = 2.0,
|
|
175
176
|
alpha: float = 0.5,
|
|
176
177
|
eps: float = 1e-8,
|
|
@@ -179,7 +180,6 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
179
180
|
and a list of masks for each image. From there it computes the loss with the model output
|
|
180
181
|
|
|
181
182
|
Args:
|
|
182
|
-
----
|
|
183
183
|
out_map: output feature map of the model of shape (N, H, W, C)
|
|
184
184
|
thresh_map: threshold map of shape (N, H, W, C)
|
|
185
185
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
@@ -188,7 +188,6 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
188
188
|
eps: epsilon factor in dice loss
|
|
189
189
|
|
|
190
190
|
Returns:
|
|
191
|
-
-------
|
|
192
191
|
A loss tensor
|
|
193
192
|
"""
|
|
194
193
|
if gamma < 0:
|
|
@@ -206,7 +205,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
206
205
|
|
|
207
206
|
# Focal loss
|
|
208
207
|
focal_scale = 10.0
|
|
209
|
-
bce_loss =
|
|
208
|
+
bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
|
|
210
209
|
|
|
211
210
|
# Convert logits to prob, compute gamma factor
|
|
212
211
|
p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
|
|
@@ -241,16 +240,16 @@ class DBNet(_DBNet, keras.Model, NestedObject):
|
|
|
241
240
|
def call(
|
|
242
241
|
self,
|
|
243
242
|
x: tf.Tensor,
|
|
244
|
-
target:
|
|
243
|
+
target: list[dict[str, np.ndarray]] | None = None,
|
|
245
244
|
return_model_output: bool = False,
|
|
246
245
|
return_preds: bool = False,
|
|
247
246
|
**kwargs: Any,
|
|
248
|
-
) ->
|
|
247
|
+
) -> dict[str, Any]:
|
|
249
248
|
feat_maps = self.feat_extractor(x, **kwargs)
|
|
250
249
|
feat_concat = self.fpn(feat_maps, **kwargs)
|
|
251
250
|
logits = self.probability_head(feat_concat, **kwargs)
|
|
252
251
|
|
|
253
|
-
out:
|
|
252
|
+
out: dict[str, tf.Tensor] = {}
|
|
254
253
|
if self.exportable:
|
|
255
254
|
out["logits"] = logits
|
|
256
255
|
return out
|
|
@@ -277,9 +276,9 @@ def _db_resnet(
|
|
|
277
276
|
arch: str,
|
|
278
277
|
pretrained: bool,
|
|
279
278
|
backbone_fn,
|
|
280
|
-
fpn_layers:
|
|
279
|
+
fpn_layers: list[str],
|
|
281
280
|
pretrained_backbone: bool = True,
|
|
282
|
-
input_shape:
|
|
281
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
283
282
|
**kwargs: Any,
|
|
284
283
|
) -> DBNet:
|
|
285
284
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -305,9 +304,16 @@ def _db_resnet(
|
|
|
305
304
|
|
|
306
305
|
# Build the model
|
|
307
306
|
model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
|
|
307
|
+
_build_model(model)
|
|
308
|
+
|
|
308
309
|
# Load pretrained parameters
|
|
309
310
|
if pretrained:
|
|
310
|
-
|
|
311
|
+
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
312
|
+
load_pretrained_params(
|
|
313
|
+
model,
|
|
314
|
+
_cfg["url"],
|
|
315
|
+
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
316
|
+
)
|
|
311
317
|
|
|
312
318
|
return model
|
|
313
319
|
|
|
@@ -316,9 +322,9 @@ def _db_mobilenet(
|
|
|
316
322
|
arch: str,
|
|
317
323
|
pretrained: bool,
|
|
318
324
|
backbone_fn,
|
|
319
|
-
fpn_layers:
|
|
325
|
+
fpn_layers: list[str],
|
|
320
326
|
pretrained_backbone: bool = True,
|
|
321
|
-
input_shape:
|
|
327
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
322
328
|
**kwargs: Any,
|
|
323
329
|
) -> DBNet:
|
|
324
330
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -326,6 +332,10 @@ def _db_mobilenet(
|
|
|
326
332
|
# Patch the config
|
|
327
333
|
_cfg = deepcopy(default_cfgs[arch])
|
|
328
334
|
_cfg["input_shape"] = input_shape or _cfg["input_shape"]
|
|
335
|
+
if not kwargs.get("class_names", None):
|
|
336
|
+
kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
|
|
337
|
+
else:
|
|
338
|
+
kwargs["class_names"] = sorted(kwargs["class_names"])
|
|
329
339
|
|
|
330
340
|
# Feature extractor
|
|
331
341
|
feat_extractor = IntermediateLayerGetter(
|
|
@@ -339,9 +349,15 @@ def _db_mobilenet(
|
|
|
339
349
|
|
|
340
350
|
# Build the model
|
|
341
351
|
model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
|
|
352
|
+
_build_model(model)
|
|
342
353
|
# Load pretrained parameters
|
|
343
354
|
if pretrained:
|
|
344
|
-
|
|
355
|
+
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
356
|
+
load_pretrained_params(
|
|
357
|
+
model,
|
|
358
|
+
_cfg["url"],
|
|
359
|
+
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
360
|
+
)
|
|
345
361
|
|
|
346
362
|
return model
|
|
347
363
|
|
|
@@ -357,12 +373,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
357
373
|
>>> out = model(input_tensor)
|
|
358
374
|
|
|
359
375
|
Args:
|
|
360
|
-
----
|
|
361
376
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
362
377
|
**kwargs: keyword arguments of the DBNet architecture
|
|
363
378
|
|
|
364
379
|
Returns:
|
|
365
|
-
-------
|
|
366
380
|
text detection architecture
|
|
367
381
|
"""
|
|
368
382
|
return _db_resnet(
|
|
@@ -385,12 +399,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
385
399
|
>>> out = model(input_tensor)
|
|
386
400
|
|
|
387
401
|
Args:
|
|
388
|
-
----
|
|
389
402
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
390
403
|
**kwargs: keyword arguments of the DBNet architecture
|
|
391
404
|
|
|
392
405
|
Returns:
|
|
393
|
-
-------
|
|
394
406
|
text detection architecture
|
|
395
407
|
"""
|
|
396
408
|
return _db_mobilenet(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
|
-
from typing import Dict, List, Tuple, Union
|
|
9
8
|
|
|
10
9
|
import cv2
|
|
11
10
|
import numpy as np
|
|
@@ -23,7 +22,6 @@ class FASTPostProcessor(DetectionPostProcessor):
|
|
|
23
22
|
"""Implements a post processor for FAST model.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
bin_thresh: threshold used to binzarized p_map at inference time
|
|
28
26
|
box_thresh: minimal objectness score to consider a box
|
|
29
27
|
assume_straight_pages: whether the inputs were expected to have horizontal text elements
|
|
@@ -45,11 +43,9 @@ class FASTPostProcessor(DetectionPostProcessor):
|
|
|
45
43
|
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
46
44
|
|
|
47
45
|
Args:
|
|
48
|
-
----
|
|
49
46
|
points: The first parameter.
|
|
50
47
|
|
|
51
48
|
Returns:
|
|
52
|
-
-------
|
|
53
49
|
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
54
50
|
"""
|
|
55
51
|
if not self.assume_straight_pages:
|
|
@@ -94,24 +90,22 @@ class FASTPostProcessor(DetectionPostProcessor):
|
|
|
94
90
|
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
95
91
|
|
|
96
92
|
Args:
|
|
97
|
-
----
|
|
98
93
|
pred: Pred map from differentiable linknet output
|
|
99
94
|
bitmap: Bitmap map computed from pred (binarized)
|
|
100
95
|
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
101
96
|
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
102
97
|
|
|
103
98
|
Returns:
|
|
104
|
-
-------
|
|
105
99
|
np tensor boxes for the bitmap, each box is a 6-element list
|
|
106
100
|
containing x, y, w, h, alpha, score for the box
|
|
107
101
|
"""
|
|
108
102
|
height, width = bitmap.shape[:2]
|
|
109
|
-
boxes:
|
|
103
|
+
boxes: list[np.ndarray | list[float]] = []
|
|
110
104
|
# get contours from connected components on the bitmap
|
|
111
105
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
112
106
|
for contour in contours:
|
|
113
107
|
# Check whether smallest enclosing bounding box is not too small
|
|
114
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
108
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
115
109
|
continue
|
|
116
110
|
# Compute objectness
|
|
117
111
|
if self.assume_straight_pages:
|
|
@@ -158,20 +152,18 @@ class _FAST(BaseModel):
|
|
|
158
152
|
|
|
159
153
|
def build_target(
|
|
160
154
|
self,
|
|
161
|
-
target:
|
|
162
|
-
output_shape:
|
|
155
|
+
target: list[dict[str, np.ndarray]],
|
|
156
|
+
output_shape: tuple[int, int, int],
|
|
163
157
|
channels_last: bool = True,
|
|
164
|
-
) ->
|
|
158
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
165
159
|
"""Build the target, and it's mask to be used from loss computation.
|
|
166
160
|
|
|
167
161
|
Args:
|
|
168
|
-
----
|
|
169
162
|
target: target coming from dataset
|
|
170
163
|
output_shape: shape of the output of the model without batch_size
|
|
171
164
|
channels_last: whether channels are last or not
|
|
172
165
|
|
|
173
166
|
Returns:
|
|
174
|
-
-------
|
|
175
167
|
the new formatted target, mask and shrunken text kernel
|
|
176
168
|
"""
|
|
177
169
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|