python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import torch
|
|
@@ -22,7 +23,7 @@ from .base import DBPostProcessor, _DBNet
|
|
|
22
23
|
__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
|
|
23
24
|
|
|
24
25
|
|
|
25
|
-
default_cfgs:
|
|
26
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
26
27
|
"db_resnet50": {
|
|
27
28
|
"input_shape": (3, 1024, 1024),
|
|
28
29
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -47,7 +48,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
47
48
|
class FeaturePyramidNetwork(nn.Module):
|
|
48
49
|
def __init__(
|
|
49
50
|
self,
|
|
50
|
-
in_channels:
|
|
51
|
+
in_channels: list[int],
|
|
51
52
|
out_channels: int,
|
|
52
53
|
deform_conv: bool = False,
|
|
53
54
|
) -> None:
|
|
@@ -76,12 +77,12 @@ class FeaturePyramidNetwork(nn.Module):
|
|
|
76
77
|
for idx, chans in enumerate(in_channels)
|
|
77
78
|
])
|
|
78
79
|
|
|
79
|
-
def forward(self, x:
|
|
80
|
+
def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
|
|
80
81
|
if len(x) != len(self.out_branches):
|
|
81
82
|
raise AssertionError
|
|
82
83
|
# Conv1x1 to get the same number of channels
|
|
83
|
-
_x:
|
|
84
|
-
out:
|
|
84
|
+
_x: list[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)]
|
|
85
|
+
out: list[torch.Tensor] = [_x[-1]]
|
|
85
86
|
for t in _x[:-1][::-1]:
|
|
86
87
|
out.append(self.upsample(out[-1]) + t)
|
|
87
88
|
|
|
@@ -96,7 +97,6 @@ class DBNet(_DBNet, nn.Module):
|
|
|
96
97
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
97
98
|
|
|
98
99
|
Args:
|
|
99
|
-
----
|
|
100
100
|
feature extractor: the backbone serving as feature extractor
|
|
101
101
|
head_chans: the number of channels in the head
|
|
102
102
|
deform_conv: whether to use deformable convolution
|
|
@@ -117,8 +117,8 @@ class DBNet(_DBNet, nn.Module):
|
|
|
117
117
|
box_thresh: float = 0.1,
|
|
118
118
|
assume_straight_pages: bool = True,
|
|
119
119
|
exportable: bool = False,
|
|
120
|
-
cfg:
|
|
121
|
-
class_names:
|
|
120
|
+
cfg: dict[str, Any] | None = None,
|
|
121
|
+
class_names: list[str] = [CLASS_NAME],
|
|
122
122
|
) -> None:
|
|
123
123
|
super().__init__()
|
|
124
124
|
self.class_names = class_names
|
|
@@ -179,13 +179,22 @@ class DBNet(_DBNet, nn.Module):
|
|
|
179
179
|
m.weight.data.fill_(1.0)
|
|
180
180
|
m.bias.data.zero_()
|
|
181
181
|
|
|
182
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
183
|
+
"""Load pretrained parameters onto the model
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
187
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
188
|
+
"""
|
|
189
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
190
|
+
|
|
182
191
|
def forward(
|
|
183
192
|
self,
|
|
184
193
|
x: torch.Tensor,
|
|
185
|
-
target:
|
|
194
|
+
target: list[np.ndarray] | None = None,
|
|
186
195
|
return_model_output: bool = False,
|
|
187
196
|
return_preds: bool = False,
|
|
188
|
-
) ->
|
|
197
|
+
) -> dict[str, torch.Tensor]:
|
|
189
198
|
# Extract feature maps at different stages
|
|
190
199
|
feats = self.feat_extractor(x)
|
|
191
200
|
feats = [feats[str(idx)] for idx in range(len(feats))]
|
|
@@ -193,7 +202,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
193
202
|
feat_concat = self.fpn(feats)
|
|
194
203
|
logits = self.prob_head(feat_concat)
|
|
195
204
|
|
|
196
|
-
out:
|
|
205
|
+
out: dict[str, Any] = {}
|
|
197
206
|
if self.exportable:
|
|
198
207
|
out["logits"] = logits
|
|
199
208
|
return out
|
|
@@ -205,11 +214,16 @@ class DBNet(_DBNet, nn.Module):
|
|
|
205
214
|
out["out_map"] = prob_map
|
|
206
215
|
|
|
207
216
|
if target is None or return_preds:
|
|
217
|
+
# Disable for torch.compile compatibility
|
|
218
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
219
|
+
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
220
|
+
return [
|
|
221
|
+
dict(zip(self.class_names, preds))
|
|
222
|
+
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
223
|
+
]
|
|
224
|
+
|
|
208
225
|
# Post-process boxes (keep only text predictions)
|
|
209
|
-
out["preds"] =
|
|
210
|
-
dict(zip(self.class_names, preds))
|
|
211
|
-
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
212
|
-
]
|
|
226
|
+
out["preds"] = _postprocess(prob_map)
|
|
213
227
|
|
|
214
228
|
if target is not None:
|
|
215
229
|
thresh_map = self.thresh_head(feat_concat)
|
|
@@ -222,7 +236,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
222
236
|
self,
|
|
223
237
|
out_map: torch.Tensor,
|
|
224
238
|
thresh_map: torch.Tensor,
|
|
225
|
-
target:
|
|
239
|
+
target: list[np.ndarray],
|
|
226
240
|
gamma: float = 2.0,
|
|
227
241
|
alpha: float = 0.5,
|
|
228
242
|
eps: float = 1e-8,
|
|
@@ -231,7 +245,6 @@ class DBNet(_DBNet, nn.Module):
|
|
|
231
245
|
and a list of masks for each image. From there it computes the loss with the model output
|
|
232
246
|
|
|
233
247
|
Args:
|
|
234
|
-
----
|
|
235
248
|
out_map: output feature map of the model of shape (N, C, H, W)
|
|
236
249
|
thresh_map: threshold map of shape (N, C, H, W)
|
|
237
250
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
@@ -240,7 +253,6 @@ class DBNet(_DBNet, nn.Module):
|
|
|
240
253
|
eps: epsilon factor in dice loss
|
|
241
254
|
|
|
242
255
|
Returns:
|
|
243
|
-
-------
|
|
244
256
|
A loss tensor
|
|
245
257
|
"""
|
|
246
258
|
if gamma < 0:
|
|
@@ -273,7 +285,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
273
285
|
dice_map = torch.softmax(out_map, dim=1)
|
|
274
286
|
else:
|
|
275
287
|
# compute binary map instead
|
|
276
|
-
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
|
|
288
|
+
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
|
|
277
289
|
# Class reduced
|
|
278
290
|
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
|
|
279
291
|
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
|
|
@@ -290,10 +302,10 @@ def _dbnet(
|
|
|
290
302
|
arch: str,
|
|
291
303
|
pretrained: bool,
|
|
292
304
|
backbone_fn: Callable[[bool], nn.Module],
|
|
293
|
-
fpn_layers:
|
|
294
|
-
backbone_submodule:
|
|
305
|
+
fpn_layers: list[str],
|
|
306
|
+
backbone_submodule: str | None = None,
|
|
295
307
|
pretrained_backbone: bool = True,
|
|
296
|
-
ignore_keys:
|
|
308
|
+
ignore_keys: list[str] | None = None,
|
|
297
309
|
**kwargs: Any,
|
|
298
310
|
) -> DBNet:
|
|
299
311
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -325,7 +337,7 @@ def _dbnet(
|
|
|
325
337
|
_ignore_keys = (
|
|
326
338
|
ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
|
|
327
339
|
)
|
|
328
|
-
|
|
340
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
329
341
|
|
|
330
342
|
return model
|
|
331
343
|
|
|
@@ -341,12 +353,10 @@ def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
341
353
|
>>> out = model(input_tensor)
|
|
342
354
|
|
|
343
355
|
Args:
|
|
344
|
-
----
|
|
345
356
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
346
357
|
**kwargs: keyword arguments of the DBNet architecture
|
|
347
358
|
|
|
348
359
|
Returns:
|
|
349
|
-
-------
|
|
350
360
|
text detection architecture
|
|
351
361
|
"""
|
|
352
362
|
return _dbnet(
|
|
@@ -376,12 +386,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
376
386
|
>>> out = model(input_tensor)
|
|
377
387
|
|
|
378
388
|
Args:
|
|
379
|
-
----
|
|
380
389
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
381
390
|
**kwargs: keyword arguments of the DBNet architecture
|
|
382
391
|
|
|
383
392
|
Returns:
|
|
384
|
-
-------
|
|
385
393
|
text detection architecture
|
|
386
394
|
"""
|
|
387
395
|
return _dbnet(
|
|
@@ -411,12 +419,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
411
419
|
>>> out = model(input_tensor)
|
|
412
420
|
|
|
413
421
|
Args:
|
|
414
|
-
----
|
|
415
422
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
416
423
|
**kwargs: keyword arguments of the DBNet architecture
|
|
417
424
|
|
|
418
425
|
Returns:
|
|
419
|
-
-------
|
|
420
426
|
text detection architecture
|
|
421
427
|
"""
|
|
422
428
|
return _dbnet(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
@@ -29,7 +29,7 @@ from .base import DBPostProcessor, _DBNet
|
|
|
29
29
|
__all__ = ["DBNet", "db_resnet50", "db_mobilenet_v3_large"]
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
default_cfgs:
|
|
32
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
33
33
|
"db_resnet50": {
|
|
34
34
|
"mean": (0.798, 0.785, 0.772),
|
|
35
35
|
"std": (0.264, 0.2749, 0.287),
|
|
@@ -50,7 +50,6 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
50
50
|
<https://arxiv.org/pdf/1612.03144.pdf>`_.
|
|
51
51
|
|
|
52
52
|
Args:
|
|
53
|
-
----
|
|
54
53
|
channels: number of channel to output
|
|
55
54
|
"""
|
|
56
55
|
|
|
@@ -72,12 +71,10 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
72
71
|
"""Module which performs a 3x3 convolution followed by up-sampling
|
|
73
72
|
|
|
74
73
|
Args:
|
|
75
|
-
----
|
|
76
74
|
channels: number of output channels
|
|
77
75
|
dilation_factor (int): dilation factor to scale the convolution output before concatenation
|
|
78
76
|
|
|
79
77
|
Returns:
|
|
80
|
-
-------
|
|
81
78
|
a keras.layers.Layer object, wrapping these operations in a sequential module
|
|
82
79
|
|
|
83
80
|
"""
|
|
@@ -95,7 +92,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
|
|
|
95
92
|
|
|
96
93
|
def call(
|
|
97
94
|
self,
|
|
98
|
-
x:
|
|
95
|
+
x: list[tf.Tensor],
|
|
99
96
|
**kwargs: Any,
|
|
100
97
|
) -> tf.Tensor:
|
|
101
98
|
# Channel mapping
|
|
@@ -114,7 +111,6 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
114
111
|
<https://arxiv.org/pdf/1911.08947.pdf>`_.
|
|
115
112
|
|
|
116
113
|
Args:
|
|
117
|
-
----
|
|
118
114
|
feature extractor: the backbone serving as feature extractor
|
|
119
115
|
fpn_channels: number of channels each extracted feature maps is mapped to
|
|
120
116
|
bin_thresh: threshold for binarization
|
|
@@ -125,7 +121,7 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
125
121
|
class_names: list of class names
|
|
126
122
|
"""
|
|
127
123
|
|
|
128
|
-
_children_names:
|
|
124
|
+
_children_names: list[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"]
|
|
129
125
|
|
|
130
126
|
def __init__(
|
|
131
127
|
self,
|
|
@@ -135,8 +131,8 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
135
131
|
box_thresh: float = 0.1,
|
|
136
132
|
assume_straight_pages: bool = True,
|
|
137
133
|
exportable: bool = False,
|
|
138
|
-
cfg:
|
|
139
|
-
class_names:
|
|
134
|
+
cfg: dict[str, Any] | None = None,
|
|
135
|
+
class_names: list[str] = [CLASS_NAME],
|
|
140
136
|
) -> None:
|
|
141
137
|
super().__init__()
|
|
142
138
|
self.class_names = class_names
|
|
@@ -171,11 +167,20 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
171
167
|
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
|
|
172
168
|
)
|
|
173
169
|
|
|
170
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
171
|
+
"""Load pretrained parameters onto the model
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
175
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
176
|
+
"""
|
|
177
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
178
|
+
|
|
174
179
|
def compute_loss(
|
|
175
180
|
self,
|
|
176
181
|
out_map: tf.Tensor,
|
|
177
182
|
thresh_map: tf.Tensor,
|
|
178
|
-
target:
|
|
183
|
+
target: list[dict[str, np.ndarray]],
|
|
179
184
|
gamma: float = 2.0,
|
|
180
185
|
alpha: float = 0.5,
|
|
181
186
|
eps: float = 1e-8,
|
|
@@ -184,7 +189,6 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
184
189
|
and a list of masks for each image. From there it computes the loss with the model output
|
|
185
190
|
|
|
186
191
|
Args:
|
|
187
|
-
----
|
|
188
192
|
out_map: output feature map of the model of shape (N, H, W, C)
|
|
189
193
|
thresh_map: threshold map of shape (N, H, W, C)
|
|
190
194
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
@@ -193,7 +197,6 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
193
197
|
eps: epsilon factor in dice loss
|
|
194
198
|
|
|
195
199
|
Returns:
|
|
196
|
-
-------
|
|
197
200
|
A loss tensor
|
|
198
201
|
"""
|
|
199
202
|
if gamma < 0:
|
|
@@ -246,16 +249,16 @@ class DBNet(_DBNet, Model, NestedObject):
|
|
|
246
249
|
def call(
|
|
247
250
|
self,
|
|
248
251
|
x: tf.Tensor,
|
|
249
|
-
target:
|
|
252
|
+
target: list[dict[str, np.ndarray]] | None = None,
|
|
250
253
|
return_model_output: bool = False,
|
|
251
254
|
return_preds: bool = False,
|
|
252
255
|
**kwargs: Any,
|
|
253
|
-
) ->
|
|
256
|
+
) -> dict[str, Any]:
|
|
254
257
|
feat_maps = self.feat_extractor(x, **kwargs)
|
|
255
258
|
feat_concat = self.fpn(feat_maps, **kwargs)
|
|
256
259
|
logits = self.probability_head(feat_concat, **kwargs)
|
|
257
260
|
|
|
258
|
-
out:
|
|
261
|
+
out: dict[str, tf.Tensor] = {}
|
|
259
262
|
if self.exportable:
|
|
260
263
|
out["logits"] = logits
|
|
261
264
|
return out
|
|
@@ -282,9 +285,9 @@ def _db_resnet(
|
|
|
282
285
|
arch: str,
|
|
283
286
|
pretrained: bool,
|
|
284
287
|
backbone_fn,
|
|
285
|
-
fpn_layers:
|
|
288
|
+
fpn_layers: list[str],
|
|
286
289
|
pretrained_backbone: bool = True,
|
|
287
|
-
input_shape:
|
|
290
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
288
291
|
**kwargs: Any,
|
|
289
292
|
) -> DBNet:
|
|
290
293
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -315,8 +318,7 @@ def _db_resnet(
|
|
|
315
318
|
# Load pretrained parameters
|
|
316
319
|
if pretrained:
|
|
317
320
|
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
318
|
-
|
|
319
|
-
model,
|
|
321
|
+
model.from_pretrained(
|
|
320
322
|
_cfg["url"],
|
|
321
323
|
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
322
324
|
)
|
|
@@ -328,9 +330,9 @@ def _db_mobilenet(
|
|
|
328
330
|
arch: str,
|
|
329
331
|
pretrained: bool,
|
|
330
332
|
backbone_fn,
|
|
331
|
-
fpn_layers:
|
|
333
|
+
fpn_layers: list[str],
|
|
332
334
|
pretrained_backbone: bool = True,
|
|
333
|
-
input_shape:
|
|
335
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
334
336
|
**kwargs: Any,
|
|
335
337
|
) -> DBNet:
|
|
336
338
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -359,8 +361,7 @@ def _db_mobilenet(
|
|
|
359
361
|
# Load pretrained parameters
|
|
360
362
|
if pretrained:
|
|
361
363
|
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
362
|
-
|
|
363
|
-
model,
|
|
364
|
+
model.from_pretrained(
|
|
364
365
|
_cfg["url"],
|
|
365
366
|
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
366
367
|
)
|
|
@@ -379,12 +380,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
379
380
|
>>> out = model(input_tensor)
|
|
380
381
|
|
|
381
382
|
Args:
|
|
382
|
-
----
|
|
383
383
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
384
384
|
**kwargs: keyword arguments of the DBNet architecture
|
|
385
385
|
|
|
386
386
|
Returns:
|
|
387
|
-
-------
|
|
388
387
|
text detection architecture
|
|
389
388
|
"""
|
|
390
389
|
return _db_resnet(
|
|
@@ -407,12 +406,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
|
|
|
407
406
|
>>> out = model(input_tensor)
|
|
408
407
|
|
|
409
408
|
Args:
|
|
410
|
-
----
|
|
411
409
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
412
410
|
**kwargs: keyword arguments of the DBNet architecture
|
|
413
411
|
|
|
414
412
|
Returns:
|
|
415
|
-
-------
|
|
416
413
|
text detection architecture
|
|
417
414
|
"""
|
|
418
415
|
return _db_mobilenet(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
|
-
from typing import Dict, List, Tuple, Union
|
|
9
8
|
|
|
10
9
|
import cv2
|
|
11
10
|
import numpy as np
|
|
@@ -23,7 +22,6 @@ class FASTPostProcessor(DetectionPostProcessor):
|
|
|
23
22
|
"""Implements a post processor for FAST model.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
bin_thresh: threshold used to binzarized p_map at inference time
|
|
28
26
|
box_thresh: minimal objectness score to consider a box
|
|
29
27
|
assume_straight_pages: whether the inputs were expected to have horizontal text elements
|
|
@@ -45,11 +43,9 @@ class FASTPostProcessor(DetectionPostProcessor):
|
|
|
45
43
|
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
46
44
|
|
|
47
45
|
Args:
|
|
48
|
-
----
|
|
49
46
|
points: The first parameter.
|
|
50
47
|
|
|
51
48
|
Returns:
|
|
52
|
-
-------
|
|
53
49
|
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
54
50
|
"""
|
|
55
51
|
if not self.assume_straight_pages:
|
|
@@ -60,9 +56,8 @@ class FASTPostProcessor(DetectionPostProcessor):
|
|
|
60
56
|
area = (rect[1][0] + 1) * (1 + rect[1][1])
|
|
61
57
|
length = 2 * (rect[1][0] + rect[1][1]) + 2
|
|
62
58
|
else:
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
length = poly.length
|
|
59
|
+
area = cv2.contourArea(points)
|
|
60
|
+
length = cv2.arcLength(points, closed=True)
|
|
66
61
|
distance = area * self.unclip_ratio / length # compute distance to expand polygon
|
|
67
62
|
offset = pyclipper.PyclipperOffset()
|
|
68
63
|
offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
@@ -94,24 +89,22 @@ class FASTPostProcessor(DetectionPostProcessor):
|
|
|
94
89
|
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
95
90
|
|
|
96
91
|
Args:
|
|
97
|
-
----
|
|
98
92
|
pred: Pred map from differentiable linknet output
|
|
99
93
|
bitmap: Bitmap map computed from pred (binarized)
|
|
100
94
|
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
101
95
|
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
102
96
|
|
|
103
97
|
Returns:
|
|
104
|
-
-------
|
|
105
98
|
np tensor boxes for the bitmap, each box is a 6-element list
|
|
106
99
|
containing x, y, w, h, alpha, score for the box
|
|
107
100
|
"""
|
|
108
101
|
height, width = bitmap.shape[:2]
|
|
109
|
-
boxes:
|
|
102
|
+
boxes: list[np.ndarray | list[float]] = []
|
|
110
103
|
# get contours from connected components on the bitmap
|
|
111
104
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
112
105
|
for contour in contours:
|
|
113
106
|
# Check whether smallest enclosing bounding box is not too small
|
|
114
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
107
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
115
108
|
continue
|
|
116
109
|
# Compute objectness
|
|
117
110
|
if self.assume_straight_pages:
|
|
@@ -158,20 +151,18 @@ class _FAST(BaseModel):
|
|
|
158
151
|
|
|
159
152
|
def build_target(
|
|
160
153
|
self,
|
|
161
|
-
target:
|
|
162
|
-
output_shape:
|
|
154
|
+
target: list[dict[str, np.ndarray]],
|
|
155
|
+
output_shape: tuple[int, int, int],
|
|
163
156
|
channels_last: bool = True,
|
|
164
|
-
) ->
|
|
157
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
165
158
|
"""Build the target, and it's mask to be used from loss computation.
|
|
166
159
|
|
|
167
160
|
Args:
|
|
168
|
-
----
|
|
169
161
|
target: target coming from dataset
|
|
170
162
|
output_shape: shape of the output of the model without batch_size
|
|
171
163
|
channels_last: whether channels are last or not
|
|
172
164
|
|
|
173
165
|
Returns:
|
|
174
|
-
-------
|
|
175
166
|
the new formatted target, mask and shrunken text kernel
|
|
176
167
|
"""
|
|
177
168
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|