python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from copy import deepcopy
|
|
7
8
|
from itertools import groupby
|
|
8
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
9
10
|
|
|
10
11
|
import torch
|
|
11
12
|
from torch import nn
|
|
@@ -19,7 +20,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
|
|
|
19
20
|
|
|
20
21
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
21
22
|
|
|
22
|
-
default_cfgs:
|
|
23
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
24
|
"crnn_vgg16_bn": {
|
|
24
25
|
"mean": (0.694, 0.695, 0.693),
|
|
25
26
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -48,7 +49,6 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
48
49
|
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
49
50
|
|
|
50
51
|
Args:
|
|
51
|
-
----
|
|
52
52
|
vocab: string containing the ordered sequence of supported characters
|
|
53
53
|
"""
|
|
54
54
|
|
|
@@ -57,18 +57,16 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
57
57
|
logits: torch.Tensor,
|
|
58
58
|
vocab: str = VOCABS["french"],
|
|
59
59
|
blank: int = 0,
|
|
60
|
-
) ->
|
|
60
|
+
) -> list[tuple[str, float]]:
|
|
61
61
|
"""Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
|
|
62
62
|
<https://github.com/githubharald/CTCDecoder>`_.
|
|
63
63
|
|
|
64
64
|
Args:
|
|
65
|
-
----
|
|
66
65
|
logits: model output, shape: N x T x C
|
|
67
66
|
vocab: vocabulary to use
|
|
68
67
|
blank: index of blank label
|
|
69
68
|
|
|
70
69
|
Returns:
|
|
71
|
-
-------
|
|
72
70
|
A list of tuples: (word, confidence)
|
|
73
71
|
"""
|
|
74
72
|
# Gather the most confident characters, and assign the smallest conf among those to the sequence prob
|
|
@@ -82,16 +80,14 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
82
80
|
|
|
83
81
|
return list(zip(words, probs.tolist()))
|
|
84
82
|
|
|
85
|
-
def __call__(self, logits: torch.Tensor) ->
|
|
83
|
+
def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
86
84
|
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
87
85
|
with label_to_idx mapping dictionnary
|
|
88
86
|
|
|
89
87
|
Args:
|
|
90
|
-
----
|
|
91
88
|
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
92
89
|
|
|
93
90
|
Returns:
|
|
94
|
-
-------
|
|
95
91
|
A tuple of 2 lists: a list of str (words) and a list of float (probs)
|
|
96
92
|
|
|
97
93
|
"""
|
|
@@ -104,7 +100,6 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
104
100
|
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
105
101
|
|
|
106
102
|
Args:
|
|
107
|
-
----
|
|
108
103
|
feature_extractor: the backbone serving as feature extractor
|
|
109
104
|
vocab: vocabulary used for encoding
|
|
110
105
|
rnn_units: number of units in the LSTM layers
|
|
@@ -112,16 +107,16 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
112
107
|
cfg: configuration dictionary
|
|
113
108
|
"""
|
|
114
109
|
|
|
115
|
-
_children_names:
|
|
110
|
+
_children_names: list[str] = ["feat_extractor", "decoder", "linear", "postprocessor"]
|
|
116
111
|
|
|
117
112
|
def __init__(
|
|
118
113
|
self,
|
|
119
114
|
feature_extractor: nn.Module,
|
|
120
115
|
vocab: str,
|
|
121
116
|
rnn_units: int = 128,
|
|
122
|
-
input_shape:
|
|
117
|
+
input_shape: tuple[int, int, int] = (3, 32, 128),
|
|
123
118
|
exportable: bool = False,
|
|
124
|
-
cfg:
|
|
119
|
+
cfg: dict[str, Any] | None = None,
|
|
125
120
|
) -> None:
|
|
126
121
|
super().__init__()
|
|
127
122
|
self.vocab = vocab
|
|
@@ -163,17 +158,15 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
163
158
|
def compute_loss(
|
|
164
159
|
self,
|
|
165
160
|
model_output: torch.Tensor,
|
|
166
|
-
target:
|
|
161
|
+
target: list[str],
|
|
167
162
|
) -> torch.Tensor:
|
|
168
163
|
"""Compute CTC loss for the model.
|
|
169
164
|
|
|
170
165
|
Args:
|
|
171
|
-
----
|
|
172
166
|
model_output: predicted logits of the model
|
|
173
167
|
target: list of target strings
|
|
174
168
|
|
|
175
169
|
Returns:
|
|
176
|
-
-------
|
|
177
170
|
The loss of the model on the batch
|
|
178
171
|
"""
|
|
179
172
|
gt, seq_len = self.build_target(target)
|
|
@@ -196,10 +189,10 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
196
189
|
def forward(
|
|
197
190
|
self,
|
|
198
191
|
x: torch.Tensor,
|
|
199
|
-
target:
|
|
192
|
+
target: list[str] | None = None,
|
|
200
193
|
return_model_output: bool = False,
|
|
201
194
|
return_preds: bool = False,
|
|
202
|
-
) ->
|
|
195
|
+
) -> dict[str, Any]:
|
|
203
196
|
if self.training and target is None:
|
|
204
197
|
raise ValueError("Need to provide labels during training")
|
|
205
198
|
|
|
@@ -211,7 +204,7 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
211
204
|
logits, _ = self.decoder(features_seq)
|
|
212
205
|
logits = self.linear(logits)
|
|
213
206
|
|
|
214
|
-
out:
|
|
207
|
+
out: dict[str, Any] = {}
|
|
215
208
|
if self.exportable:
|
|
216
209
|
out["logits"] = logits
|
|
217
210
|
return out
|
|
@@ -220,8 +213,13 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
220
213
|
out["out_map"] = logits
|
|
221
214
|
|
|
222
215
|
if target is None or return_preds:
|
|
216
|
+
# Disable for torch.compile compatibility
|
|
217
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
218
|
+
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
219
|
+
return self.postprocessor(logits)
|
|
220
|
+
|
|
223
221
|
# Post-process boxes
|
|
224
|
-
out["preds"] =
|
|
222
|
+
out["preds"] = _postprocess(logits)
|
|
225
223
|
|
|
226
224
|
if target is not None:
|
|
227
225
|
out["loss"] = self.compute_loss(logits, target)
|
|
@@ -234,7 +232,7 @@ def _crnn(
|
|
|
234
232
|
pretrained: bool,
|
|
235
233
|
backbone_fn: Callable[[Any], nn.Module],
|
|
236
234
|
pretrained_backbone: bool = True,
|
|
237
|
-
ignore_keys:
|
|
235
|
+
ignore_keys: list[str] | None = None,
|
|
238
236
|
**kwargs: Any,
|
|
239
237
|
) -> CRNN:
|
|
240
238
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -272,12 +270,10 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
272
270
|
>>> out = model(input_tensor)
|
|
273
271
|
|
|
274
272
|
Args:
|
|
275
|
-
----
|
|
276
273
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
277
274
|
**kwargs: keyword arguments of the CRNN architecture
|
|
278
275
|
|
|
279
276
|
Returns:
|
|
280
|
-
-------
|
|
281
277
|
text recognition architecture
|
|
282
278
|
"""
|
|
283
279
|
return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, ignore_keys=["linear.weight", "linear.bias"], **kwargs)
|
|
@@ -294,12 +290,10 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
294
290
|
>>> out = model(input_tensor)
|
|
295
291
|
|
|
296
292
|
Args:
|
|
297
|
-
----
|
|
298
293
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
299
294
|
**kwargs: keyword arguments of the CRNN architecture
|
|
300
295
|
|
|
301
296
|
Returns:
|
|
302
|
-
-------
|
|
303
297
|
text recognition architecture
|
|
304
298
|
"""
|
|
305
299
|
return _crnn(
|
|
@@ -322,12 +316,10 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
322
316
|
>>> out = model(input_tensor)
|
|
323
317
|
|
|
324
318
|
Args:
|
|
325
|
-
----
|
|
326
319
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
327
320
|
**kwargs: keyword arguments of the CRNN architecture
|
|
328
321
|
|
|
329
322
|
Returns:
|
|
330
|
-
-------
|
|
331
323
|
text recognition architecture
|
|
332
324
|
"""
|
|
333
325
|
return _crnn(
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import layers
|
|
@@ -13,32 +13,32 @@ from tensorflow.keras.models import Model, Sequential
|
|
|
13
13
|
from doctr.datasets import VOCABS
|
|
14
14
|
|
|
15
15
|
from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
|
|
16
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
16
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
17
17
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
18
18
|
|
|
19
19
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
20
20
|
|
|
21
|
-
default_cfgs:
|
|
21
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
22
22
|
"crnn_vgg16_bn": {
|
|
23
23
|
"mean": (0.694, 0.695, 0.693),
|
|
24
24
|
"std": (0.299, 0.296, 0.301),
|
|
25
25
|
"input_shape": (32, 128, 3),
|
|
26
26
|
"vocab": VOCABS["legacy_french"],
|
|
27
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
27
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0",
|
|
28
28
|
},
|
|
29
29
|
"crnn_mobilenet_v3_small": {
|
|
30
30
|
"mean": (0.694, 0.695, 0.693),
|
|
31
31
|
"std": (0.299, 0.296, 0.301),
|
|
32
32
|
"input_shape": (32, 128, 3),
|
|
33
33
|
"vocab": VOCABS["french"],
|
|
34
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
34
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0",
|
|
35
35
|
},
|
|
36
36
|
"crnn_mobilenet_v3_large": {
|
|
37
37
|
"mean": (0.694, 0.695, 0.693),
|
|
38
38
|
"std": (0.299, 0.296, 0.301),
|
|
39
39
|
"input_shape": (32, 128, 3),
|
|
40
40
|
"vocab": VOCABS["french"],
|
|
41
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
41
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0",
|
|
42
42
|
},
|
|
43
43
|
}
|
|
44
44
|
|
|
@@ -47,7 +47,6 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
47
47
|
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
48
48
|
|
|
49
49
|
Args:
|
|
50
|
-
----
|
|
51
50
|
vocab: string containing the ordered sequence of supported characters
|
|
52
51
|
ignore_case: if True, ignore case of letters
|
|
53
52
|
ignore_accents: if True, ignore accents of letters
|
|
@@ -58,18 +57,16 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
58
57
|
logits: tf.Tensor,
|
|
59
58
|
beam_width: int = 1,
|
|
60
59
|
top_paths: int = 1,
|
|
61
|
-
) ->
|
|
60
|
+
) -> list[tuple[str, float]] | list[tuple[list[str] | list[float]]]:
|
|
62
61
|
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
63
62
|
with label_to_idx mapping dictionnary
|
|
64
63
|
|
|
65
64
|
Args:
|
|
66
|
-
----
|
|
67
65
|
logits: raw output of the model, shape BATCH_SIZE X SEQ_LEN X NUM_CLASSES + 1
|
|
68
66
|
beam_width: An int scalar >= 0 (beam search beam width).
|
|
69
67
|
top_paths: An int scalar >= 0, <= beam_width (controls output size).
|
|
70
68
|
|
|
71
69
|
Returns:
|
|
72
|
-
-------
|
|
73
70
|
A list of decoded words of length BATCH_SIZE
|
|
74
71
|
|
|
75
72
|
|
|
@@ -114,7 +111,6 @@ class CRNN(RecognitionModel, Model):
|
|
|
114
111
|
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
115
112
|
|
|
116
113
|
Args:
|
|
117
|
-
----
|
|
118
114
|
feature_extractor: the backbone serving as feature extractor
|
|
119
115
|
vocab: vocabulary used for encoding
|
|
120
116
|
rnn_units: number of units in the LSTM layers
|
|
@@ -124,17 +120,17 @@ class CRNN(RecognitionModel, Model):
|
|
|
124
120
|
cfg: configuration dictionary
|
|
125
121
|
"""
|
|
126
122
|
|
|
127
|
-
_children_names:
|
|
123
|
+
_children_names: list[str] = ["feat_extractor", "decoder", "postprocessor"]
|
|
128
124
|
|
|
129
125
|
def __init__(
|
|
130
126
|
self,
|
|
131
|
-
feature_extractor:
|
|
127
|
+
feature_extractor: Model,
|
|
132
128
|
vocab: str,
|
|
133
129
|
rnn_units: int = 128,
|
|
134
130
|
exportable: bool = False,
|
|
135
131
|
beam_width: int = 1,
|
|
136
132
|
top_paths: int = 1,
|
|
137
|
-
cfg:
|
|
133
|
+
cfg: dict[str, Any] | None = None,
|
|
138
134
|
) -> None:
|
|
139
135
|
# Initialize kernels
|
|
140
136
|
h, w, c = feature_extractor.output_shape[1:]
|
|
@@ -161,17 +157,15 @@ class CRNN(RecognitionModel, Model):
|
|
|
161
157
|
def compute_loss(
|
|
162
158
|
self,
|
|
163
159
|
model_output: tf.Tensor,
|
|
164
|
-
target:
|
|
160
|
+
target: list[str],
|
|
165
161
|
) -> tf.Tensor:
|
|
166
162
|
"""Compute CTC loss for the model.
|
|
167
163
|
|
|
168
164
|
Args:
|
|
169
|
-
----
|
|
170
165
|
model_output: predicted logits of the model
|
|
171
166
|
target: lengths of each gt word inside the batch
|
|
172
167
|
|
|
173
168
|
Returns:
|
|
174
|
-
-------
|
|
175
169
|
The loss of the model on the batch
|
|
176
170
|
"""
|
|
177
171
|
gt, seq_len = self.build_target(target)
|
|
@@ -185,13 +179,13 @@ class CRNN(RecognitionModel, Model):
|
|
|
185
179
|
def call(
|
|
186
180
|
self,
|
|
187
181
|
x: tf.Tensor,
|
|
188
|
-
target:
|
|
182
|
+
target: list[str] | None = None,
|
|
189
183
|
return_model_output: bool = False,
|
|
190
184
|
return_preds: bool = False,
|
|
191
185
|
beam_width: int = 1,
|
|
192
186
|
top_paths: int = 1,
|
|
193
187
|
**kwargs: Any,
|
|
194
|
-
) ->
|
|
188
|
+
) -> dict[str, Any]:
|
|
195
189
|
if kwargs.get("training", False) and target is None:
|
|
196
190
|
raise ValueError("Need to provide labels during training")
|
|
197
191
|
|
|
@@ -203,7 +197,7 @@ class CRNN(RecognitionModel, Model):
|
|
|
203
197
|
features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
|
|
204
198
|
logits = _bf16_to_float32(self.decoder(features_seq, **kwargs))
|
|
205
199
|
|
|
206
|
-
out:
|
|
200
|
+
out: dict[str, tf.Tensor] = {}
|
|
207
201
|
if self.exportable:
|
|
208
202
|
out["logits"] = logits
|
|
209
203
|
return out
|
|
@@ -226,7 +220,7 @@ def _crnn(
|
|
|
226
220
|
pretrained: bool,
|
|
227
221
|
backbone_fn,
|
|
228
222
|
pretrained_backbone: bool = True,
|
|
229
|
-
input_shape:
|
|
223
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
230
224
|
**kwargs: Any,
|
|
231
225
|
) -> CRNN:
|
|
232
226
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -245,9 +239,11 @@ def _crnn(
|
|
|
245
239
|
|
|
246
240
|
# Build the model
|
|
247
241
|
model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
|
|
242
|
+
_build_model(model)
|
|
248
243
|
# Load pretrained parameters
|
|
249
244
|
if pretrained:
|
|
250
|
-
|
|
245
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
246
|
+
load_pretrained_params(model, _cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
251
247
|
|
|
252
248
|
return model
|
|
253
249
|
|
|
@@ -263,12 +259,10 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
263
259
|
>>> out = model(input_tensor)
|
|
264
260
|
|
|
265
261
|
Args:
|
|
266
|
-
----
|
|
267
262
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
268
263
|
**kwargs: keyword arguments of the CRNN architecture
|
|
269
264
|
|
|
270
265
|
Returns:
|
|
271
|
-
-------
|
|
272
266
|
text recognition architecture
|
|
273
267
|
"""
|
|
274
268
|
return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, **kwargs)
|
|
@@ -285,12 +279,10 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
285
279
|
>>> out = model(input_tensor)
|
|
286
280
|
|
|
287
281
|
Args:
|
|
288
|
-
----
|
|
289
282
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
290
283
|
**kwargs: keyword arguments of the CRNN architecture
|
|
291
284
|
|
|
292
285
|
Returns:
|
|
293
|
-
-------
|
|
294
286
|
text recognition architecture
|
|
295
287
|
"""
|
|
296
288
|
return _crnn("crnn_mobilenet_v3_small", pretrained, mobilenet_v3_small_r, **kwargs)
|
|
@@ -307,12 +299,10 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
307
299
|
>>> out = model(input_tensor)
|
|
308
300
|
|
|
309
301
|
Args:
|
|
310
|
-
----
|
|
311
302
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
312
303
|
**kwargs: keyword arguments of the CRNN architecture
|
|
313
304
|
|
|
314
305
|
Returns:
|
|
315
|
-
-------
|
|
316
306
|
text recognition architecture
|
|
317
307
|
"""
|
|
318
308
|
return _crnn("crnn_mobilenet_v3_large", pretrained, mobilenet_v3_large_r, **kwargs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Tuple
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -17,17 +16,15 @@ class _MASTER:
|
|
|
17
16
|
|
|
18
17
|
def build_target(
|
|
19
18
|
self,
|
|
20
|
-
gts:
|
|
21
|
-
) ->
|
|
19
|
+
gts: list[str],
|
|
20
|
+
) -> tuple[np.ndarray, list[int]]:
|
|
22
21
|
"""Encode a list of gts sequences into a np array and gives the corresponding*
|
|
23
22
|
sequence lengths.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
gts: list of ground-truth labels
|
|
28
26
|
|
|
29
27
|
Returns:
|
|
30
|
-
-------
|
|
31
28
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
32
29
|
"""
|
|
33
30
|
encoded = encode_sequences(
|
|
@@ -46,7 +43,6 @@ class _MASTERPostProcessor(RecognitionPostProcessor):
|
|
|
46
43
|
"""Abstract class to postprocess the raw output of the model
|
|
47
44
|
|
|
48
45
|
Args:
|
|
49
|
-
----
|
|
50
46
|
vocab: string containing the ordered sequence of supported characters
|
|
51
47
|
"""
|
|
52
48
|
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import torch
|
|
10
11
|
from torch import nn
|
|
@@ -21,7 +22,7 @@ from .base import _MASTER, _MASTERPostProcessor
|
|
|
21
22
|
__all__ = ["MASTER", "master"]
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
default_cfgs:
|
|
25
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
25
26
|
"master": {
|
|
26
27
|
"mean": (0.694, 0.695, 0.693),
|
|
27
28
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -37,7 +38,6 @@ class MASTER(_MASTER, nn.Module):
|
|
|
37
38
|
Implementation based on the official Pytorch implementation: <https://github.com/wenwenyu/MASTER-pytorch>`_.
|
|
38
39
|
|
|
39
40
|
Args:
|
|
40
|
-
----
|
|
41
41
|
feature_extractor: the backbone serving as feature extractor
|
|
42
42
|
vocab: vocabulary, (without EOS, SOS, PAD)
|
|
43
43
|
d_model: d parameter for the transformer decoder
|
|
@@ -61,9 +61,9 @@ class MASTER(_MASTER, nn.Module):
|
|
|
61
61
|
num_layers: int = 3,
|
|
62
62
|
max_length: int = 50,
|
|
63
63
|
dropout: float = 0.2,
|
|
64
|
-
input_shape:
|
|
64
|
+
input_shape: tuple[int, int, int] = (3, 32, 128), # different from the paper
|
|
65
65
|
exportable: bool = False,
|
|
66
|
-
cfg:
|
|
66
|
+
cfg: dict[str, Any] | None = None,
|
|
67
67
|
) -> None:
|
|
68
68
|
super().__init__()
|
|
69
69
|
|
|
@@ -102,12 +102,12 @@ class MASTER(_MASTER, nn.Module):
|
|
|
102
102
|
|
|
103
103
|
def make_source_and_target_mask(
|
|
104
104
|
self, source: torch.Tensor, target: torch.Tensor
|
|
105
|
-
) ->
|
|
105
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
106
106
|
# borrowed and slightly modified from https://github.com/wenwenyu/MASTER-pytorch
|
|
107
107
|
# NOTE: nn.TransformerDecoder takes the inverse from this implementation
|
|
108
108
|
# [True, True, True, ..., False, False, False] -> False is masked
|
|
109
109
|
# (N, 1, 1, max_length)
|
|
110
|
-
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
|
|
110
|
+
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
|
|
111
111
|
target_length = target.size(1)
|
|
112
112
|
# sub mask filled diagonal with True = see and False = masked (max_length, max_length)
|
|
113
113
|
# NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
|
|
@@ -130,19 +130,17 @@ class MASTER(_MASTER, nn.Module):
|
|
|
130
130
|
Sequences are masked after the EOS character.
|
|
131
131
|
|
|
132
132
|
Args:
|
|
133
|
-
----
|
|
134
133
|
gt: the encoded tensor with gt labels
|
|
135
134
|
model_output: predicted logits of the model
|
|
136
135
|
seq_len: lengths of each gt word inside the batch
|
|
137
136
|
|
|
138
137
|
Returns:
|
|
139
|
-
-------
|
|
140
138
|
The loss of the model on the batch
|
|
141
139
|
"""
|
|
142
140
|
# Input length : number of timesteps
|
|
143
141
|
input_len = model_output.shape[1]
|
|
144
142
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
145
|
-
seq_len = seq_len + 1
|
|
143
|
+
seq_len = seq_len + 1 # type: ignore[assignment]
|
|
146
144
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
147
145
|
# The "masked" first gt char is <sos>. Delete last logit of the model output.
|
|
148
146
|
cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -156,21 +154,19 @@ class MASTER(_MASTER, nn.Module):
|
|
|
156
154
|
def forward(
|
|
157
155
|
self,
|
|
158
156
|
x: torch.Tensor,
|
|
159
|
-
target:
|
|
157
|
+
target: list[str] | None = None,
|
|
160
158
|
return_model_output: bool = False,
|
|
161
159
|
return_preds: bool = False,
|
|
162
|
-
) ->
|
|
160
|
+
) -> dict[str, Any]:
|
|
163
161
|
"""Call function for training
|
|
164
162
|
|
|
165
163
|
Args:
|
|
166
|
-
----
|
|
167
164
|
x: images
|
|
168
165
|
target: list of str labels
|
|
169
166
|
return_model_output: if True, return logits
|
|
170
167
|
return_preds: if True, decode logits
|
|
171
168
|
|
|
172
169
|
Returns:
|
|
173
|
-
-------
|
|
174
170
|
A dictionnary containing eventually loss, logits and predictions.
|
|
175
171
|
"""
|
|
176
172
|
# Encode
|
|
@@ -181,7 +177,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
181
177
|
# add positional encoding to features
|
|
182
178
|
encoded = self.positional_encoding(features)
|
|
183
179
|
|
|
184
|
-
out:
|
|
180
|
+
out: dict[str, Any] = {}
|
|
185
181
|
|
|
186
182
|
if self.training and target is None:
|
|
187
183
|
raise ValueError("Need to provide labels during training")
|
|
@@ -213,7 +209,13 @@ class MASTER(_MASTER, nn.Module):
|
|
|
213
209
|
out["out_map"] = logits
|
|
214
210
|
|
|
215
211
|
if return_preds:
|
|
216
|
-
|
|
212
|
+
# Disable for torch.compile compatibility
|
|
213
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
214
|
+
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
215
|
+
return self.postprocessor(logits)
|
|
216
|
+
|
|
217
|
+
# Post-process boxes
|
|
218
|
+
out["preds"] = _postprocess(logits)
|
|
217
219
|
|
|
218
220
|
return out
|
|
219
221
|
|
|
@@ -221,12 +223,10 @@ class MASTER(_MASTER, nn.Module):
|
|
|
221
223
|
"""Decode function for prediction
|
|
222
224
|
|
|
223
225
|
Args:
|
|
224
|
-
----
|
|
225
226
|
encoded: input tensor
|
|
226
227
|
|
|
227
228
|
Returns:
|
|
228
|
-
|
|
229
|
-
A Tuple of torch.Tensor: predictions, logits
|
|
229
|
+
A tuple of torch.Tensor: predictions, logits
|
|
230
230
|
"""
|
|
231
231
|
b = encoded.size(0)
|
|
232
232
|
|
|
@@ -254,7 +254,7 @@ class MASTERPostProcessor(_MASTERPostProcessor):
|
|
|
254
254
|
def __call__(
|
|
255
255
|
self,
|
|
256
256
|
logits: torch.Tensor,
|
|
257
|
-
) ->
|
|
257
|
+
) -> list[tuple[str, float]]:
|
|
258
258
|
# compute pred with argmax for attention models
|
|
259
259
|
out_idxs = logits.argmax(-1)
|
|
260
260
|
# N x L
|
|
@@ -277,7 +277,7 @@ def _master(
|
|
|
277
277
|
backbone_fn: Callable[[bool], nn.Module],
|
|
278
278
|
layer: str,
|
|
279
279
|
pretrained_backbone: bool = True,
|
|
280
|
-
ignore_keys:
|
|
280
|
+
ignore_keys: list[str] | None = None,
|
|
281
281
|
**kwargs: Any,
|
|
282
282
|
) -> MASTER:
|
|
283
283
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -316,12 +316,10 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
|
|
|
316
316
|
>>> out = model(input_tensor)
|
|
317
317
|
|
|
318
318
|
Args:
|
|
319
|
-
----
|
|
320
319
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
321
320
|
**kwargs: keywoard arguments passed to the MASTER architecture
|
|
322
321
|
|
|
323
322
|
Returns:
|
|
324
|
-
-------
|
|
325
323
|
text recognition architecture
|
|
326
324
|
"""
|
|
327
325
|
return _master(
|