python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +8 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +7 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +4 -5
- doctr/datasets/ic13.py +4 -5
- doctr/datasets/iiit5k.py +6 -5
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +6 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +6 -5
- doctr/datasets/svt.py +4 -5
- doctr/datasets/synthtext.py +4 -5
- doctr/datasets/utils.py +34 -29
- doctr/datasets/vocabs.py +17 -7
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +2 -7
- doctr/io/elements.py +59 -79
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +30 -48
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +8 -11
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +5 -17
- doctr/models/classification/mobilenet/tensorflow.py +8 -21
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +6 -8
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +20 -31
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +8 -15
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +9 -12
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +6 -12
- doctr/models/classification/zoo.py +19 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +14 -26
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +14 -23
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +3 -7
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +18 -19
- doctr/models/kie_predictor/tensorflow.py +13 -14
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +12 -13
- doctr/models/predictor/tensorflow.py +8 -9
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +11 -23
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +12 -22
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +16 -22
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +12 -21
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +12 -20
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +14 -17
- doctr/models/utils/tensorflow.py +17 -16
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +16 -47
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
@@ -21,13 +21,12 @@ class RecognitionPredictor(NestedObject):
|
|
|
21
21
|
"""Implements an object able to identify character sequences in images
|
|
22
22
|
|
|
23
23
|
Args:
|
|
24
|
-
----
|
|
25
24
|
pre_processor: transform inputs for easier batched model inference
|
|
26
25
|
model: core detection architecture
|
|
27
26
|
split_wide_crops: wether to use crop splitting for high aspect ratio crops
|
|
28
27
|
"""
|
|
29
28
|
|
|
30
|
-
_children_names:
|
|
29
|
+
_children_names: list[str] = ["pre_processor", "model"]
|
|
31
30
|
|
|
32
31
|
def __init__(
|
|
33
32
|
self,
|
|
@@ -45,9 +44,9 @@ class RecognitionPredictor(NestedObject):
|
|
|
45
44
|
|
|
46
45
|
def __call__(
|
|
47
46
|
self,
|
|
48
|
-
crops:
|
|
47
|
+
crops: list[np.ndarray | tf.Tensor],
|
|
49
48
|
**kwargs: Any,
|
|
50
|
-
) ->
|
|
49
|
+
) -> list[tuple[str, float]]:
|
|
51
50
|
if len(crops) == 0:
|
|
52
51
|
return []
|
|
53
52
|
# Dimension check
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import torch
|
|
10
11
|
from torch import nn
|
|
@@ -19,7 +20,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
|
|
|
19
20
|
|
|
20
21
|
__all__ = ["SAR", "sar_resnet31"]
|
|
21
22
|
|
|
22
|
-
default_cfgs:
|
|
23
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
24
|
"sar_resnet31": {
|
|
24
25
|
"mean": (0.694, 0.695, 0.693),
|
|
25
26
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -80,7 +81,6 @@ class SARDecoder(nn.Module):
|
|
|
80
81
|
"""Implements decoder module of the SAR model
|
|
81
82
|
|
|
82
83
|
Args:
|
|
83
|
-
----
|
|
84
84
|
rnn_units: number of hidden units in recurrent cells
|
|
85
85
|
max_length: maximum length of a sequence
|
|
86
86
|
vocab_size: number of classes in the model alphabet
|
|
@@ -114,12 +114,12 @@ class SARDecoder(nn.Module):
|
|
|
114
114
|
self,
|
|
115
115
|
features: torch.Tensor, # (N, C, H, W)
|
|
116
116
|
holistic: torch.Tensor, # (N, C)
|
|
117
|
-
gt:
|
|
117
|
+
gt: torch.Tensor | None = None, # (N, L)
|
|
118
118
|
) -> torch.Tensor:
|
|
119
119
|
if gt is not None:
|
|
120
120
|
gt_embedding = self.embed_tgt(gt)
|
|
121
121
|
|
|
122
|
-
logits_list:
|
|
122
|
+
logits_list: list[torch.Tensor] = []
|
|
123
123
|
|
|
124
124
|
for t in range(self.max_length + 1): # 32
|
|
125
125
|
if t == 0:
|
|
@@ -166,7 +166,6 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
166
166
|
Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
167
167
|
|
|
168
168
|
Args:
|
|
169
|
-
----
|
|
170
169
|
feature_extractor: the backbone serving as feature extractor
|
|
171
170
|
vocab: vocabulary used for encoding
|
|
172
171
|
rnn_units: number of hidden units in both encoder and decoder LSTM
|
|
@@ -187,9 +186,9 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
187
186
|
attention_units: int = 512,
|
|
188
187
|
max_length: int = 30,
|
|
189
188
|
dropout_prob: float = 0.0,
|
|
190
|
-
input_shape:
|
|
189
|
+
input_shape: tuple[int, int, int] = (3, 32, 128),
|
|
191
190
|
exportable: bool = False,
|
|
192
|
-
cfg:
|
|
191
|
+
cfg: dict[str, Any] | None = None,
|
|
193
192
|
) -> None:
|
|
194
193
|
super().__init__()
|
|
195
194
|
self.vocab = vocab
|
|
@@ -232,10 +231,10 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
232
231
|
def forward(
|
|
233
232
|
self,
|
|
234
233
|
x: torch.Tensor,
|
|
235
|
-
target:
|
|
234
|
+
target: list[str] | None = None,
|
|
236
235
|
return_model_output: bool = False,
|
|
237
236
|
return_preds: bool = False,
|
|
238
|
-
) ->
|
|
237
|
+
) -> dict[str, Any]:
|
|
239
238
|
features = self.feat_extractor(x)["features"]
|
|
240
239
|
# NOTE: use max instead of functional max_pool2d which leads to ONNX incompatibility (kernel_size)
|
|
241
240
|
# Vertical max pooling (N, C, H, W) --> (N, C, W)
|
|
@@ -254,7 +253,7 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
254
253
|
|
|
255
254
|
decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt))
|
|
256
255
|
|
|
257
|
-
out:
|
|
256
|
+
out: dict[str, Any] = {}
|
|
258
257
|
if self.exportable:
|
|
259
258
|
out["logits"] = decoded_features
|
|
260
259
|
return out
|
|
@@ -263,8 +262,13 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
263
262
|
out["out_map"] = decoded_features
|
|
264
263
|
|
|
265
264
|
if target is None or return_preds:
|
|
265
|
+
# Disable for torch.compile compatibility
|
|
266
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
267
|
+
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
268
|
+
return self.postprocessor(decoded_features)
|
|
269
|
+
|
|
266
270
|
# Post-process boxes
|
|
267
|
-
out["preds"] =
|
|
271
|
+
out["preds"] = _postprocess(decoded_features)
|
|
268
272
|
|
|
269
273
|
if target is not None:
|
|
270
274
|
out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
|
|
@@ -281,19 +285,17 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
281
285
|
Sequences are masked after the EOS character.
|
|
282
286
|
|
|
283
287
|
Args:
|
|
284
|
-
----
|
|
285
288
|
model_output: predicted logits of the model
|
|
286
289
|
gt: the encoded tensor with gt labels
|
|
287
290
|
seq_len: lengths of each gt word inside the batch
|
|
288
291
|
|
|
289
292
|
Returns:
|
|
290
|
-
-------
|
|
291
293
|
The loss of the model on the batch
|
|
292
294
|
"""
|
|
293
295
|
# Input length : number of timesteps
|
|
294
296
|
input_len = model_output.shape[1]
|
|
295
297
|
# Add one for additional <eos> token
|
|
296
|
-
seq_len = seq_len + 1
|
|
298
|
+
seq_len = seq_len + 1 # type: ignore[assignment]
|
|
297
299
|
# Compute loss
|
|
298
300
|
# (N, L, vocab_size + 1)
|
|
299
301
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
|
|
@@ -308,14 +310,13 @@ class SARPostProcessor(RecognitionPostProcessor):
|
|
|
308
310
|
"""Post processor for SAR architectures
|
|
309
311
|
|
|
310
312
|
Args:
|
|
311
|
-
----
|
|
312
313
|
vocab: string containing the ordered sequence of supported characters
|
|
313
314
|
"""
|
|
314
315
|
|
|
315
316
|
def __call__(
|
|
316
317
|
self,
|
|
317
318
|
logits: torch.Tensor,
|
|
318
|
-
) ->
|
|
319
|
+
) -> list[tuple[str, float]]:
|
|
319
320
|
# compute pred with argmax for attention models
|
|
320
321
|
out_idxs = logits.argmax(-1)
|
|
321
322
|
# N x L
|
|
@@ -338,7 +339,7 @@ def _sar(
|
|
|
338
339
|
backbone_fn: Callable[[bool], nn.Module],
|
|
339
340
|
layer: str,
|
|
340
341
|
pretrained_backbone: bool = True,
|
|
341
|
-
ignore_keys:
|
|
342
|
+
ignore_keys: list[str] | None = None,
|
|
342
343
|
**kwargs: Any,
|
|
343
344
|
) -> SAR:
|
|
344
345
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -379,12 +380,10 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
|
|
|
379
380
|
>>> out = model(input_tensor)
|
|
380
381
|
|
|
381
382
|
Args:
|
|
382
|
-
----
|
|
383
383
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
384
384
|
**kwargs: keyword arguments of the SAR architecture
|
|
385
385
|
|
|
386
386
|
Returns:
|
|
387
|
-
-------
|
|
388
387
|
text recognition architecture
|
|
389
388
|
"""
|
|
390
389
|
return _sar(
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import Model, Sequential, layers
|
|
@@ -18,7 +18,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
|
|
|
18
18
|
|
|
19
19
|
__all__ = ["SAR", "sar_resnet31"]
|
|
20
20
|
|
|
21
|
-
default_cfgs:
|
|
21
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
22
22
|
"sar_resnet31": {
|
|
23
23
|
"mean": (0.694, 0.695, 0.693),
|
|
24
24
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -33,7 +33,6 @@ class SAREncoder(layers.Layer, NestedObject):
|
|
|
33
33
|
"""Implements encoder module of the SAR model
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
|
-
----
|
|
37
36
|
rnn_units: number of hidden rnn units
|
|
38
37
|
dropout_prob: dropout probability
|
|
39
38
|
"""
|
|
@@ -58,7 +57,6 @@ class AttentionModule(layers.Layer, NestedObject):
|
|
|
58
57
|
"""Implements attention module of the SAR model
|
|
59
58
|
|
|
60
59
|
Args:
|
|
61
|
-
----
|
|
62
60
|
attention_units: number of hidden attention units
|
|
63
61
|
|
|
64
62
|
"""
|
|
@@ -120,7 +118,6 @@ class SARDecoder(layers.Layer, NestedObject):
|
|
|
120
118
|
"""Implements decoder module of the SAR model
|
|
121
119
|
|
|
122
120
|
Args:
|
|
123
|
-
----
|
|
124
121
|
rnn_units: number of hidden units in recurrent cells
|
|
125
122
|
max_length: maximum length of a sequence
|
|
126
123
|
vocab_size: number of classes in the model alphabet
|
|
@@ -159,13 +156,13 @@ class SARDecoder(layers.Layer, NestedObject):
|
|
|
159
156
|
self,
|
|
160
157
|
features: tf.Tensor,
|
|
161
158
|
holistic: tf.Tensor,
|
|
162
|
-
gt:
|
|
159
|
+
gt: tf.Tensor | None = None,
|
|
163
160
|
**kwargs: Any,
|
|
164
161
|
) -> tf.Tensor:
|
|
165
162
|
if gt is not None:
|
|
166
163
|
gt_embedding = self.embed_tgt(gt, **kwargs)
|
|
167
164
|
|
|
168
|
-
logits_list:
|
|
165
|
+
logits_list: list[tf.Tensor] = []
|
|
169
166
|
|
|
170
167
|
for t in range(self.max_length + 1): # 32
|
|
171
168
|
if t == 0:
|
|
@@ -210,7 +207,6 @@ class SAR(Model, RecognitionModel):
|
|
|
210
207
|
Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
211
208
|
|
|
212
209
|
Args:
|
|
213
|
-
----
|
|
214
210
|
feature_extractor: the backbone serving as feature extractor
|
|
215
211
|
vocab: vocabulary used for encoding
|
|
216
212
|
rnn_units: number of hidden units in both encoder and decoder LSTM
|
|
@@ -223,7 +219,7 @@ class SAR(Model, RecognitionModel):
|
|
|
223
219
|
cfg: dictionary containing information about the model
|
|
224
220
|
"""
|
|
225
221
|
|
|
226
|
-
_children_names:
|
|
222
|
+
_children_names: list[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"]
|
|
227
223
|
|
|
228
224
|
def __init__(
|
|
229
225
|
self,
|
|
@@ -236,7 +232,7 @@ class SAR(Model, RecognitionModel):
|
|
|
236
232
|
num_decoder_cells: int = 2,
|
|
237
233
|
dropout_prob: float = 0.0,
|
|
238
234
|
exportable: bool = False,
|
|
239
|
-
cfg:
|
|
235
|
+
cfg: dict[str, Any] | None = None,
|
|
240
236
|
) -> None:
|
|
241
237
|
super().__init__()
|
|
242
238
|
self.vocab = vocab
|
|
@@ -269,13 +265,11 @@ class SAR(Model, RecognitionModel):
|
|
|
269
265
|
Sequences are masked after the EOS character.
|
|
270
266
|
|
|
271
267
|
Args:
|
|
272
|
-
----
|
|
273
268
|
gt: the encoded tensor with gt labels
|
|
274
269
|
model_output: predicted logits of the model
|
|
275
270
|
seq_len: lengths of each gt word inside the batch
|
|
276
271
|
|
|
277
272
|
Returns:
|
|
278
|
-
-------
|
|
279
273
|
The loss of the model on the batch
|
|
280
274
|
"""
|
|
281
275
|
# Input length : number of timesteps
|
|
@@ -296,11 +290,11 @@ class SAR(Model, RecognitionModel):
|
|
|
296
290
|
def call(
|
|
297
291
|
self,
|
|
298
292
|
x: tf.Tensor,
|
|
299
|
-
target:
|
|
293
|
+
target: list[str] | None = None,
|
|
300
294
|
return_model_output: bool = False,
|
|
301
295
|
return_preds: bool = False,
|
|
302
296
|
**kwargs: Any,
|
|
303
|
-
) ->
|
|
297
|
+
) -> dict[str, Any]:
|
|
304
298
|
features = self.feat_extractor(x, **kwargs)
|
|
305
299
|
# vertical max pooling --> (N, C, W)
|
|
306
300
|
pooled_features = tf.reduce_max(features, axis=1)
|
|
@@ -318,7 +312,7 @@ class SAR(Model, RecognitionModel):
|
|
|
318
312
|
self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
|
|
319
313
|
)
|
|
320
314
|
|
|
321
|
-
out:
|
|
315
|
+
out: dict[str, tf.Tensor] = {}
|
|
322
316
|
if self.exportable:
|
|
323
317
|
out["logits"] = decoded_features
|
|
324
318
|
return out
|
|
@@ -340,14 +334,13 @@ class SARPostProcessor(RecognitionPostProcessor):
|
|
|
340
334
|
"""Post processor for SAR architectures
|
|
341
335
|
|
|
342
336
|
Args:
|
|
343
|
-
----
|
|
344
337
|
vocab: string containing the ordered sequence of supported characters
|
|
345
338
|
"""
|
|
346
339
|
|
|
347
340
|
def __call__(
|
|
348
341
|
self,
|
|
349
342
|
logits: tf.Tensor,
|
|
350
|
-
) ->
|
|
343
|
+
) -> list[tuple[str, float]]:
|
|
351
344
|
# compute pred with argmax for attention models
|
|
352
345
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
353
346
|
# N x L
|
|
@@ -371,7 +364,7 @@ def _sar(
|
|
|
371
364
|
pretrained: bool,
|
|
372
365
|
backbone_fn,
|
|
373
366
|
pretrained_backbone: bool = True,
|
|
374
|
-
input_shape:
|
|
367
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
375
368
|
**kwargs: Any,
|
|
376
369
|
) -> SAR:
|
|
377
370
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -414,12 +407,10 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
|
|
|
414
407
|
>>> out = model(input_tensor)
|
|
415
408
|
|
|
416
409
|
Args:
|
|
417
|
-
----
|
|
418
410
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
419
411
|
**kwargs: keyword arguments of the SAR architecture
|
|
420
412
|
|
|
421
413
|
Returns:
|
|
422
|
-
-------
|
|
423
414
|
text recognition architecture
|
|
424
415
|
"""
|
|
425
416
|
return _sar("sar_resnet31", pretrained, resnet31, **kwargs)
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List
|
|
7
6
|
|
|
8
7
|
from rapidfuzz.distance import Levenshtein
|
|
9
8
|
|
|
@@ -14,18 +13,16 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
|
|
|
14
13
|
"""Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
|
|
15
14
|
|
|
16
15
|
Args:
|
|
17
|
-
----
|
|
18
16
|
a: first char seq, suffix should be similar to b's prefix.
|
|
19
17
|
b: second char seq, prefix should be similar to a's suffix.
|
|
20
18
|
dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
|
|
21
19
|
only used when the mother sequence is splitted on a character repetition
|
|
22
20
|
|
|
23
21
|
Returns:
|
|
24
|
-
-------
|
|
25
22
|
A merged character sequence.
|
|
26
23
|
|
|
27
24
|
Example::
|
|
28
|
-
>>> from doctr.
|
|
25
|
+
>>> from doctr.models.recognition.utils import merge_sequences
|
|
29
26
|
>>> merge_sequences('abcd', 'cdefgh', 1.4)
|
|
30
27
|
'abcdefgh'
|
|
31
28
|
>>> merge_sequences('abcdi', 'cdefgh', 1.4)
|
|
@@ -61,26 +58,24 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
|
|
|
61
58
|
return a[:-1] + b[index - 1 :]
|
|
62
59
|
|
|
63
60
|
|
|
64
|
-
def merge_multi_strings(seq_list:
|
|
61
|
+
def merge_multi_strings(seq_list: list[str], dil_factor: float) -> str:
|
|
65
62
|
"""Recursively merges consecutive string sequences with overlapping characters.
|
|
66
63
|
|
|
67
64
|
Args:
|
|
68
|
-
----
|
|
69
65
|
seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
|
|
70
66
|
dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
|
|
71
67
|
only used when the mother sequence is splitted on a character repetition
|
|
72
68
|
|
|
73
69
|
Returns:
|
|
74
|
-
-------
|
|
75
70
|
A merged character sequence
|
|
76
71
|
|
|
77
72
|
Example::
|
|
78
|
-
>>> from doctr.
|
|
73
|
+
>>> from doctr.models.recognition.utils import merge_multi_sequences
|
|
79
74
|
>>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4)
|
|
80
75
|
'abcdefghijkl'
|
|
81
76
|
"""
|
|
82
77
|
|
|
83
|
-
def _recursive_merge(a: str, seq_list:
|
|
78
|
+
def _recursive_merge(a: str, seq_list: list[str], dil_factor: float) -> str:
|
|
84
79
|
# Recursive version of compute_overlap
|
|
85
80
|
if len(seq_list) == 1:
|
|
86
81
|
return merge_strings(a, seq_list[0], dil_factor)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Tuple
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -17,17 +16,15 @@ class _ViTSTR:
|
|
|
17
16
|
|
|
18
17
|
def build_target(
|
|
19
18
|
self,
|
|
20
|
-
gts:
|
|
21
|
-
) ->
|
|
19
|
+
gts: list[str],
|
|
20
|
+
) -> tuple[np.ndarray, list[int]]:
|
|
22
21
|
"""Encode a list of gts sequences into a np array and gives the corresponding*
|
|
23
22
|
sequence lengths.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
gts: list of ground-truth labels
|
|
28
26
|
|
|
29
27
|
Returns:
|
|
30
|
-
-------
|
|
31
28
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
32
29
|
"""
|
|
33
30
|
encoded = encode_sequences(
|
|
@@ -45,7 +42,6 @@ class _ViTSTRPostProcessor(RecognitionPostProcessor):
|
|
|
45
42
|
"""Abstract class to postprocess the raw output of the model
|
|
46
43
|
|
|
47
44
|
Args:
|
|
48
|
-
----
|
|
49
45
|
vocab: string containing the ordered sequence of supported characters
|
|
50
46
|
"""
|
|
51
47
|
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import torch
|
|
10
11
|
from torch import nn
|
|
@@ -19,7 +20,7 @@ from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
|
19
20
|
|
|
20
21
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
21
22
|
|
|
22
|
-
default_cfgs:
|
|
23
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
24
|
"vitstr_small": {
|
|
24
25
|
"mean": (0.694, 0.695, 0.693),
|
|
25
26
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -42,7 +43,6 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
42
43
|
Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
43
44
|
|
|
44
45
|
Args:
|
|
45
|
-
----
|
|
46
46
|
feature_extractor: the backbone serving as feature extractor
|
|
47
47
|
vocab: vocabulary used for encoding
|
|
48
48
|
embedding_units: number of embedding units
|
|
@@ -59,9 +59,9 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
59
59
|
vocab: str,
|
|
60
60
|
embedding_units: int,
|
|
61
61
|
max_length: int = 32, # different from paper
|
|
62
|
-
input_shape:
|
|
62
|
+
input_shape: tuple[int, int, int] = (3, 32, 128), # different from paper
|
|
63
63
|
exportable: bool = False,
|
|
64
|
-
cfg:
|
|
64
|
+
cfg: dict[str, Any] | None = None,
|
|
65
65
|
) -> None:
|
|
66
66
|
super().__init__()
|
|
67
67
|
self.vocab = vocab
|
|
@@ -77,10 +77,10 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
77
77
|
def forward(
|
|
78
78
|
self,
|
|
79
79
|
x: torch.Tensor,
|
|
80
|
-
target:
|
|
80
|
+
target: list[str] | None = None,
|
|
81
81
|
return_model_output: bool = False,
|
|
82
82
|
return_preds: bool = False,
|
|
83
|
-
) ->
|
|
83
|
+
) -> dict[str, Any]:
|
|
84
84
|
features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
|
|
85
85
|
|
|
86
86
|
if target is not None:
|
|
@@ -98,7 +98,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
98
98
|
logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1)
|
|
99
99
|
decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
|
|
100
100
|
|
|
101
|
-
out:
|
|
101
|
+
out: dict[str, Any] = {}
|
|
102
102
|
if self.exportable:
|
|
103
103
|
out["logits"] = decoded_features
|
|
104
104
|
return out
|
|
@@ -107,8 +107,13 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
107
107
|
out["out_map"] = decoded_features
|
|
108
108
|
|
|
109
109
|
if target is None or return_preds:
|
|
110
|
+
# Disable for torch.compile compatibility
|
|
111
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
112
|
+
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
113
|
+
return self.postprocessor(decoded_features)
|
|
114
|
+
|
|
110
115
|
# Post-process boxes
|
|
111
|
-
out["preds"] =
|
|
116
|
+
out["preds"] = _postprocess(decoded_features)
|
|
112
117
|
|
|
113
118
|
if target is not None:
|
|
114
119
|
out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
|
|
@@ -125,19 +130,17 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
125
130
|
Sequences are masked after the EOS character.
|
|
126
131
|
|
|
127
132
|
Args:
|
|
128
|
-
----
|
|
129
133
|
model_output: predicted logits of the model
|
|
130
134
|
gt: the encoded tensor with gt labels
|
|
131
135
|
seq_len: lengths of each gt word inside the batch
|
|
132
136
|
|
|
133
137
|
Returns:
|
|
134
|
-
-------
|
|
135
138
|
The loss of the model on the batch
|
|
136
139
|
"""
|
|
137
140
|
# Input length : number of steps
|
|
138
141
|
input_len = model_output.shape[1]
|
|
139
142
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
140
|
-
seq_len = seq_len + 1
|
|
143
|
+
seq_len = seq_len + 1 # type: ignore[assignment]
|
|
141
144
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
142
145
|
# The "masked" first gt char is <sos>.
|
|
143
146
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -153,14 +156,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
|
|
|
153
156
|
"""Post processor for ViTSTR architecture
|
|
154
157
|
|
|
155
158
|
Args:
|
|
156
|
-
----
|
|
157
159
|
vocab: string containing the ordered sequence of supported characters
|
|
158
160
|
"""
|
|
159
161
|
|
|
160
162
|
def __call__(
|
|
161
163
|
self,
|
|
162
164
|
logits: torch.Tensor,
|
|
163
|
-
) ->
|
|
165
|
+
) -> list[tuple[str, float]]:
|
|
164
166
|
# compute pred with argmax for attention models
|
|
165
167
|
out_idxs = logits.argmax(-1)
|
|
166
168
|
preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
|
|
@@ -183,7 +185,7 @@ def _vitstr(
|
|
|
183
185
|
pretrained: bool,
|
|
184
186
|
backbone_fn: Callable[[bool], nn.Module],
|
|
185
187
|
layer: str,
|
|
186
|
-
ignore_keys:
|
|
188
|
+
ignore_keys: list[str] | None = None,
|
|
187
189
|
**kwargs: Any,
|
|
188
190
|
) -> ViTSTR:
|
|
189
191
|
# Patch the config
|
|
@@ -228,12 +230,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
228
230
|
>>> out = model(input_tensor)
|
|
229
231
|
|
|
230
232
|
Args:
|
|
231
|
-
----
|
|
232
233
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
233
234
|
kwargs: keyword arguments of the ViTSTR architecture
|
|
234
235
|
|
|
235
236
|
Returns:
|
|
236
|
-
-------
|
|
237
237
|
text recognition architecture
|
|
238
238
|
"""
|
|
239
239
|
return _vitstr(
|
|
@@ -259,12 +259,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
259
259
|
>>> out = model(input_tensor)
|
|
260
260
|
|
|
261
261
|
Args:
|
|
262
|
-
----
|
|
263
262
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
264
263
|
kwargs: keyword arguments of the ViTSTR architecture
|
|
265
264
|
|
|
266
265
|
Returns:
|
|
267
|
-
-------
|
|
268
266
|
text recognition architecture
|
|
269
267
|
"""
|
|
270
268
|
return _vitstr(
|