python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import torch
|
|
@@ -20,7 +21,7 @@ from .base import LinkNetPostProcessor, _LinkNet
|
|
|
20
21
|
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
default_cfgs:
|
|
24
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
24
25
|
"linknet_resnet18": {
|
|
25
26
|
"input_shape": (3, 1024, 1024),
|
|
26
27
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -43,7 +44,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
class LinkNetFPN(nn.Module):
|
|
46
|
-
def __init__(self, layer_shapes:
|
|
47
|
+
def __init__(self, layer_shapes: list[tuple[int, int, int]]) -> None:
|
|
47
48
|
super().__init__()
|
|
48
49
|
strides = [
|
|
49
50
|
1 if (in_shape[-1] == out_shape[-1]) else 2
|
|
@@ -74,7 +75,7 @@ class LinkNetFPN(nn.Module):
|
|
|
74
75
|
nn.ReLU(inplace=True),
|
|
75
76
|
)
|
|
76
77
|
|
|
77
|
-
def forward(self, feats:
|
|
78
|
+
def forward(self, feats: list[torch.Tensor]) -> torch.Tensor:
|
|
78
79
|
out = feats[-1]
|
|
79
80
|
for decoder, fmap in zip(self.decoders[::-1], feats[:-1][::-1]):
|
|
80
81
|
out = decoder(out) + fmap
|
|
@@ -89,7 +90,6 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
89
90
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
90
91
|
|
|
91
92
|
Args:
|
|
92
|
-
----
|
|
93
93
|
feature extractor: the backbone serving as feature extractor
|
|
94
94
|
bin_thresh: threshold for binarization of the output feature map
|
|
95
95
|
box_thresh: minimal objectness score to consider a box
|
|
@@ -108,8 +108,8 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
108
108
|
head_chans: int = 32,
|
|
109
109
|
assume_straight_pages: bool = True,
|
|
110
110
|
exportable: bool = False,
|
|
111
|
-
cfg:
|
|
112
|
-
class_names:
|
|
111
|
+
cfg: dict[str, Any] | None = None,
|
|
112
|
+
class_names: list[str] = [CLASS_NAME],
|
|
113
113
|
) -> None:
|
|
114
114
|
super().__init__()
|
|
115
115
|
self.class_names = class_names
|
|
@@ -160,19 +160,28 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
160
160
|
m.weight.data.fill_(1.0)
|
|
161
161
|
m.bias.data.zero_()
|
|
162
162
|
|
|
163
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
164
|
+
"""Load pretrained parameters onto the model
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
168
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
169
|
+
"""
|
|
170
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
171
|
+
|
|
163
172
|
def forward(
|
|
164
173
|
self,
|
|
165
174
|
x: torch.Tensor,
|
|
166
|
-
target:
|
|
175
|
+
target: list[np.ndarray] | None = None,
|
|
167
176
|
return_model_output: bool = False,
|
|
168
177
|
return_preds: bool = False,
|
|
169
178
|
**kwargs: Any,
|
|
170
|
-
) ->
|
|
179
|
+
) -> dict[str, Any]:
|
|
171
180
|
feats = self.feat_extractor(x)
|
|
172
181
|
logits = self.fpn([feats[str(idx)] for idx in range(len(feats))])
|
|
173
182
|
logits = self.classifier(logits)
|
|
174
183
|
|
|
175
|
-
out:
|
|
184
|
+
out: dict[str, Any] = {}
|
|
176
185
|
if self.exportable:
|
|
177
186
|
out["logits"] = logits
|
|
178
187
|
return out
|
|
@@ -183,11 +192,16 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
183
192
|
out["out_map"] = prob_map
|
|
184
193
|
|
|
185
194
|
if target is None or return_preds:
|
|
186
|
-
#
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
195
|
+
# Disable for torch.compile compatibility
|
|
196
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
197
|
+
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
198
|
+
return [
|
|
199
|
+
dict(zip(self.class_names, preds))
|
|
200
|
+
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
# Post-process boxes (keep only text predictions)
|
|
204
|
+
out["preds"] = _postprocess(prob_map)
|
|
191
205
|
|
|
192
206
|
if target is not None:
|
|
193
207
|
loss = self.compute_loss(logits, target)
|
|
@@ -198,7 +212,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
198
212
|
def compute_loss(
|
|
199
213
|
self,
|
|
200
214
|
out_map: torch.Tensor,
|
|
201
|
-
target:
|
|
215
|
+
target: list[np.ndarray],
|
|
202
216
|
gamma: float = 2.0,
|
|
203
217
|
alpha: float = 0.5,
|
|
204
218
|
eps: float = 1e-8,
|
|
@@ -207,7 +221,6 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
207
221
|
<https://github.com/tensorflow/addons/>`_.
|
|
208
222
|
|
|
209
223
|
Args:
|
|
210
|
-
----
|
|
211
224
|
out_map: output feature map of the model of shape (N, num_classes, H, W)
|
|
212
225
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
213
226
|
gamma: modulating factor in the focal loss formula
|
|
@@ -215,7 +228,6 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
215
228
|
eps: epsilon factor in dice loss
|
|
216
229
|
|
|
217
230
|
Returns:
|
|
218
|
-
-------
|
|
219
231
|
A loss tensor
|
|
220
232
|
"""
|
|
221
233
|
_target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
|
|
@@ -252,9 +264,9 @@ def _linknet(
|
|
|
252
264
|
arch: str,
|
|
253
265
|
pretrained: bool,
|
|
254
266
|
backbone_fn: Callable[[bool], nn.Module],
|
|
255
|
-
fpn_layers:
|
|
267
|
+
fpn_layers: list[str],
|
|
256
268
|
pretrained_backbone: bool = True,
|
|
257
|
-
ignore_keys:
|
|
269
|
+
ignore_keys: list[str] | None = None,
|
|
258
270
|
**kwargs: Any,
|
|
259
271
|
) -> LinkNet:
|
|
260
272
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -279,7 +291,7 @@ def _linknet(
|
|
|
279
291
|
_ignore_keys = (
|
|
280
292
|
ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
|
|
281
293
|
)
|
|
282
|
-
|
|
294
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
283
295
|
|
|
284
296
|
return model
|
|
285
297
|
|
|
@@ -295,12 +307,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
295
307
|
>>> out = model(input_tensor)
|
|
296
308
|
|
|
297
309
|
Args:
|
|
298
|
-
----
|
|
299
310
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
300
311
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
301
312
|
|
|
302
313
|
Returns:
|
|
303
|
-
-------
|
|
304
314
|
text detection architecture
|
|
305
315
|
"""
|
|
306
316
|
return _linknet(
|
|
@@ -327,12 +337,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
327
337
|
>>> out = model(input_tensor)
|
|
328
338
|
|
|
329
339
|
Args:
|
|
330
|
-
----
|
|
331
340
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
332
341
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
333
342
|
|
|
334
343
|
Returns:
|
|
335
|
-
-------
|
|
336
344
|
text detection architecture
|
|
337
345
|
"""
|
|
338
346
|
return _linknet(
|
|
@@ -359,12 +367,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
359
367
|
>>> out = model(input_tensor)
|
|
360
368
|
|
|
361
369
|
Args:
|
|
362
|
-
----
|
|
363
370
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
364
371
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
365
372
|
|
|
366
373
|
Returns:
|
|
367
|
-
-------
|
|
368
374
|
text detection architecture
|
|
369
375
|
"""
|
|
370
376
|
return _linknet(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
@@ -27,7 +27,7 @@ from .base import LinkNetPostProcessor, _LinkNet
|
|
|
27
27
|
|
|
28
28
|
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
|
|
29
29
|
|
|
30
|
-
default_cfgs:
|
|
30
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
31
31
|
"linknet_resnet18": {
|
|
32
32
|
"mean": (0.798, 0.785, 0.772),
|
|
33
33
|
"std": (0.264, 0.2749, 0.287),
|
|
@@ -73,7 +73,7 @@ class LinkNetFPN(Model, NestedObject):
|
|
|
73
73
|
def __init__(
|
|
74
74
|
self,
|
|
75
75
|
out_chans: int,
|
|
76
|
-
in_shapes:
|
|
76
|
+
in_shapes: list[tuple[int, ...]],
|
|
77
77
|
) -> None:
|
|
78
78
|
super().__init__()
|
|
79
79
|
self.out_chans = out_chans
|
|
@@ -85,7 +85,7 @@ class LinkNetFPN(Model, NestedObject):
|
|
|
85
85
|
for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
|
|
86
86
|
]
|
|
87
87
|
|
|
88
|
-
def call(self, x:
|
|
88
|
+
def call(self, x: list[tf.Tensor], **kwargs: Any) -> tf.Tensor:
|
|
89
89
|
out = 0
|
|
90
90
|
for decoder, fmap in zip(self.decoders, x[::-1]):
|
|
91
91
|
out = decoder(out + fmap, **kwargs)
|
|
@@ -100,7 +100,6 @@ class LinkNet(_LinkNet, Model):
|
|
|
100
100
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
101
101
|
|
|
102
102
|
Args:
|
|
103
|
-
----
|
|
104
103
|
feature extractor: the backbone serving as feature extractor
|
|
105
104
|
fpn_channels: number of channels each extracted feature maps is mapped to
|
|
106
105
|
bin_thresh: threshold for binarization of the output feature map
|
|
@@ -111,7 +110,7 @@ class LinkNet(_LinkNet, Model):
|
|
|
111
110
|
class_names: list of class names
|
|
112
111
|
"""
|
|
113
112
|
|
|
114
|
-
_children_names:
|
|
113
|
+
_children_names: list[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"]
|
|
115
114
|
|
|
116
115
|
def __init__(
|
|
117
116
|
self,
|
|
@@ -121,8 +120,8 @@ class LinkNet(_LinkNet, Model):
|
|
|
121
120
|
box_thresh: float = 0.1,
|
|
122
121
|
assume_straight_pages: bool = True,
|
|
123
122
|
exportable: bool = False,
|
|
124
|
-
cfg:
|
|
125
|
-
class_names:
|
|
123
|
+
cfg: dict[str, Any] | None = None,
|
|
124
|
+
class_names: list[str] = [CLASS_NAME],
|
|
126
125
|
) -> None:
|
|
127
126
|
super().__init__(cfg=cfg)
|
|
128
127
|
|
|
@@ -164,10 +163,19 @@ class LinkNet(_LinkNet, Model):
|
|
|
164
163
|
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
|
|
165
164
|
)
|
|
166
165
|
|
|
166
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
167
|
+
"""Load pretrained parameters onto the model
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
171
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
172
|
+
"""
|
|
173
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
174
|
+
|
|
167
175
|
def compute_loss(
|
|
168
176
|
self,
|
|
169
177
|
out_map: tf.Tensor,
|
|
170
|
-
target:
|
|
178
|
+
target: list[dict[str, np.ndarray]],
|
|
171
179
|
gamma: float = 2.0,
|
|
172
180
|
alpha: float = 0.5,
|
|
173
181
|
eps: float = 1e-8,
|
|
@@ -176,7 +184,6 @@ class LinkNet(_LinkNet, Model):
|
|
|
176
184
|
<https://github.com/tensorflow/addons/>`_.
|
|
177
185
|
|
|
178
186
|
Args:
|
|
179
|
-
----
|
|
180
187
|
out_map: output feature map of the model of shape N x H x W x 1
|
|
181
188
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
182
189
|
gamma: modulating factor in the focal loss formula
|
|
@@ -184,7 +191,6 @@ class LinkNet(_LinkNet, Model):
|
|
|
184
191
|
eps: epsilon factor in dice loss
|
|
185
192
|
|
|
186
193
|
Returns:
|
|
187
|
-
-------
|
|
188
194
|
A loss tensor
|
|
189
195
|
"""
|
|
190
196
|
seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True)
|
|
@@ -218,16 +224,16 @@ class LinkNet(_LinkNet, Model):
|
|
|
218
224
|
def call(
|
|
219
225
|
self,
|
|
220
226
|
x: tf.Tensor,
|
|
221
|
-
target:
|
|
227
|
+
target: list[dict[str, np.ndarray]] | None = None,
|
|
222
228
|
return_model_output: bool = False,
|
|
223
229
|
return_preds: bool = False,
|
|
224
230
|
**kwargs: Any,
|
|
225
|
-
) ->
|
|
231
|
+
) -> dict[str, Any]:
|
|
226
232
|
feat_maps = self.feat_extractor(x, **kwargs)
|
|
227
233
|
logits = self.fpn(feat_maps, **kwargs)
|
|
228
234
|
logits = self.classifier(logits, **kwargs)
|
|
229
235
|
|
|
230
|
-
out:
|
|
236
|
+
out: dict[str, tf.Tensor] = {}
|
|
231
237
|
if self.exportable:
|
|
232
238
|
out["logits"] = logits
|
|
233
239
|
return out
|
|
@@ -253,9 +259,9 @@ def _linknet(
|
|
|
253
259
|
arch: str,
|
|
254
260
|
pretrained: bool,
|
|
255
261
|
backbone_fn,
|
|
256
|
-
fpn_layers:
|
|
262
|
+
fpn_layers: list[str],
|
|
257
263
|
pretrained_backbone: bool = True,
|
|
258
|
-
input_shape:
|
|
264
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
259
265
|
**kwargs: Any,
|
|
260
266
|
) -> LinkNet:
|
|
261
267
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -285,8 +291,7 @@ def _linknet(
|
|
|
285
291
|
# Load pretrained parameters
|
|
286
292
|
if pretrained:
|
|
287
293
|
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
288
|
-
|
|
289
|
-
model,
|
|
294
|
+
model.from_pretrained(
|
|
290
295
|
_cfg["url"],
|
|
291
296
|
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
292
297
|
)
|
|
@@ -305,12 +310,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
305
310
|
>>> out = model(input_tensor)
|
|
306
311
|
|
|
307
312
|
Args:
|
|
308
|
-
----
|
|
309
313
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
310
314
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
311
315
|
|
|
312
316
|
Returns:
|
|
313
|
-
-------
|
|
314
317
|
text detection architecture
|
|
315
318
|
"""
|
|
316
319
|
return _linknet(
|
|
@@ -333,12 +336,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
333
336
|
>>> out = model(input_tensor)
|
|
334
337
|
|
|
335
338
|
Args:
|
|
336
|
-
----
|
|
337
339
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
338
340
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
339
341
|
|
|
340
342
|
Returns:
|
|
341
|
-
-------
|
|
342
343
|
text detection architecture
|
|
343
344
|
"""
|
|
344
345
|
return _linknet(
|
|
@@ -361,12 +362,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
361
362
|
>>> out = model(input_tensor)
|
|
362
363
|
|
|
363
364
|
Args:
|
|
364
|
-
----
|
|
365
365
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
366
366
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
367
367
|
|
|
368
368
|
Returns:
|
|
369
|
-
-------
|
|
370
369
|
text detection architecture
|
|
371
370
|
"""
|
|
372
371
|
return _linknet(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -20,7 +20,6 @@ class DetectionPredictor(nn.Module):
|
|
|
20
20
|
"""Implements an object able to localize text elements in a document
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
----
|
|
24
23
|
pre_processor: transform inputs for easier batched model inference
|
|
25
24
|
model: core detection architecture
|
|
26
25
|
"""
|
|
@@ -37,10 +36,10 @@ class DetectionPredictor(nn.Module):
|
|
|
37
36
|
@torch.inference_mode()
|
|
38
37
|
def forward(
|
|
39
38
|
self,
|
|
40
|
-
pages:
|
|
39
|
+
pages: list[np.ndarray | torch.Tensor],
|
|
41
40
|
return_maps: bool = False,
|
|
42
41
|
**kwargs: Any,
|
|
43
|
-
) ->
|
|
42
|
+
) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
|
|
44
43
|
# Extract parameters from the preprocessor
|
|
45
44
|
preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
|
|
46
45
|
symmetric_pad = self.pre_processor.resize.symmetric_pad
|
|
@@ -60,11 +59,11 @@ class DetectionPredictor(nn.Module):
|
|
|
60
59
|
]
|
|
61
60
|
# Remove padding from loc predictions
|
|
62
61
|
preds = _remove_padding(
|
|
63
|
-
pages,
|
|
62
|
+
pages,
|
|
64
63
|
[pred for batch in predicted_batches for pred in batch["preds"]],
|
|
65
64
|
preserve_aspect_ratio=preserve_aspect_ratio,
|
|
66
65
|
symmetric_pad=symmetric_pad,
|
|
67
|
-
assume_straight_pages=assume_straight_pages,
|
|
66
|
+
assume_straight_pages=assume_straight_pages, # type: ignore[arg-type]
|
|
68
67
|
)
|
|
69
68
|
|
|
70
69
|
if return_maps:
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
@@ -20,12 +20,11 @@ class DetectionPredictor(NestedObject):
|
|
|
20
20
|
"""Implements an object able to localize text elements in a document
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
----
|
|
24
23
|
pre_processor: transform inputs for easier batched model inference
|
|
25
24
|
model: core detection architecture
|
|
26
25
|
"""
|
|
27
26
|
|
|
28
|
-
_children_names:
|
|
27
|
+
_children_names: list[str] = ["pre_processor", "model"]
|
|
29
28
|
|
|
30
29
|
def __init__(
|
|
31
30
|
self,
|
|
@@ -37,10 +36,10 @@ class DetectionPredictor(NestedObject):
|
|
|
37
36
|
|
|
38
37
|
def __call__(
|
|
39
38
|
self,
|
|
40
|
-
pages:
|
|
39
|
+
pages: list[np.ndarray | tf.Tensor],
|
|
41
40
|
return_maps: bool = False,
|
|
42
41
|
**kwargs: Any,
|
|
43
|
-
) ->
|
|
42
|
+
) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
|
|
44
43
|
# Extract parameters from the preprocessor
|
|
45
44
|
preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
|
|
46
45
|
symmetric_pad = self.pre_processor.resize.symmetric_pad
|
doctr/models/detection/zoo.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
9
9
|
|
|
@@ -14,7 +14,7 @@ from .predictor import DetectionPredictor
|
|
|
14
14
|
|
|
15
15
|
__all__ = ["detection_predictor"]
|
|
16
16
|
|
|
17
|
-
ARCHS:
|
|
17
|
+
ARCHS: list[str]
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
if is_tf_available():
|
|
@@ -56,7 +56,14 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
|
|
|
56
56
|
if isinstance(_model, detection.FAST):
|
|
57
57
|
_model = reparameterize(_model)
|
|
58
58
|
else:
|
|
59
|
-
|
|
59
|
+
allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST]
|
|
60
|
+
if is_torch_available():
|
|
61
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
62
|
+
from doctr.models.utils import _CompiledModule
|
|
63
|
+
|
|
64
|
+
allowed_archs.append(_CompiledModule)
|
|
65
|
+
|
|
66
|
+
if not isinstance(arch, tuple(allowed_archs)):
|
|
60
67
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
61
68
|
|
|
62
69
|
_model = arch
|
|
@@ -79,6 +86,9 @@ def detection_predictor(
|
|
|
79
86
|
arch: Any = "fast_base",
|
|
80
87
|
pretrained: bool = False,
|
|
81
88
|
assume_straight_pages: bool = True,
|
|
89
|
+
preserve_aspect_ratio: bool = True,
|
|
90
|
+
symmetric_pad: bool = True,
|
|
91
|
+
batch_size: int = 2,
|
|
82
92
|
**kwargs: Any,
|
|
83
93
|
) -> DetectionPredictor:
|
|
84
94
|
"""Text detection architecture.
|
|
@@ -90,14 +100,24 @@ def detection_predictor(
|
|
|
90
100
|
>>> out = model([input_page])
|
|
91
101
|
|
|
92
102
|
Args:
|
|
93
|
-
----
|
|
94
103
|
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
|
|
95
104
|
pretrained: If True, returns a model pre-trained on our text detection dataset
|
|
96
105
|
assume_straight_pages: If True, fit straight boxes to the page
|
|
106
|
+
preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
|
|
107
|
+
running the detection model on it
|
|
108
|
+
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
|
|
109
|
+
batch_size: number of samples the model processes in parallel
|
|
97
110
|
**kwargs: optional keyword arguments passed to the architecture
|
|
98
111
|
|
|
99
112
|
Returns:
|
|
100
|
-
-------
|
|
101
113
|
Detection predictor
|
|
102
114
|
"""
|
|
103
|
-
return _predictor(
|
|
115
|
+
return _predictor(
|
|
116
|
+
arch=arch,
|
|
117
|
+
pretrained=pretrained,
|
|
118
|
+
assume_straight_pages=assume_straight_pages,
|
|
119
|
+
preserve_aspect_ratio=preserve_aspect_ratio,
|
|
120
|
+
symmetric_pad=symmetric_pad,
|
|
121
|
+
batch_size=batch_size,
|
|
122
|
+
**kwargs,
|
|
123
|
+
)
|
doctr/models/factory/hub.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -61,7 +61,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
|
|
|
61
61
|
"""Save model and config to disk for pushing to huggingface hub
|
|
62
62
|
|
|
63
63
|
Args:
|
|
64
|
-
----
|
|
65
64
|
model: TF or PyTorch model to be saved
|
|
66
65
|
save_dir: directory to save model and config
|
|
67
66
|
arch: architecture name
|
|
@@ -97,7 +96,6 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
97
96
|
>>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
|
|
98
97
|
|
|
99
98
|
Args:
|
|
100
|
-
----
|
|
101
99
|
model: TF or PyTorch model to be saved
|
|
102
100
|
model_name: name of the model which is also the repository name
|
|
103
101
|
task: task name
|
|
@@ -114,9 +112,9 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
114
112
|
# default readme
|
|
115
113
|
readme = textwrap.dedent(
|
|
116
114
|
f"""
|
|
117
|
-
|
|
115
|
+
|
|
118
116
|
language: en
|
|
119
|
-
|
|
117
|
+
|
|
120
118
|
|
|
121
119
|
<p align="center">
|
|
122
120
|
<img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
|
|
@@ -190,12 +188,10 @@ def from_hub(repo_id: str, **kwargs: Any):
|
|
|
190
188
|
>>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn")
|
|
191
189
|
|
|
192
190
|
Args:
|
|
193
|
-
----
|
|
194
191
|
repo_id: HuggingFace model hub repo
|
|
195
192
|
kwargs: kwargs of `hf_hub_download` or `snapshot_download`
|
|
196
193
|
|
|
197
194
|
Returns:
|
|
198
|
-
-------
|
|
199
195
|
Model loaded with the checkpoint
|
|
200
196
|
"""
|
|
201
197
|
# Get the config
|
|
@@ -221,10 +217,10 @@ def from_hub(repo_id: str, **kwargs: Any):
|
|
|
221
217
|
|
|
222
218
|
# Load checkpoint
|
|
223
219
|
if is_torch_available():
|
|
224
|
-
|
|
225
|
-
model.load_state_dict(state_dict)
|
|
220
|
+
weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
|
|
226
221
|
else: # tf
|
|
227
222
|
weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
|
|
228
|
-
|
|
223
|
+
|
|
224
|
+
model.from_pretrained(weights)
|
|
229
225
|
|
|
230
226
|
return model
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from doctr.models.builder import KIEDocumentBuilder
|
|
9
9
|
|
|
@@ -17,7 +17,6 @@ class _KIEPredictor(_OCRPredictor):
|
|
|
17
17
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
18
18
|
|
|
19
19
|
Args:
|
|
20
|
-
----
|
|
21
20
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
22
21
|
without rotated textual elements.
|
|
23
22
|
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
|
|
@@ -30,8 +29,8 @@ class _KIEPredictor(_OCRPredictor):
|
|
|
30
29
|
kwargs: keyword args of `DocumentBuilder`
|
|
31
30
|
"""
|
|
32
31
|
|
|
33
|
-
crop_orientation_predictor:
|
|
34
|
-
page_orientation_predictor:
|
|
32
|
+
crop_orientation_predictor: OrientationPredictor | None
|
|
33
|
+
page_orientation_predictor: OrientationPredictor | None
|
|
35
34
|
|
|
36
35
|
def __init__(
|
|
37
36
|
self,
|