python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import torch
|
|
@@ -20,7 +21,7 @@ from .base import LinkNetPostProcessor, _LinkNet
|
|
|
20
21
|
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
default_cfgs:
|
|
24
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
24
25
|
"linknet_resnet18": {
|
|
25
26
|
"input_shape": (3, 1024, 1024),
|
|
26
27
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -43,7 +44,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
class LinkNetFPN(nn.Module):
|
|
46
|
-
def __init__(self, layer_shapes:
|
|
47
|
+
def __init__(self, layer_shapes: list[tuple[int, int, int]]) -> None:
|
|
47
48
|
super().__init__()
|
|
48
49
|
strides = [
|
|
49
50
|
1 if (in_shape[-1] == out_shape[-1]) else 2
|
|
@@ -74,7 +75,7 @@ class LinkNetFPN(nn.Module):
|
|
|
74
75
|
nn.ReLU(inplace=True),
|
|
75
76
|
)
|
|
76
77
|
|
|
77
|
-
def forward(self, feats:
|
|
78
|
+
def forward(self, feats: list[torch.Tensor]) -> torch.Tensor:
|
|
78
79
|
out = feats[-1]
|
|
79
80
|
for decoder, fmap in zip(self.decoders[::-1], feats[:-1][::-1]):
|
|
80
81
|
out = decoder(out) + fmap
|
|
@@ -89,7 +90,6 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
89
90
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
90
91
|
|
|
91
92
|
Args:
|
|
92
|
-
----
|
|
93
93
|
feature extractor: the backbone serving as feature extractor
|
|
94
94
|
bin_thresh: threshold for binarization of the output feature map
|
|
95
95
|
box_thresh: minimal objectness score to consider a box
|
|
@@ -108,8 +108,8 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
108
108
|
head_chans: int = 32,
|
|
109
109
|
assume_straight_pages: bool = True,
|
|
110
110
|
exportable: bool = False,
|
|
111
|
-
cfg:
|
|
112
|
-
class_names:
|
|
111
|
+
cfg: dict[str, Any] | None = None,
|
|
112
|
+
class_names: list[str] = [CLASS_NAME],
|
|
113
113
|
) -> None:
|
|
114
114
|
super().__init__()
|
|
115
115
|
self.class_names = class_names
|
|
@@ -163,16 +163,16 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
163
163
|
def forward(
|
|
164
164
|
self,
|
|
165
165
|
x: torch.Tensor,
|
|
166
|
-
target:
|
|
166
|
+
target: list[np.ndarray] | None = None,
|
|
167
167
|
return_model_output: bool = False,
|
|
168
168
|
return_preds: bool = False,
|
|
169
169
|
**kwargs: Any,
|
|
170
|
-
) ->
|
|
170
|
+
) -> dict[str, Any]:
|
|
171
171
|
feats = self.feat_extractor(x)
|
|
172
172
|
logits = self.fpn([feats[str(idx)] for idx in range(len(feats))])
|
|
173
173
|
logits = self.classifier(logits)
|
|
174
174
|
|
|
175
|
-
out:
|
|
175
|
+
out: dict[str, Any] = {}
|
|
176
176
|
if self.exportable:
|
|
177
177
|
out["logits"] = logits
|
|
178
178
|
return out
|
|
@@ -183,11 +183,16 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
183
183
|
out["out_map"] = prob_map
|
|
184
184
|
|
|
185
185
|
if target is None or return_preds:
|
|
186
|
-
#
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
186
|
+
# Disable for torch.compile compatibility
|
|
187
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
188
|
+
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
189
|
+
return [
|
|
190
|
+
dict(zip(self.class_names, preds))
|
|
191
|
+
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
# Post-process boxes (keep only text predictions)
|
|
195
|
+
out["preds"] = _postprocess(prob_map)
|
|
191
196
|
|
|
192
197
|
if target is not None:
|
|
193
198
|
loss = self.compute_loss(logits, target)
|
|
@@ -198,7 +203,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
198
203
|
def compute_loss(
|
|
199
204
|
self,
|
|
200
205
|
out_map: torch.Tensor,
|
|
201
|
-
target:
|
|
206
|
+
target: list[np.ndarray],
|
|
202
207
|
gamma: float = 2.0,
|
|
203
208
|
alpha: float = 0.5,
|
|
204
209
|
eps: float = 1e-8,
|
|
@@ -207,7 +212,6 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
207
212
|
<https://github.com/tensorflow/addons/>`_.
|
|
208
213
|
|
|
209
214
|
Args:
|
|
210
|
-
----
|
|
211
215
|
out_map: output feature map of the model of shape (N, num_classes, H, W)
|
|
212
216
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
213
217
|
gamma: modulating factor in the focal loss formula
|
|
@@ -215,7 +219,6 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
215
219
|
eps: epsilon factor in dice loss
|
|
216
220
|
|
|
217
221
|
Returns:
|
|
218
|
-
-------
|
|
219
222
|
A loss tensor
|
|
220
223
|
"""
|
|
221
224
|
_target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
|
|
@@ -252,9 +255,9 @@ def _linknet(
|
|
|
252
255
|
arch: str,
|
|
253
256
|
pretrained: bool,
|
|
254
257
|
backbone_fn: Callable[[bool], nn.Module],
|
|
255
|
-
fpn_layers:
|
|
258
|
+
fpn_layers: list[str],
|
|
256
259
|
pretrained_backbone: bool = True,
|
|
257
|
-
ignore_keys:
|
|
260
|
+
ignore_keys: list[str] | None = None,
|
|
258
261
|
**kwargs: Any,
|
|
259
262
|
) -> LinkNet:
|
|
260
263
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -295,12 +298,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
295
298
|
>>> out = model(input_tensor)
|
|
296
299
|
|
|
297
300
|
Args:
|
|
298
|
-
----
|
|
299
301
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
300
302
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
301
303
|
|
|
302
304
|
Returns:
|
|
303
|
-
-------
|
|
304
305
|
text detection architecture
|
|
305
306
|
"""
|
|
306
307
|
return _linknet(
|
|
@@ -327,12 +328,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
327
328
|
>>> out = model(input_tensor)
|
|
328
329
|
|
|
329
330
|
Args:
|
|
330
|
-
----
|
|
331
331
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
332
332
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
333
333
|
|
|
334
334
|
Returns:
|
|
335
|
-
-------
|
|
336
335
|
text detection architecture
|
|
337
336
|
"""
|
|
338
337
|
return _linknet(
|
|
@@ -359,12 +358,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
359
358
|
>>> out = model(input_tensor)
|
|
360
359
|
|
|
361
360
|
Args:
|
|
362
|
-
----
|
|
363
361
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
364
362
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
365
363
|
|
|
366
364
|
Returns:
|
|
367
|
-
-------
|
|
368
365
|
text detection architecture
|
|
369
366
|
"""
|
|
370
367
|
return _linknet(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,40 +6,45 @@
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
13
|
-
from tensorflow import
|
|
14
|
-
from tensorflow.keras import Model, Sequential, layers
|
|
13
|
+
from tensorflow.keras import Model, Sequential, layers, losses
|
|
15
14
|
|
|
16
15
|
from doctr.file_utils import CLASS_NAME
|
|
17
16
|
from doctr.models.classification import resnet18, resnet34, resnet50
|
|
18
|
-
from doctr.models.utils import
|
|
17
|
+
from doctr.models.utils import (
|
|
18
|
+
IntermediateLayerGetter,
|
|
19
|
+
_bf16_to_float32,
|
|
20
|
+
_build_model,
|
|
21
|
+
conv_sequence,
|
|
22
|
+
load_pretrained_params,
|
|
23
|
+
)
|
|
19
24
|
from doctr.utils.repr import NestedObject
|
|
20
25
|
|
|
21
26
|
from .base import LinkNetPostProcessor, _LinkNet
|
|
22
27
|
|
|
23
28
|
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
|
|
24
29
|
|
|
25
|
-
default_cfgs:
|
|
30
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
26
31
|
"linknet_resnet18": {
|
|
27
32
|
"mean": (0.798, 0.785, 0.772),
|
|
28
33
|
"std": (0.264, 0.2749, 0.287),
|
|
29
34
|
"input_shape": (1024, 1024, 3),
|
|
30
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
35
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0",
|
|
31
36
|
},
|
|
32
37
|
"linknet_resnet34": {
|
|
33
38
|
"mean": (0.798, 0.785, 0.772),
|
|
34
39
|
"std": (0.264, 0.2749, 0.287),
|
|
35
40
|
"input_shape": (1024, 1024, 3),
|
|
36
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
41
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0",
|
|
37
42
|
},
|
|
38
43
|
"linknet_resnet50": {
|
|
39
44
|
"mean": (0.798, 0.785, 0.772),
|
|
40
45
|
"std": (0.264, 0.2749, 0.287),
|
|
41
46
|
"input_shape": (1024, 1024, 3),
|
|
42
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
47
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0",
|
|
43
48
|
},
|
|
44
49
|
}
|
|
45
50
|
|
|
@@ -68,7 +73,7 @@ class LinkNetFPN(Model, NestedObject):
|
|
|
68
73
|
def __init__(
|
|
69
74
|
self,
|
|
70
75
|
out_chans: int,
|
|
71
|
-
in_shapes:
|
|
76
|
+
in_shapes: list[tuple[int, ...]],
|
|
72
77
|
) -> None:
|
|
73
78
|
super().__init__()
|
|
74
79
|
self.out_chans = out_chans
|
|
@@ -80,22 +85,21 @@ class LinkNetFPN(Model, NestedObject):
|
|
|
80
85
|
for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
|
|
81
86
|
]
|
|
82
87
|
|
|
83
|
-
def call(self, x:
|
|
88
|
+
def call(self, x: list[tf.Tensor], **kwargs: Any) -> tf.Tensor:
|
|
84
89
|
out = 0
|
|
85
90
|
for decoder, fmap in zip(self.decoders, x[::-1]):
|
|
86
|
-
out = decoder(out + fmap)
|
|
91
|
+
out = decoder(out + fmap, **kwargs)
|
|
87
92
|
return out
|
|
88
93
|
|
|
89
94
|
def extra_repr(self) -> str:
|
|
90
95
|
return f"out_chans={self.out_chans}"
|
|
91
96
|
|
|
92
97
|
|
|
93
|
-
class LinkNet(_LinkNet,
|
|
98
|
+
class LinkNet(_LinkNet, Model):
|
|
94
99
|
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
95
100
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
96
101
|
|
|
97
102
|
Args:
|
|
98
|
-
----
|
|
99
103
|
feature extractor: the backbone serving as feature extractor
|
|
100
104
|
fpn_channels: number of channels each extracted feature maps is mapped to
|
|
101
105
|
bin_thresh: threshold for binarization of the output feature map
|
|
@@ -106,7 +110,7 @@ class LinkNet(_LinkNet, keras.Model):
|
|
|
106
110
|
class_names: list of class names
|
|
107
111
|
"""
|
|
108
112
|
|
|
109
|
-
_children_names:
|
|
113
|
+
_children_names: list[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"]
|
|
110
114
|
|
|
111
115
|
def __init__(
|
|
112
116
|
self,
|
|
@@ -116,8 +120,8 @@ class LinkNet(_LinkNet, keras.Model):
|
|
|
116
120
|
box_thresh: float = 0.1,
|
|
117
121
|
assume_straight_pages: bool = True,
|
|
118
122
|
exportable: bool = False,
|
|
119
|
-
cfg:
|
|
120
|
-
class_names:
|
|
123
|
+
cfg: dict[str, Any] | None = None,
|
|
124
|
+
class_names: list[str] = [CLASS_NAME],
|
|
121
125
|
) -> None:
|
|
122
126
|
super().__init__(cfg=cfg)
|
|
123
127
|
|
|
@@ -162,7 +166,7 @@ class LinkNet(_LinkNet, keras.Model):
|
|
|
162
166
|
def compute_loss(
|
|
163
167
|
self,
|
|
164
168
|
out_map: tf.Tensor,
|
|
165
|
-
target:
|
|
169
|
+
target: list[dict[str, np.ndarray]],
|
|
166
170
|
gamma: float = 2.0,
|
|
167
171
|
alpha: float = 0.5,
|
|
168
172
|
eps: float = 1e-8,
|
|
@@ -171,7 +175,6 @@ class LinkNet(_LinkNet, keras.Model):
|
|
|
171
175
|
<https://github.com/tensorflow/addons/>`_.
|
|
172
176
|
|
|
173
177
|
Args:
|
|
174
|
-
----
|
|
175
178
|
out_map: output feature map of the model of shape N x H x W x 1
|
|
176
179
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
177
180
|
gamma: modulating factor in the focal loss formula
|
|
@@ -179,7 +182,6 @@ class LinkNet(_LinkNet, keras.Model):
|
|
|
179
182
|
eps: epsilon factor in dice loss
|
|
180
183
|
|
|
181
184
|
Returns:
|
|
182
|
-
-------
|
|
183
185
|
A loss tensor
|
|
184
186
|
"""
|
|
185
187
|
seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True)
|
|
@@ -187,7 +189,7 @@ class LinkNet(_LinkNet, keras.Model):
|
|
|
187
189
|
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
|
|
188
190
|
seg_mask = tf.cast(seg_mask, tf.float32)
|
|
189
191
|
|
|
190
|
-
bce_loss =
|
|
192
|
+
bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
|
|
191
193
|
proba_map = tf.sigmoid(out_map)
|
|
192
194
|
|
|
193
195
|
# Focal loss
|
|
@@ -213,16 +215,16 @@ class LinkNet(_LinkNet, keras.Model):
|
|
|
213
215
|
def call(
|
|
214
216
|
self,
|
|
215
217
|
x: tf.Tensor,
|
|
216
|
-
target:
|
|
218
|
+
target: list[dict[str, np.ndarray]] | None = None,
|
|
217
219
|
return_model_output: bool = False,
|
|
218
220
|
return_preds: bool = False,
|
|
219
221
|
**kwargs: Any,
|
|
220
|
-
) ->
|
|
222
|
+
) -> dict[str, Any]:
|
|
221
223
|
feat_maps = self.feat_extractor(x, **kwargs)
|
|
222
224
|
logits = self.fpn(feat_maps, **kwargs)
|
|
223
225
|
logits = self.classifier(logits, **kwargs)
|
|
224
226
|
|
|
225
|
-
out:
|
|
227
|
+
out: dict[str, tf.Tensor] = {}
|
|
226
228
|
if self.exportable:
|
|
227
229
|
out["logits"] = logits
|
|
228
230
|
return out
|
|
@@ -248,9 +250,9 @@ def _linknet(
|
|
|
248
250
|
arch: str,
|
|
249
251
|
pretrained: bool,
|
|
250
252
|
backbone_fn,
|
|
251
|
-
fpn_layers:
|
|
253
|
+
fpn_layers: list[str],
|
|
252
254
|
pretrained_backbone: bool = True,
|
|
253
|
-
input_shape:
|
|
255
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
254
256
|
**kwargs: Any,
|
|
255
257
|
) -> LinkNet:
|
|
256
258
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -275,9 +277,16 @@ def _linknet(
|
|
|
275
277
|
|
|
276
278
|
# Build the model
|
|
277
279
|
model = LinkNet(feat_extractor, cfg=_cfg, **kwargs)
|
|
280
|
+
_build_model(model)
|
|
281
|
+
|
|
278
282
|
# Load pretrained parameters
|
|
279
283
|
if pretrained:
|
|
280
|
-
|
|
284
|
+
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
285
|
+
load_pretrained_params(
|
|
286
|
+
model,
|
|
287
|
+
_cfg["url"],
|
|
288
|
+
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
289
|
+
)
|
|
281
290
|
|
|
282
291
|
return model
|
|
283
292
|
|
|
@@ -293,12 +302,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
293
302
|
>>> out = model(input_tensor)
|
|
294
303
|
|
|
295
304
|
Args:
|
|
296
|
-
----
|
|
297
305
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
298
306
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
299
307
|
|
|
300
308
|
Returns:
|
|
301
|
-
-------
|
|
302
309
|
text detection architecture
|
|
303
310
|
"""
|
|
304
311
|
return _linknet(
|
|
@@ -321,12 +328,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
321
328
|
>>> out = model(input_tensor)
|
|
322
329
|
|
|
323
330
|
Args:
|
|
324
|
-
----
|
|
325
331
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
326
332
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
327
333
|
|
|
328
334
|
Returns:
|
|
329
|
-
-------
|
|
330
335
|
text detection architecture
|
|
331
336
|
"""
|
|
332
337
|
return _linknet(
|
|
@@ -349,12 +354,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
349
354
|
>>> out = model(input_tensor)
|
|
350
355
|
|
|
351
356
|
Args:
|
|
352
|
-
----
|
|
353
357
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
354
358
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
355
359
|
|
|
356
360
|
Returns:
|
|
357
|
-
-------
|
|
358
361
|
text detection architecture
|
|
359
362
|
"""
|
|
360
363
|
return _linknet(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -20,7 +20,6 @@ class DetectionPredictor(nn.Module):
|
|
|
20
20
|
"""Implements an object able to localize text elements in a document
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
----
|
|
24
23
|
pre_processor: transform inputs for easier batched model inference
|
|
25
24
|
model: core detection architecture
|
|
26
25
|
"""
|
|
@@ -37,10 +36,10 @@ class DetectionPredictor(nn.Module):
|
|
|
37
36
|
@torch.inference_mode()
|
|
38
37
|
def forward(
|
|
39
38
|
self,
|
|
40
|
-
pages:
|
|
39
|
+
pages: list[np.ndarray | torch.Tensor],
|
|
41
40
|
return_maps: bool = False,
|
|
42
41
|
**kwargs: Any,
|
|
43
|
-
) ->
|
|
42
|
+
) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
|
|
44
43
|
# Extract parameters from the preprocessor
|
|
45
44
|
preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
|
|
46
45
|
symmetric_pad = self.pre_processor.resize.symmetric_pad
|
|
@@ -60,11 +59,11 @@ class DetectionPredictor(nn.Module):
|
|
|
60
59
|
]
|
|
61
60
|
# Remove padding from loc predictions
|
|
62
61
|
preds = _remove_padding(
|
|
63
|
-
pages,
|
|
62
|
+
pages,
|
|
64
63
|
[pred for batch in predicted_batches for pred in batch["preds"]],
|
|
65
64
|
preserve_aspect_ratio=preserve_aspect_ratio,
|
|
66
65
|
symmetric_pad=symmetric_pad,
|
|
67
|
-
assume_straight_pages=assume_straight_pages,
|
|
66
|
+
assume_straight_pages=assume_straight_pages, # type: ignore[arg-type]
|
|
68
67
|
)
|
|
69
68
|
|
|
70
69
|
if return_maps:
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
10
|
-
from tensorflow import
|
|
10
|
+
from tensorflow.keras import Model
|
|
11
11
|
|
|
12
12
|
from doctr.models.detection._utils import _remove_padding
|
|
13
13
|
from doctr.models.preprocessor import PreProcessor
|
|
@@ -20,27 +20,26 @@ class DetectionPredictor(NestedObject):
|
|
|
20
20
|
"""Implements an object able to localize text elements in a document
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
----
|
|
24
23
|
pre_processor: transform inputs for easier batched model inference
|
|
25
24
|
model: core detection architecture
|
|
26
25
|
"""
|
|
27
26
|
|
|
28
|
-
_children_names:
|
|
27
|
+
_children_names: list[str] = ["pre_processor", "model"]
|
|
29
28
|
|
|
30
29
|
def __init__(
|
|
31
30
|
self,
|
|
32
31
|
pre_processor: PreProcessor,
|
|
33
|
-
model:
|
|
32
|
+
model: Model,
|
|
34
33
|
) -> None:
|
|
35
34
|
self.pre_processor = pre_processor
|
|
36
35
|
self.model = model
|
|
37
36
|
|
|
38
37
|
def __call__(
|
|
39
38
|
self,
|
|
40
|
-
pages:
|
|
39
|
+
pages: list[np.ndarray | tf.Tensor],
|
|
41
40
|
return_maps: bool = False,
|
|
42
41
|
**kwargs: Any,
|
|
43
|
-
) ->
|
|
42
|
+
) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
|
|
44
43
|
# Extract parameters from the preprocessor
|
|
45
44
|
preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
|
|
46
45
|
symmetric_pad = self.pre_processor.resize.symmetric_pad
|
doctr/models/detection/zoo.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
9
9
|
|
|
@@ -14,7 +14,7 @@ from .predictor import DetectionPredictor
|
|
|
14
14
|
|
|
15
15
|
__all__ = ["detection_predictor"]
|
|
16
16
|
|
|
17
|
-
ARCHS:
|
|
17
|
+
ARCHS: list[str]
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
if is_tf_available():
|
|
@@ -56,7 +56,14 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
|
|
|
56
56
|
if isinstance(_model, detection.FAST):
|
|
57
57
|
_model = reparameterize(_model)
|
|
58
58
|
else:
|
|
59
|
-
|
|
59
|
+
allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST]
|
|
60
|
+
if is_torch_available():
|
|
61
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
62
|
+
from doctr.models.utils import _CompiledModule
|
|
63
|
+
|
|
64
|
+
allowed_archs.append(_CompiledModule)
|
|
65
|
+
|
|
66
|
+
if not isinstance(arch, tuple(allowed_archs)):
|
|
60
67
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
61
68
|
|
|
62
69
|
_model = arch
|
|
@@ -79,6 +86,9 @@ def detection_predictor(
|
|
|
79
86
|
arch: Any = "fast_base",
|
|
80
87
|
pretrained: bool = False,
|
|
81
88
|
assume_straight_pages: bool = True,
|
|
89
|
+
preserve_aspect_ratio: bool = True,
|
|
90
|
+
symmetric_pad: bool = True,
|
|
91
|
+
batch_size: int = 2,
|
|
82
92
|
**kwargs: Any,
|
|
83
93
|
) -> DetectionPredictor:
|
|
84
94
|
"""Text detection architecture.
|
|
@@ -90,14 +100,24 @@ def detection_predictor(
|
|
|
90
100
|
>>> out = model([input_page])
|
|
91
101
|
|
|
92
102
|
Args:
|
|
93
|
-
----
|
|
94
103
|
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
|
|
95
104
|
pretrained: If True, returns a model pre-trained on our text detection dataset
|
|
96
105
|
assume_straight_pages: If True, fit straight boxes to the page
|
|
106
|
+
preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
|
|
107
|
+
running the detection model on it
|
|
108
|
+
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
|
|
109
|
+
batch_size: number of samples the model processes in parallel
|
|
97
110
|
**kwargs: optional keyword arguments passed to the architecture
|
|
98
111
|
|
|
99
112
|
Returns:
|
|
100
|
-
-------
|
|
101
113
|
Detection predictor
|
|
102
114
|
"""
|
|
103
|
-
return _predictor(
|
|
115
|
+
return _predictor(
|
|
116
|
+
arch=arch,
|
|
117
|
+
pretrained=pretrained,
|
|
118
|
+
assume_straight_pages=assume_straight_pages,
|
|
119
|
+
preserve_aspect_ratio=preserve_aspect_ratio,
|
|
120
|
+
symmetric_pad=symmetric_pad,
|
|
121
|
+
batch_size=batch_size,
|
|
122
|
+
**kwargs,
|
|
123
|
+
)
|
doctr/models/factory/hub.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -20,7 +20,6 @@ from huggingface_hub import (
|
|
|
20
20
|
get_token_permission,
|
|
21
21
|
hf_hub_download,
|
|
22
22
|
login,
|
|
23
|
-
snapshot_download,
|
|
24
23
|
)
|
|
25
24
|
|
|
26
25
|
from doctr import models
|
|
@@ -33,7 +32,7 @@ __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config
|
|
|
33
32
|
|
|
34
33
|
|
|
35
34
|
AVAILABLE_ARCHS = {
|
|
36
|
-
"classification": models.classification.zoo.ARCHS,
|
|
35
|
+
"classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS,
|
|
37
36
|
"detection": models.detection.zoo.ARCHS,
|
|
38
37
|
"recognition": models.recognition.zoo.ARCHS,
|
|
39
38
|
}
|
|
@@ -62,7 +61,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
|
|
|
62
61
|
"""Save model and config to disk for pushing to huggingface hub
|
|
63
62
|
|
|
64
63
|
Args:
|
|
65
|
-
----
|
|
66
64
|
model: TF or PyTorch model to be saved
|
|
67
65
|
save_dir: directory to save model and config
|
|
68
66
|
arch: architecture name
|
|
@@ -74,7 +72,7 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
|
|
|
74
72
|
weights_path = save_directory / "pytorch_model.bin"
|
|
75
73
|
torch.save(model.state_dict(), weights_path)
|
|
76
74
|
elif is_tf_available():
|
|
77
|
-
weights_path = save_directory / "tf_model
|
|
75
|
+
weights_path = save_directory / "tf_model.weights.h5"
|
|
78
76
|
model.save_weights(str(weights_path))
|
|
79
77
|
|
|
80
78
|
config_path = save_directory / "config.json"
|
|
@@ -98,7 +96,6 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
98
96
|
>>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
|
|
99
97
|
|
|
100
98
|
Args:
|
|
101
|
-
----
|
|
102
99
|
model: TF or PyTorch model to be saved
|
|
103
100
|
model_name: name of the model which is also the repository name
|
|
104
101
|
task: task name
|
|
@@ -115,9 +112,9 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
115
112
|
# default readme
|
|
116
113
|
readme = textwrap.dedent(
|
|
117
114
|
f"""
|
|
118
|
-
|
|
115
|
+
|
|
119
116
|
language: en
|
|
120
|
-
|
|
117
|
+
|
|
121
118
|
|
|
122
119
|
<p align="center">
|
|
123
120
|
<img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
|
|
@@ -174,7 +171,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
174
171
|
|
|
175
172
|
local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name)
|
|
176
173
|
repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False)
|
|
177
|
-
repo = Repository(local_dir=local_cache_dir, clone_from=repo_url
|
|
174
|
+
repo = Repository(local_dir=local_cache_dir, clone_from=repo_url)
|
|
178
175
|
|
|
179
176
|
with repo.commit(commit_message):
|
|
180
177
|
_save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task)
|
|
@@ -191,12 +188,10 @@ def from_hub(repo_id: str, **kwargs: Any):
|
|
|
191
188
|
>>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn")
|
|
192
189
|
|
|
193
190
|
Args:
|
|
194
|
-
----
|
|
195
191
|
repo_id: HuggingFace model hub repo
|
|
196
192
|
kwargs: kwargs of `hf_hub_download` or `snapshot_download`
|
|
197
193
|
|
|
198
194
|
Returns:
|
|
199
|
-
-------
|
|
200
195
|
Model loaded with the checkpoint
|
|
201
196
|
"""
|
|
202
197
|
# Get the config
|
|
@@ -225,7 +220,7 @@ def from_hub(repo_id: str, **kwargs: Any):
|
|
|
225
220
|
state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu")
|
|
226
221
|
model.load_state_dict(state_dict)
|
|
227
222
|
else: # tf
|
|
228
|
-
|
|
229
|
-
model.load_weights(
|
|
223
|
+
weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
|
|
224
|
+
model.load_weights(weights)
|
|
230
225
|
|
|
231
226
|
return model
|