python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import torch
|
|
@@ -21,7 +22,7 @@ from .base import _FAST, FASTPostProcessor
|
|
|
21
22
|
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
default_cfgs:
|
|
25
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
25
26
|
"fast_tiny": {
|
|
26
27
|
"input_shape": (3, 1024, 1024),
|
|
27
28
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -47,7 +48,6 @@ class FastNeck(nn.Module):
|
|
|
47
48
|
"""Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layers.
|
|
48
49
|
|
|
49
50
|
Args:
|
|
50
|
-
----
|
|
51
51
|
in_channels: number of input channels
|
|
52
52
|
out_channels: number of output channels
|
|
53
53
|
"""
|
|
@@ -77,7 +77,6 @@ class FastHead(nn.Sequential):
|
|
|
77
77
|
"""Head of the FAST architecture
|
|
78
78
|
|
|
79
79
|
Args:
|
|
80
|
-
----
|
|
81
80
|
in_channels: number of input channels
|
|
82
81
|
num_classes: number of output classes
|
|
83
82
|
out_channels: number of output channels
|
|
@@ -91,7 +90,7 @@ class FastHead(nn.Sequential):
|
|
|
91
90
|
out_channels: int = 128,
|
|
92
91
|
dropout: float = 0.1,
|
|
93
92
|
) -> None:
|
|
94
|
-
_layers:
|
|
93
|
+
_layers: list[nn.Module] = [
|
|
95
94
|
FASTConvLayer(in_channels, out_channels, kernel_size=3),
|
|
96
95
|
nn.Dropout(dropout),
|
|
97
96
|
nn.Conv2d(out_channels, num_classes, kernel_size=1, bias=False),
|
|
@@ -104,7 +103,6 @@ class FAST(_FAST, nn.Module):
|
|
|
104
103
|
<https://arxiv.org/pdf/2111.02394.pdf>`_.
|
|
105
104
|
|
|
106
105
|
Args:
|
|
107
|
-
----
|
|
108
106
|
feat extractor: the backbone serving as feature extractor
|
|
109
107
|
bin_thresh: threshold for binarization
|
|
110
108
|
box_thresh: minimal objectness score to consider a box
|
|
@@ -125,8 +123,8 @@ class FAST(_FAST, nn.Module):
|
|
|
125
123
|
pooling_size: int = 4, # different from paper performs better on close text-rich images
|
|
126
124
|
assume_straight_pages: bool = True,
|
|
127
125
|
exportable: bool = False,
|
|
128
|
-
cfg:
|
|
129
|
-
class_names:
|
|
126
|
+
cfg: dict[str, Any] = {},
|
|
127
|
+
class_names: list[str] = [CLASS_NAME],
|
|
130
128
|
) -> None:
|
|
131
129
|
super().__init__()
|
|
132
130
|
self.class_names = class_names
|
|
@@ -175,10 +173,10 @@ class FAST(_FAST, nn.Module):
|
|
|
175
173
|
def forward(
|
|
176
174
|
self,
|
|
177
175
|
x: torch.Tensor,
|
|
178
|
-
target:
|
|
176
|
+
target: list[np.ndarray] | None = None,
|
|
179
177
|
return_model_output: bool = False,
|
|
180
178
|
return_preds: bool = False,
|
|
181
|
-
) ->
|
|
179
|
+
) -> dict[str, torch.Tensor]:
|
|
182
180
|
# Extract feature maps at different stages
|
|
183
181
|
feats = self.feat_extractor(x)
|
|
184
182
|
feats = [feats[str(idx)] for idx in range(len(feats))]
|
|
@@ -186,7 +184,7 @@ class FAST(_FAST, nn.Module):
|
|
|
186
184
|
feat_concat = self.neck(feats)
|
|
187
185
|
logits = F.interpolate(self.prob_head(feat_concat), size=x.shape[-2:], mode="bilinear")
|
|
188
186
|
|
|
189
|
-
out:
|
|
187
|
+
out: dict[str, Any] = {}
|
|
190
188
|
if self.exportable:
|
|
191
189
|
out["logits"] = logits
|
|
192
190
|
return out
|
|
@@ -198,11 +196,16 @@ class FAST(_FAST, nn.Module):
|
|
|
198
196
|
out["out_map"] = prob_map
|
|
199
197
|
|
|
200
198
|
if target is None or return_preds:
|
|
199
|
+
# Disable for torch.compile compatibility
|
|
200
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
201
|
+
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
202
|
+
return [
|
|
203
|
+
dict(zip(self.class_names, preds))
|
|
204
|
+
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
205
|
+
]
|
|
206
|
+
|
|
201
207
|
# Post-process boxes (keep only text predictions)
|
|
202
|
-
out["preds"] =
|
|
203
|
-
dict(zip(self.class_names, preds))
|
|
204
|
-
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
205
|
-
]
|
|
208
|
+
out["preds"] = _postprocess(prob_map)
|
|
206
209
|
|
|
207
210
|
if target is not None:
|
|
208
211
|
loss = self.compute_loss(logits, target)
|
|
@@ -213,19 +216,17 @@ class FAST(_FAST, nn.Module):
|
|
|
213
216
|
def compute_loss(
|
|
214
217
|
self,
|
|
215
218
|
out_map: torch.Tensor,
|
|
216
|
-
target:
|
|
219
|
+
target: list[np.ndarray],
|
|
217
220
|
eps: float = 1e-6,
|
|
218
221
|
) -> torch.Tensor:
|
|
219
222
|
"""Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
|
|
220
223
|
|
|
221
224
|
Args:
|
|
222
|
-
----
|
|
223
225
|
out_map: output feature map of the model of shape (N, num_classes, H, W)
|
|
224
226
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
225
227
|
eps: epsilon factor in dice loss
|
|
226
228
|
|
|
227
229
|
Returns:
|
|
228
|
-
-------
|
|
229
230
|
A loss tensor
|
|
230
231
|
"""
|
|
231
232
|
targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
|
|
@@ -279,15 +280,13 @@ class FAST(_FAST, nn.Module):
|
|
|
279
280
|
return text_loss + kernel_loss
|
|
280
281
|
|
|
281
282
|
|
|
282
|
-
def reparameterize(model:
|
|
283
|
+
def reparameterize(model: FAST | nn.Module) -> FAST:
|
|
283
284
|
"""Fuse batchnorm and conv layers and reparameterize the model
|
|
284
285
|
|
|
285
|
-
|
|
286
|
-
----
|
|
286
|
+
Args:
|
|
287
287
|
model: the FAST model to reparameterize
|
|
288
288
|
|
|
289
289
|
Returns:
|
|
290
|
-
-------
|
|
291
290
|
the reparameterized model
|
|
292
291
|
"""
|
|
293
292
|
last_conv = None
|
|
@@ -324,9 +323,9 @@ def _fast(
|
|
|
324
323
|
arch: str,
|
|
325
324
|
pretrained: bool,
|
|
326
325
|
backbone_fn: Callable[[bool], nn.Module],
|
|
327
|
-
feat_layers:
|
|
326
|
+
feat_layers: list[str],
|
|
328
327
|
pretrained_backbone: bool = True,
|
|
329
|
-
ignore_keys:
|
|
328
|
+
ignore_keys: list[str] | None = None,
|
|
330
329
|
**kwargs: Any,
|
|
331
330
|
) -> FAST:
|
|
332
331
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -366,12 +365,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
366
365
|
>>> out = model(input_tensor)
|
|
367
366
|
|
|
368
367
|
Args:
|
|
369
|
-
----
|
|
370
368
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
371
369
|
**kwargs: keyword arguments of the DBNet architecture
|
|
372
370
|
|
|
373
371
|
Returns:
|
|
374
|
-
-------
|
|
375
372
|
text detection architecture
|
|
376
373
|
"""
|
|
377
374
|
return _fast(
|
|
@@ -395,12 +392,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
395
392
|
>>> out = model(input_tensor)
|
|
396
393
|
|
|
397
394
|
Args:
|
|
398
|
-
----
|
|
399
395
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
400
396
|
**kwargs: keyword arguments of the DBNet architecture
|
|
401
397
|
|
|
402
398
|
Returns:
|
|
403
|
-
-------
|
|
404
399
|
text detection architecture
|
|
405
400
|
"""
|
|
406
401
|
return _fast(
|
|
@@ -424,12 +419,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
424
419
|
>>> out = model(input_tensor)
|
|
425
420
|
|
|
426
421
|
Args:
|
|
427
|
-
----
|
|
428
422
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
429
423
|
**kwargs: keyword arguments of the DBNet architecture
|
|
430
424
|
|
|
431
425
|
Returns:
|
|
432
|
-
-------
|
|
433
426
|
text detection architecture
|
|
434
427
|
"""
|
|
435
428
|
return _fast(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,15 +6,14 @@
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
13
|
-
from tensorflow import
|
|
14
|
-
from tensorflow.keras import Sequential, layers
|
|
13
|
+
from tensorflow.keras import Model, Sequential, layers
|
|
15
14
|
|
|
16
15
|
from doctr.file_utils import CLASS_NAME
|
|
17
|
-
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params
|
|
16
|
+
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params
|
|
18
17
|
from doctr.utils.repr import NestedObject
|
|
19
18
|
|
|
20
19
|
from ...classification import textnet_base, textnet_small, textnet_tiny
|
|
@@ -24,24 +23,24 @@ from .base import _FAST, FASTPostProcessor
|
|
|
24
23
|
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
|
|
25
24
|
|
|
26
25
|
|
|
27
|
-
default_cfgs:
|
|
26
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
28
27
|
"fast_tiny": {
|
|
29
28
|
"input_shape": (1024, 1024, 3),
|
|
30
29
|
"mean": (0.798, 0.785, 0.772),
|
|
31
30
|
"std": (0.264, 0.2749, 0.287),
|
|
32
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
31
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0",
|
|
33
32
|
},
|
|
34
33
|
"fast_small": {
|
|
35
34
|
"input_shape": (1024, 1024, 3),
|
|
36
35
|
"mean": (0.798, 0.785, 0.772),
|
|
37
36
|
"std": (0.264, 0.2749, 0.287),
|
|
38
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
37
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0",
|
|
39
38
|
},
|
|
40
39
|
"fast_base": {
|
|
41
40
|
"input_shape": (1024, 1024, 3),
|
|
42
41
|
"mean": (0.798, 0.785, 0.772),
|
|
43
42
|
"std": (0.264, 0.2749, 0.287),
|
|
44
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
43
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0",
|
|
45
44
|
},
|
|
46
45
|
}
|
|
47
46
|
|
|
@@ -50,7 +49,6 @@ class FastNeck(layers.Layer, NestedObject):
|
|
|
50
49
|
"""Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer.
|
|
51
50
|
|
|
52
51
|
Args:
|
|
53
|
-
----
|
|
54
52
|
in_channels: number of input channels
|
|
55
53
|
out_channels: number of output channels
|
|
56
54
|
"""
|
|
@@ -78,7 +76,6 @@ class FastHead(Sequential):
|
|
|
78
76
|
"""Head of the FAST architecture
|
|
79
77
|
|
|
80
78
|
Args:
|
|
81
|
-
----
|
|
82
79
|
in_channels: number of input channels
|
|
83
80
|
num_classes: number of output classes
|
|
84
81
|
out_channels: number of output channels
|
|
@@ -100,12 +97,11 @@ class FastHead(Sequential):
|
|
|
100
97
|
super().__init__(_layers)
|
|
101
98
|
|
|
102
99
|
|
|
103
|
-
class FAST(_FAST,
|
|
100
|
+
class FAST(_FAST, Model, NestedObject):
|
|
104
101
|
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
105
102
|
<https://arxiv.org/pdf/2111.02394.pdf>`_.
|
|
106
103
|
|
|
107
104
|
Args:
|
|
108
|
-
----
|
|
109
105
|
feature extractor: the backbone serving as feature extractor
|
|
110
106
|
bin_thresh: threshold for binarization
|
|
111
107
|
box_thresh: minimal objectness score to consider a box
|
|
@@ -117,7 +113,7 @@ class FAST(_FAST, keras.Model, NestedObject):
|
|
|
117
113
|
class_names: list of class names
|
|
118
114
|
"""
|
|
119
115
|
|
|
120
|
-
_children_names:
|
|
116
|
+
_children_names: list[str] = ["feat_extractor", "neck", "head", "postprocessor"]
|
|
121
117
|
|
|
122
118
|
def __init__(
|
|
123
119
|
self,
|
|
@@ -128,8 +124,8 @@ class FAST(_FAST, keras.Model, NestedObject):
|
|
|
128
124
|
pooling_size: int = 4, # different from paper performs better on close text-rich images
|
|
129
125
|
assume_straight_pages: bool = True,
|
|
130
126
|
exportable: bool = False,
|
|
131
|
-
cfg:
|
|
132
|
-
class_names:
|
|
127
|
+
cfg: dict[str, Any] = {},
|
|
128
|
+
class_names: list[str] = [CLASS_NAME],
|
|
133
129
|
) -> None:
|
|
134
130
|
super().__init__()
|
|
135
131
|
self.class_names = class_names
|
|
@@ -160,19 +156,17 @@ class FAST(_FAST, keras.Model, NestedObject):
|
|
|
160
156
|
def compute_loss(
|
|
161
157
|
self,
|
|
162
158
|
out_map: tf.Tensor,
|
|
163
|
-
target:
|
|
159
|
+
target: list[dict[str, np.ndarray]],
|
|
164
160
|
eps: float = 1e-6,
|
|
165
161
|
) -> tf.Tensor:
|
|
166
162
|
"""Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
|
|
167
163
|
|
|
168
164
|
Args:
|
|
169
|
-
----
|
|
170
165
|
out_map: output feature map of the model of shape (N, num_classes, H, W)
|
|
171
166
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
172
167
|
eps: epsilon factor in dice loss
|
|
173
168
|
|
|
174
169
|
Returns:
|
|
175
|
-
-------
|
|
176
170
|
A loss tensor
|
|
177
171
|
"""
|
|
178
172
|
targets = self.build_target(target, out_map.shape[1:], True)
|
|
@@ -223,18 +217,18 @@ class FAST(_FAST, keras.Model, NestedObject):
|
|
|
223
217
|
def call(
|
|
224
218
|
self,
|
|
225
219
|
x: tf.Tensor,
|
|
226
|
-
target:
|
|
220
|
+
target: list[dict[str, np.ndarray]] | None = None,
|
|
227
221
|
return_model_output: bool = False,
|
|
228
222
|
return_preds: bool = False,
|
|
229
223
|
**kwargs: Any,
|
|
230
|
-
) ->
|
|
224
|
+
) -> dict[str, Any]:
|
|
231
225
|
feat_maps = self.feat_extractor(x, **kwargs)
|
|
232
226
|
# Pass through the Neck & Head & Upsample
|
|
233
227
|
feat_concat = self.neck(feat_maps, **kwargs)
|
|
234
228
|
logits: tf.Tensor = self.head(feat_concat, **kwargs)
|
|
235
229
|
logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs)
|
|
236
230
|
|
|
237
|
-
out:
|
|
231
|
+
out: dict[str, tf.Tensor] = {}
|
|
238
232
|
if self.exportable:
|
|
239
233
|
out["logits"] = logits
|
|
240
234
|
return out
|
|
@@ -256,15 +250,14 @@ class FAST(_FAST, keras.Model, NestedObject):
|
|
|
256
250
|
return out
|
|
257
251
|
|
|
258
252
|
|
|
259
|
-
def reparameterize(model:
|
|
253
|
+
def reparameterize(model: FAST | layers.Layer) -> FAST:
|
|
260
254
|
"""Fuse batchnorm and conv layers and reparameterize the model
|
|
261
255
|
|
|
262
256
|
args:
|
|
263
|
-
|
|
257
|
+
|
|
264
258
|
model: the FAST model to reparameterize
|
|
265
259
|
|
|
266
260
|
Returns:
|
|
267
|
-
-------
|
|
268
261
|
the reparameterized model
|
|
269
262
|
"""
|
|
270
263
|
last_conv = None
|
|
@@ -307,9 +300,9 @@ def _fast(
|
|
|
307
300
|
arch: str,
|
|
308
301
|
pretrained: bool,
|
|
309
302
|
backbone_fn,
|
|
310
|
-
feat_layers:
|
|
303
|
+
feat_layers: list[str],
|
|
311
304
|
pretrained_backbone: bool = True,
|
|
312
|
-
input_shape:
|
|
305
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
313
306
|
**kwargs: Any,
|
|
314
307
|
) -> FAST:
|
|
315
308
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -334,12 +327,16 @@ def _fast(
|
|
|
334
327
|
|
|
335
328
|
# Build the model
|
|
336
329
|
model = FAST(feat_extractor, cfg=_cfg, **kwargs)
|
|
330
|
+
_build_model(model)
|
|
331
|
+
|
|
337
332
|
# Load pretrained parameters
|
|
338
333
|
if pretrained:
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
334
|
+
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
335
|
+
load_pretrained_params(
|
|
336
|
+
model,
|
|
337
|
+
_cfg["url"],
|
|
338
|
+
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
339
|
+
)
|
|
343
340
|
|
|
344
341
|
return model
|
|
345
342
|
|
|
@@ -355,12 +352,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
355
352
|
>>> out = model(input_tensor)
|
|
356
353
|
|
|
357
354
|
Args:
|
|
358
|
-
----
|
|
359
355
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
360
356
|
**kwargs: keyword arguments of the DBNet architecture
|
|
361
357
|
|
|
362
358
|
Returns:
|
|
363
|
-
-------
|
|
364
359
|
text detection architecture
|
|
365
360
|
"""
|
|
366
361
|
return _fast(
|
|
@@ -383,12 +378,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
383
378
|
>>> out = model(input_tensor)
|
|
384
379
|
|
|
385
380
|
Args:
|
|
386
|
-
----
|
|
387
381
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
388
382
|
**kwargs: keyword arguments of the DBNet architecture
|
|
389
383
|
|
|
390
384
|
Returns:
|
|
391
|
-
-------
|
|
392
385
|
text detection architecture
|
|
393
386
|
"""
|
|
394
387
|
return _fast(
|
|
@@ -411,12 +404,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
411
404
|
>>> out = model(input_tensor)
|
|
412
405
|
|
|
413
406
|
Args:
|
|
414
|
-
----
|
|
415
407
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
416
408
|
**kwargs: keyword arguments of the DBNet architecture
|
|
417
409
|
|
|
418
410
|
Returns:
|
|
419
|
-
-------
|
|
420
411
|
text detection architecture
|
|
421
412
|
"""
|
|
422
413
|
return _fast(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
|
-
from typing import Dict, List, Tuple, Union
|
|
9
8
|
|
|
10
9
|
import cv2
|
|
11
10
|
import numpy as np
|
|
@@ -23,7 +22,6 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
23
22
|
"""Implements a post processor for LinkNet model.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
bin_thresh: threshold used to binzarized p_map at inference time
|
|
28
26
|
box_thresh: minimal objectness score to consider a box
|
|
29
27
|
assume_straight_pages: whether the inputs were expected to have horizontal text elements
|
|
@@ -45,11 +43,9 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
45
43
|
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
46
44
|
|
|
47
45
|
Args:
|
|
48
|
-
----
|
|
49
46
|
points: The first parameter.
|
|
50
47
|
|
|
51
48
|
Returns:
|
|
52
|
-
-------
|
|
53
49
|
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
54
50
|
"""
|
|
55
51
|
if not self.assume_straight_pages:
|
|
@@ -94,24 +90,22 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
94
90
|
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
95
91
|
|
|
96
92
|
Args:
|
|
97
|
-
----
|
|
98
93
|
pred: Pred map from differentiable linknet output
|
|
99
94
|
bitmap: Bitmap map computed from pred (binarized)
|
|
100
95
|
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
101
96
|
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
102
97
|
|
|
103
98
|
Returns:
|
|
104
|
-
-------
|
|
105
99
|
np tensor boxes for the bitmap, each box is a 6-element list
|
|
106
100
|
containing x, y, w, h, alpha, score for the box
|
|
107
101
|
"""
|
|
108
102
|
height, width = bitmap.shape[:2]
|
|
109
|
-
boxes:
|
|
103
|
+
boxes: list[np.ndarray | list[float]] = []
|
|
110
104
|
# get contours from connected components on the bitmap
|
|
111
105
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
112
106
|
for contour in contours:
|
|
113
107
|
# Check whether smallest enclosing bounding box is not too small
|
|
114
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
108
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
115
109
|
continue
|
|
116
110
|
# Compute objectness
|
|
117
111
|
if self.assume_straight_pages:
|
|
@@ -152,7 +146,6 @@ class _LinkNet(BaseModel):
|
|
|
152
146
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
153
147
|
|
|
154
148
|
Args:
|
|
155
|
-
----
|
|
156
149
|
out_chan: number of channels for the output
|
|
157
150
|
"""
|
|
158
151
|
|
|
@@ -162,20 +155,18 @@ class _LinkNet(BaseModel):
|
|
|
162
155
|
|
|
163
156
|
def build_target(
|
|
164
157
|
self,
|
|
165
|
-
target:
|
|
166
|
-
output_shape:
|
|
158
|
+
target: list[dict[str, np.ndarray]],
|
|
159
|
+
output_shape: tuple[int, int, int],
|
|
167
160
|
channels_last: bool = True,
|
|
168
|
-
) ->
|
|
161
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
169
162
|
"""Build the target, and it's mask to be used from loss computation.
|
|
170
163
|
|
|
171
164
|
Args:
|
|
172
|
-
----
|
|
173
165
|
target: target coming from dataset
|
|
174
166
|
output_shape: shape of the output of the model without batch_size
|
|
175
167
|
channels_last: whether channels are last or not
|
|
176
168
|
|
|
177
169
|
Returns:
|
|
178
|
-
-------
|
|
179
170
|
the new formatted target and the mask
|
|
180
171
|
"""
|
|
181
172
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|