python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import torch
|
|
@@ -21,7 +22,7 @@ from .base import _FAST, FASTPostProcessor
|
|
|
21
22
|
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
default_cfgs:
|
|
25
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
25
26
|
"fast_tiny": {
|
|
26
27
|
"input_shape": (3, 1024, 1024),
|
|
27
28
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -47,7 +48,6 @@ class FastNeck(nn.Module):
|
|
|
47
48
|
"""Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layers.
|
|
48
49
|
|
|
49
50
|
Args:
|
|
50
|
-
----
|
|
51
51
|
in_channels: number of input channels
|
|
52
52
|
out_channels: number of output channels
|
|
53
53
|
"""
|
|
@@ -77,7 +77,6 @@ class FastHead(nn.Sequential):
|
|
|
77
77
|
"""Head of the FAST architecture
|
|
78
78
|
|
|
79
79
|
Args:
|
|
80
|
-
----
|
|
81
80
|
in_channels: number of input channels
|
|
82
81
|
num_classes: number of output classes
|
|
83
82
|
out_channels: number of output channels
|
|
@@ -91,7 +90,7 @@ class FastHead(nn.Sequential):
|
|
|
91
90
|
out_channels: int = 128,
|
|
92
91
|
dropout: float = 0.1,
|
|
93
92
|
) -> None:
|
|
94
|
-
_layers:
|
|
93
|
+
_layers: list[nn.Module] = [
|
|
95
94
|
FASTConvLayer(in_channels, out_channels, kernel_size=3),
|
|
96
95
|
nn.Dropout(dropout),
|
|
97
96
|
nn.Conv2d(out_channels, num_classes, kernel_size=1, bias=False),
|
|
@@ -104,7 +103,6 @@ class FAST(_FAST, nn.Module):
|
|
|
104
103
|
<https://arxiv.org/pdf/2111.02394.pdf>`_.
|
|
105
104
|
|
|
106
105
|
Args:
|
|
107
|
-
----
|
|
108
106
|
feat extractor: the backbone serving as feature extractor
|
|
109
107
|
bin_thresh: threshold for binarization
|
|
110
108
|
box_thresh: minimal objectness score to consider a box
|
|
@@ -125,8 +123,8 @@ class FAST(_FAST, nn.Module):
|
|
|
125
123
|
pooling_size: int = 4, # different from paper performs better on close text-rich images
|
|
126
124
|
assume_straight_pages: bool = True,
|
|
127
125
|
exportable: bool = False,
|
|
128
|
-
cfg:
|
|
129
|
-
class_names:
|
|
126
|
+
cfg: dict[str, Any] = {},
|
|
127
|
+
class_names: list[str] = [CLASS_NAME],
|
|
130
128
|
) -> None:
|
|
131
129
|
super().__init__()
|
|
132
130
|
self.class_names = class_names
|
|
@@ -172,13 +170,22 @@ class FAST(_FAST, nn.Module):
|
|
|
172
170
|
m.weight.data.fill_(1.0)
|
|
173
171
|
m.bias.data.zero_()
|
|
174
172
|
|
|
173
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
174
|
+
"""Load pretrained parameters onto the model
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
178
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
179
|
+
"""
|
|
180
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
181
|
+
|
|
175
182
|
def forward(
|
|
176
183
|
self,
|
|
177
184
|
x: torch.Tensor,
|
|
178
|
-
target:
|
|
185
|
+
target: list[np.ndarray] | None = None,
|
|
179
186
|
return_model_output: bool = False,
|
|
180
187
|
return_preds: bool = False,
|
|
181
|
-
) ->
|
|
188
|
+
) -> dict[str, torch.Tensor]:
|
|
182
189
|
# Extract feature maps at different stages
|
|
183
190
|
feats = self.feat_extractor(x)
|
|
184
191
|
feats = [feats[str(idx)] for idx in range(len(feats))]
|
|
@@ -186,7 +193,7 @@ class FAST(_FAST, nn.Module):
|
|
|
186
193
|
feat_concat = self.neck(feats)
|
|
187
194
|
logits = F.interpolate(self.prob_head(feat_concat), size=x.shape[-2:], mode="bilinear")
|
|
188
195
|
|
|
189
|
-
out:
|
|
196
|
+
out: dict[str, Any] = {}
|
|
190
197
|
if self.exportable:
|
|
191
198
|
out["logits"] = logits
|
|
192
199
|
return out
|
|
@@ -198,11 +205,16 @@ class FAST(_FAST, nn.Module):
|
|
|
198
205
|
out["out_map"] = prob_map
|
|
199
206
|
|
|
200
207
|
if target is None or return_preds:
|
|
208
|
+
# Disable for torch.compile compatibility
|
|
209
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
210
|
+
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
211
|
+
return [
|
|
212
|
+
dict(zip(self.class_names, preds))
|
|
213
|
+
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
214
|
+
]
|
|
215
|
+
|
|
201
216
|
# Post-process boxes (keep only text predictions)
|
|
202
|
-
out["preds"] =
|
|
203
|
-
dict(zip(self.class_names, preds))
|
|
204
|
-
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
205
|
-
]
|
|
217
|
+
out["preds"] = _postprocess(prob_map)
|
|
206
218
|
|
|
207
219
|
if target is not None:
|
|
208
220
|
loss = self.compute_loss(logits, target)
|
|
@@ -213,19 +225,17 @@ class FAST(_FAST, nn.Module):
|
|
|
213
225
|
def compute_loss(
|
|
214
226
|
self,
|
|
215
227
|
out_map: torch.Tensor,
|
|
216
|
-
target:
|
|
228
|
+
target: list[np.ndarray],
|
|
217
229
|
eps: float = 1e-6,
|
|
218
230
|
) -> torch.Tensor:
|
|
219
231
|
"""Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
|
|
220
232
|
|
|
221
233
|
Args:
|
|
222
|
-
----
|
|
223
234
|
out_map: output feature map of the model of shape (N, num_classes, H, W)
|
|
224
235
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
225
236
|
eps: epsilon factor in dice loss
|
|
226
237
|
|
|
227
238
|
Returns:
|
|
228
|
-
-------
|
|
229
239
|
A loss tensor
|
|
230
240
|
"""
|
|
231
241
|
targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
|
|
@@ -279,15 +289,13 @@ class FAST(_FAST, nn.Module):
|
|
|
279
289
|
return text_loss + kernel_loss
|
|
280
290
|
|
|
281
291
|
|
|
282
|
-
def reparameterize(model:
|
|
292
|
+
def reparameterize(model: FAST | nn.Module) -> FAST:
|
|
283
293
|
"""Fuse batchnorm and conv layers and reparameterize the model
|
|
284
294
|
|
|
285
|
-
|
|
286
|
-
----
|
|
295
|
+
Args:
|
|
287
296
|
model: the FAST model to reparameterize
|
|
288
297
|
|
|
289
298
|
Returns:
|
|
290
|
-
-------
|
|
291
299
|
the reparameterized model
|
|
292
300
|
"""
|
|
293
301
|
last_conv = None
|
|
@@ -303,12 +311,12 @@ def reparameterize(model: Union[FAST, nn.Module]) -> FAST:
|
|
|
303
311
|
if last_conv is None:
|
|
304
312
|
continue
|
|
305
313
|
conv_w = last_conv.weight
|
|
306
|
-
conv_b = last_conv.bias if last_conv.bias is not None else torch.zeros_like(child.running_mean)
|
|
314
|
+
conv_b = last_conv.bias if last_conv.bias is not None else torch.zeros_like(child.running_mean) # type: ignore[arg-type]
|
|
307
315
|
|
|
308
|
-
factor = child.weight / torch.sqrt(child.running_var + child.eps)
|
|
316
|
+
factor = child.weight / torch.sqrt(child.running_var + child.eps) # type: ignore
|
|
309
317
|
last_conv.weight = nn.Parameter(conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1]))
|
|
310
318
|
last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias)
|
|
311
|
-
model._modules[last_conv_name] = last_conv
|
|
319
|
+
model._modules[last_conv_name] = last_conv # type: ignore[index]
|
|
312
320
|
model._modules[name] = nn.Identity()
|
|
313
321
|
last_conv = None
|
|
314
322
|
elif isinstance(child, nn.Conv2d):
|
|
@@ -324,9 +332,9 @@ def _fast(
|
|
|
324
332
|
arch: str,
|
|
325
333
|
pretrained: bool,
|
|
326
334
|
backbone_fn: Callable[[bool], nn.Module],
|
|
327
|
-
feat_layers:
|
|
335
|
+
feat_layers: list[str],
|
|
328
336
|
pretrained_backbone: bool = True,
|
|
329
|
-
ignore_keys:
|
|
337
|
+
ignore_keys: list[str] | None = None,
|
|
330
338
|
**kwargs: Any,
|
|
331
339
|
) -> FAST:
|
|
332
340
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -350,7 +358,7 @@ def _fast(
|
|
|
350
358
|
_ignore_keys = (
|
|
351
359
|
ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
|
|
352
360
|
)
|
|
353
|
-
|
|
361
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
354
362
|
|
|
355
363
|
return model
|
|
356
364
|
|
|
@@ -366,12 +374,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
366
374
|
>>> out = model(input_tensor)
|
|
367
375
|
|
|
368
376
|
Args:
|
|
369
|
-
----
|
|
370
377
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
371
378
|
**kwargs: keyword arguments of the DBNet architecture
|
|
372
379
|
|
|
373
380
|
Returns:
|
|
374
|
-
-------
|
|
375
381
|
text detection architecture
|
|
376
382
|
"""
|
|
377
383
|
return _fast(
|
|
@@ -395,12 +401,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
395
401
|
>>> out = model(input_tensor)
|
|
396
402
|
|
|
397
403
|
Args:
|
|
398
|
-
----
|
|
399
404
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
400
405
|
**kwargs: keyword arguments of the DBNet architecture
|
|
401
406
|
|
|
402
407
|
Returns:
|
|
403
|
-
-------
|
|
404
408
|
text detection architecture
|
|
405
409
|
"""
|
|
406
410
|
return _fast(
|
|
@@ -424,12 +428,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
424
428
|
>>> out = model(input_tensor)
|
|
425
429
|
|
|
426
430
|
Args:
|
|
427
|
-
----
|
|
428
431
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
429
432
|
**kwargs: keyword arguments of the DBNet architecture
|
|
430
433
|
|
|
431
434
|
Returns:
|
|
432
|
-
-------
|
|
433
435
|
text detection architecture
|
|
434
436
|
"""
|
|
435
437
|
return _fast(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
@@ -23,7 +23,7 @@ from .base import _FAST, FASTPostProcessor
|
|
|
23
23
|
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
default_cfgs:
|
|
26
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
27
27
|
"fast_tiny": {
|
|
28
28
|
"input_shape": (1024, 1024, 3),
|
|
29
29
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -49,7 +49,6 @@ class FastNeck(layers.Layer, NestedObject):
|
|
|
49
49
|
"""Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer.
|
|
50
50
|
|
|
51
51
|
Args:
|
|
52
|
-
----
|
|
53
52
|
in_channels: number of input channels
|
|
54
53
|
out_channels: number of output channels
|
|
55
54
|
"""
|
|
@@ -77,7 +76,6 @@ class FastHead(Sequential):
|
|
|
77
76
|
"""Head of the FAST architecture
|
|
78
77
|
|
|
79
78
|
Args:
|
|
80
|
-
----
|
|
81
79
|
in_channels: number of input channels
|
|
82
80
|
num_classes: number of output classes
|
|
83
81
|
out_channels: number of output channels
|
|
@@ -104,7 +102,6 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
104
102
|
<https://arxiv.org/pdf/2111.02394.pdf>`_.
|
|
105
103
|
|
|
106
104
|
Args:
|
|
107
|
-
----
|
|
108
105
|
feature extractor: the backbone serving as feature extractor
|
|
109
106
|
bin_thresh: threshold for binarization
|
|
110
107
|
box_thresh: minimal objectness score to consider a box
|
|
@@ -116,7 +113,7 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
116
113
|
class_names: list of class names
|
|
117
114
|
"""
|
|
118
115
|
|
|
119
|
-
_children_names:
|
|
116
|
+
_children_names: list[str] = ["feat_extractor", "neck", "head", "postprocessor"]
|
|
120
117
|
|
|
121
118
|
def __init__(
|
|
122
119
|
self,
|
|
@@ -127,8 +124,8 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
127
124
|
pooling_size: int = 4, # different from paper performs better on close text-rich images
|
|
128
125
|
assume_straight_pages: bool = True,
|
|
129
126
|
exportable: bool = False,
|
|
130
|
-
cfg:
|
|
131
|
-
class_names:
|
|
127
|
+
cfg: dict[str, Any] = {},
|
|
128
|
+
class_names: list[str] = [CLASS_NAME],
|
|
132
129
|
) -> None:
|
|
133
130
|
super().__init__()
|
|
134
131
|
self.class_names = class_names
|
|
@@ -156,22 +153,29 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
156
153
|
# Pooling layer as erosion reversal as described in the paper
|
|
157
154
|
self.pooling = layers.MaxPooling2D(pool_size=pooling_size // 2 + 1, strides=1, padding="same")
|
|
158
155
|
|
|
156
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
157
|
+
"""Load pretrained parameters onto the model
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
161
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
162
|
+
"""
|
|
163
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
164
|
+
|
|
159
165
|
def compute_loss(
|
|
160
166
|
self,
|
|
161
167
|
out_map: tf.Tensor,
|
|
162
|
-
target:
|
|
168
|
+
target: list[dict[str, np.ndarray]],
|
|
163
169
|
eps: float = 1e-6,
|
|
164
170
|
) -> tf.Tensor:
|
|
165
171
|
"""Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
|
|
166
172
|
|
|
167
173
|
Args:
|
|
168
|
-
----
|
|
169
174
|
out_map: output feature map of the model of shape (N, num_classes, H, W)
|
|
170
175
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
171
176
|
eps: epsilon factor in dice loss
|
|
172
177
|
|
|
173
178
|
Returns:
|
|
174
|
-
-------
|
|
175
179
|
A loss tensor
|
|
176
180
|
"""
|
|
177
181
|
targets = self.build_target(target, out_map.shape[1:], True)
|
|
@@ -222,18 +226,18 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
222
226
|
def call(
|
|
223
227
|
self,
|
|
224
228
|
x: tf.Tensor,
|
|
225
|
-
target:
|
|
229
|
+
target: list[dict[str, np.ndarray]] | None = None,
|
|
226
230
|
return_model_output: bool = False,
|
|
227
231
|
return_preds: bool = False,
|
|
228
232
|
**kwargs: Any,
|
|
229
|
-
) ->
|
|
233
|
+
) -> dict[str, Any]:
|
|
230
234
|
feat_maps = self.feat_extractor(x, **kwargs)
|
|
231
235
|
# Pass through the Neck & Head & Upsample
|
|
232
236
|
feat_concat = self.neck(feat_maps, **kwargs)
|
|
233
237
|
logits: tf.Tensor = self.head(feat_concat, **kwargs)
|
|
234
238
|
logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs)
|
|
235
239
|
|
|
236
|
-
out:
|
|
240
|
+
out: dict[str, tf.Tensor] = {}
|
|
237
241
|
if self.exportable:
|
|
238
242
|
out["logits"] = logits
|
|
239
243
|
return out
|
|
@@ -255,15 +259,14 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
255
259
|
return out
|
|
256
260
|
|
|
257
261
|
|
|
258
|
-
def reparameterize(model:
|
|
262
|
+
def reparameterize(model: FAST | layers.Layer) -> FAST:
|
|
259
263
|
"""Fuse batchnorm and conv layers and reparameterize the model
|
|
260
264
|
|
|
261
265
|
args:
|
|
262
|
-
|
|
266
|
+
|
|
263
267
|
model: the FAST model to reparameterize
|
|
264
268
|
|
|
265
269
|
Returns:
|
|
266
|
-
-------
|
|
267
270
|
the reparameterized model
|
|
268
271
|
"""
|
|
269
272
|
last_conv = None
|
|
@@ -306,9 +309,9 @@ def _fast(
|
|
|
306
309
|
arch: str,
|
|
307
310
|
pretrained: bool,
|
|
308
311
|
backbone_fn,
|
|
309
|
-
feat_layers:
|
|
312
|
+
feat_layers: list[str],
|
|
310
313
|
pretrained_backbone: bool = True,
|
|
311
|
-
input_shape:
|
|
314
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
312
315
|
**kwargs: Any,
|
|
313
316
|
) -> FAST:
|
|
314
317
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -338,8 +341,7 @@ def _fast(
|
|
|
338
341
|
# Load pretrained parameters
|
|
339
342
|
if pretrained:
|
|
340
343
|
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
341
|
-
|
|
342
|
-
model,
|
|
344
|
+
model.from_pretrained(
|
|
343
345
|
_cfg["url"],
|
|
344
346
|
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
345
347
|
)
|
|
@@ -358,12 +360,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
358
360
|
>>> out = model(input_tensor)
|
|
359
361
|
|
|
360
362
|
Args:
|
|
361
|
-
----
|
|
362
363
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
363
364
|
**kwargs: keyword arguments of the DBNet architecture
|
|
364
365
|
|
|
365
366
|
Returns:
|
|
366
|
-
-------
|
|
367
367
|
text detection architecture
|
|
368
368
|
"""
|
|
369
369
|
return _fast(
|
|
@@ -386,12 +386,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
386
386
|
>>> out = model(input_tensor)
|
|
387
387
|
|
|
388
388
|
Args:
|
|
389
|
-
----
|
|
390
389
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
391
390
|
**kwargs: keyword arguments of the DBNet architecture
|
|
392
391
|
|
|
393
392
|
Returns:
|
|
394
|
-
-------
|
|
395
393
|
text detection architecture
|
|
396
394
|
"""
|
|
397
395
|
return _fast(
|
|
@@ -414,12 +412,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
|
|
|
414
412
|
>>> out = model(input_tensor)
|
|
415
413
|
|
|
416
414
|
Args:
|
|
417
|
-
----
|
|
418
415
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
419
416
|
**kwargs: keyword arguments of the DBNet architecture
|
|
420
417
|
|
|
421
418
|
Returns:
|
|
422
|
-
-------
|
|
423
419
|
text detection architecture
|
|
424
420
|
"""
|
|
425
421
|
return _fast(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
|
-
from typing import Dict, List, Tuple, Union
|
|
9
8
|
|
|
10
9
|
import cv2
|
|
11
10
|
import numpy as np
|
|
@@ -23,7 +22,6 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
23
22
|
"""Implements a post processor for LinkNet model.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
bin_thresh: threshold used to binzarized p_map at inference time
|
|
28
26
|
box_thresh: minimal objectness score to consider a box
|
|
29
27
|
assume_straight_pages: whether the inputs were expected to have horizontal text elements
|
|
@@ -45,11 +43,9 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
45
43
|
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
46
44
|
|
|
47
45
|
Args:
|
|
48
|
-
----
|
|
49
46
|
points: The first parameter.
|
|
50
47
|
|
|
51
48
|
Returns:
|
|
52
|
-
-------
|
|
53
49
|
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
54
50
|
"""
|
|
55
51
|
if not self.assume_straight_pages:
|
|
@@ -60,9 +56,8 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
60
56
|
area = (rect[1][0] + 1) * (1 + rect[1][1])
|
|
61
57
|
length = 2 * (rect[1][0] + rect[1][1]) + 2
|
|
62
58
|
else:
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
length = poly.length
|
|
59
|
+
area = cv2.contourArea(points)
|
|
60
|
+
length = cv2.arcLength(points, closed=True)
|
|
66
61
|
distance = area * self.unclip_ratio / length # compute distance to expand polygon
|
|
67
62
|
offset = pyclipper.PyclipperOffset()
|
|
68
63
|
offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
@@ -94,24 +89,22 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
94
89
|
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
95
90
|
|
|
96
91
|
Args:
|
|
97
|
-
----
|
|
98
92
|
pred: Pred map from differentiable linknet output
|
|
99
93
|
bitmap: Bitmap map computed from pred (binarized)
|
|
100
94
|
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
101
95
|
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
102
96
|
|
|
103
97
|
Returns:
|
|
104
|
-
-------
|
|
105
98
|
np tensor boxes for the bitmap, each box is a 6-element list
|
|
106
99
|
containing x, y, w, h, alpha, score for the box
|
|
107
100
|
"""
|
|
108
101
|
height, width = bitmap.shape[:2]
|
|
109
|
-
boxes:
|
|
102
|
+
boxes: list[np.ndarray | list[float]] = []
|
|
110
103
|
# get contours from connected components on the bitmap
|
|
111
104
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
112
105
|
for contour in contours:
|
|
113
106
|
# Check whether smallest enclosing bounding box is not too small
|
|
114
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
107
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
115
108
|
continue
|
|
116
109
|
# Compute objectness
|
|
117
110
|
if self.assume_straight_pages:
|
|
@@ -152,7 +145,6 @@ class _LinkNet(BaseModel):
|
|
|
152
145
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
153
146
|
|
|
154
147
|
Args:
|
|
155
|
-
----
|
|
156
148
|
out_chan: number of channels for the output
|
|
157
149
|
"""
|
|
158
150
|
|
|
@@ -162,20 +154,18 @@ class _LinkNet(BaseModel):
|
|
|
162
154
|
|
|
163
155
|
def build_target(
|
|
164
156
|
self,
|
|
165
|
-
target:
|
|
166
|
-
output_shape:
|
|
157
|
+
target: list[dict[str, np.ndarray]],
|
|
158
|
+
output_shape: tuple[int, int, int],
|
|
167
159
|
channels_last: bool = True,
|
|
168
|
-
) ->
|
|
160
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
169
161
|
"""Build the target, and it's mask to be used from loss computation.
|
|
170
162
|
|
|
171
163
|
Args:
|
|
172
|
-
----
|
|
173
164
|
target: target coming from dataset
|
|
174
165
|
output_shape: shape of the output of the model without batch_size
|
|
175
166
|
channels_last: whether channels are last or not
|
|
176
167
|
|
|
177
168
|
Returns:
|
|
178
|
-
-------
|
|
179
169
|
the new formatted target and the mask
|
|
180
170
|
"""
|
|
181
171
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|