python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from copy import deepcopy
|
|
7
8
|
from itertools import groupby
|
|
8
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
9
10
|
|
|
10
11
|
import torch
|
|
11
12
|
from torch import nn
|
|
@@ -19,7 +20,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
|
|
|
19
20
|
|
|
20
21
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
21
22
|
|
|
22
|
-
default_cfgs:
|
|
23
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
24
|
"crnn_vgg16_bn": {
|
|
24
25
|
"mean": (0.694, 0.695, 0.693),
|
|
25
26
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -48,7 +49,6 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
48
49
|
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
49
50
|
|
|
50
51
|
Args:
|
|
51
|
-
----
|
|
52
52
|
vocab: string containing the ordered sequence of supported characters
|
|
53
53
|
"""
|
|
54
54
|
|
|
@@ -57,18 +57,16 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
57
57
|
logits: torch.Tensor,
|
|
58
58
|
vocab: str = VOCABS["french"],
|
|
59
59
|
blank: int = 0,
|
|
60
|
-
) ->
|
|
60
|
+
) -> list[tuple[str, float]]:
|
|
61
61
|
"""Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
|
|
62
62
|
<https://github.com/githubharald/CTCDecoder>`_.
|
|
63
63
|
|
|
64
64
|
Args:
|
|
65
|
-
----
|
|
66
65
|
logits: model output, shape: N x T x C
|
|
67
66
|
vocab: vocabulary to use
|
|
68
67
|
blank: index of blank label
|
|
69
68
|
|
|
70
69
|
Returns:
|
|
71
|
-
-------
|
|
72
70
|
A list of tuples: (word, confidence)
|
|
73
71
|
"""
|
|
74
72
|
# Gather the most confident characters, and assign the smallest conf among those to the sequence prob
|
|
@@ -82,16 +80,14 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
82
80
|
|
|
83
81
|
return list(zip(words, probs.tolist()))
|
|
84
82
|
|
|
85
|
-
def __call__(self, logits: torch.Tensor) ->
|
|
83
|
+
def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
86
84
|
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
87
85
|
with label_to_idx mapping dictionnary
|
|
88
86
|
|
|
89
87
|
Args:
|
|
90
|
-
----
|
|
91
88
|
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
92
89
|
|
|
93
90
|
Returns:
|
|
94
|
-
-------
|
|
95
91
|
A tuple of 2 lists: a list of str (words) and a list of float (probs)
|
|
96
92
|
|
|
97
93
|
"""
|
|
@@ -104,7 +100,6 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
104
100
|
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
105
101
|
|
|
106
102
|
Args:
|
|
107
|
-
----
|
|
108
103
|
feature_extractor: the backbone serving as feature extractor
|
|
109
104
|
vocab: vocabulary used for encoding
|
|
110
105
|
rnn_units: number of units in the LSTM layers
|
|
@@ -112,16 +107,16 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
112
107
|
cfg: configuration dictionary
|
|
113
108
|
"""
|
|
114
109
|
|
|
115
|
-
_children_names:
|
|
110
|
+
_children_names: list[str] = ["feat_extractor", "decoder", "linear", "postprocessor"]
|
|
116
111
|
|
|
117
112
|
def __init__(
|
|
118
113
|
self,
|
|
119
114
|
feature_extractor: nn.Module,
|
|
120
115
|
vocab: str,
|
|
121
116
|
rnn_units: int = 128,
|
|
122
|
-
input_shape:
|
|
117
|
+
input_shape: tuple[int, int, int] = (3, 32, 128),
|
|
123
118
|
exportable: bool = False,
|
|
124
|
-
cfg:
|
|
119
|
+
cfg: dict[str, Any] | None = None,
|
|
125
120
|
) -> None:
|
|
126
121
|
super().__init__()
|
|
127
122
|
self.vocab = vocab
|
|
@@ -160,20 +155,27 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
160
155
|
m.weight.data.fill_(1.0)
|
|
161
156
|
m.bias.data.zero_()
|
|
162
157
|
|
|
158
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
159
|
+
"""Load pretrained parameters onto the model
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
163
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
164
|
+
"""
|
|
165
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
166
|
+
|
|
163
167
|
def compute_loss(
|
|
164
168
|
self,
|
|
165
169
|
model_output: torch.Tensor,
|
|
166
|
-
target:
|
|
170
|
+
target: list[str],
|
|
167
171
|
) -> torch.Tensor:
|
|
168
172
|
"""Compute CTC loss for the model.
|
|
169
173
|
|
|
170
174
|
Args:
|
|
171
|
-
----
|
|
172
175
|
model_output: predicted logits of the model
|
|
173
176
|
target: list of target strings
|
|
174
177
|
|
|
175
178
|
Returns:
|
|
176
|
-
-------
|
|
177
179
|
The loss of the model on the batch
|
|
178
180
|
"""
|
|
179
181
|
gt, seq_len = self.build_target(target)
|
|
@@ -196,10 +198,10 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
196
198
|
def forward(
|
|
197
199
|
self,
|
|
198
200
|
x: torch.Tensor,
|
|
199
|
-
target:
|
|
201
|
+
target: list[str] | None = None,
|
|
200
202
|
return_model_output: bool = False,
|
|
201
203
|
return_preds: bool = False,
|
|
202
|
-
) ->
|
|
204
|
+
) -> dict[str, Any]:
|
|
203
205
|
if self.training and target is None:
|
|
204
206
|
raise ValueError("Need to provide labels during training")
|
|
205
207
|
|
|
@@ -211,7 +213,7 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
211
213
|
logits, _ = self.decoder(features_seq)
|
|
212
214
|
logits = self.linear(logits)
|
|
213
215
|
|
|
214
|
-
out:
|
|
216
|
+
out: dict[str, Any] = {}
|
|
215
217
|
if self.exportable:
|
|
216
218
|
out["logits"] = logits
|
|
217
219
|
return out
|
|
@@ -220,8 +222,13 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
220
222
|
out["out_map"] = logits
|
|
221
223
|
|
|
222
224
|
if target is None or return_preds:
|
|
225
|
+
# Disable for torch.compile compatibility
|
|
226
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
227
|
+
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
228
|
+
return self.postprocessor(logits)
|
|
229
|
+
|
|
223
230
|
# Post-process boxes
|
|
224
|
-
out["preds"] =
|
|
231
|
+
out["preds"] = _postprocess(logits)
|
|
225
232
|
|
|
226
233
|
if target is not None:
|
|
227
234
|
out["loss"] = self.compute_loss(logits, target)
|
|
@@ -234,7 +241,7 @@ def _crnn(
|
|
|
234
241
|
pretrained: bool,
|
|
235
242
|
backbone_fn: Callable[[Any], nn.Module],
|
|
236
243
|
pretrained_backbone: bool = True,
|
|
237
|
-
ignore_keys:
|
|
244
|
+
ignore_keys: list[str] | None = None,
|
|
238
245
|
**kwargs: Any,
|
|
239
246
|
) -> CRNN:
|
|
240
247
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -256,7 +263,7 @@ def _crnn(
|
|
|
256
263
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
257
264
|
# remove the last layer weights
|
|
258
265
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
259
|
-
|
|
266
|
+
model.from_pretrained(_cfg["url"], ignore_keys=_ignore_keys)
|
|
260
267
|
|
|
261
268
|
return model
|
|
262
269
|
|
|
@@ -272,12 +279,10 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
272
279
|
>>> out = model(input_tensor)
|
|
273
280
|
|
|
274
281
|
Args:
|
|
275
|
-
----
|
|
276
282
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
277
283
|
**kwargs: keyword arguments of the CRNN architecture
|
|
278
284
|
|
|
279
285
|
Returns:
|
|
280
|
-
-------
|
|
281
286
|
text recognition architecture
|
|
282
287
|
"""
|
|
283
288
|
return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, ignore_keys=["linear.weight", "linear.bias"], **kwargs)
|
|
@@ -294,12 +299,10 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
294
299
|
>>> out = model(input_tensor)
|
|
295
300
|
|
|
296
301
|
Args:
|
|
297
|
-
----
|
|
298
302
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
299
303
|
**kwargs: keyword arguments of the CRNN architecture
|
|
300
304
|
|
|
301
305
|
Returns:
|
|
302
|
-
-------
|
|
303
306
|
text recognition architecture
|
|
304
307
|
"""
|
|
305
308
|
return _crnn(
|
|
@@ -322,12 +325,10 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
322
325
|
>>> out = model(input_tensor)
|
|
323
326
|
|
|
324
327
|
Args:
|
|
325
|
-
----
|
|
326
328
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
327
329
|
**kwargs: keyword arguments of the CRNN architecture
|
|
328
330
|
|
|
329
331
|
Returns:
|
|
330
|
-
-------
|
|
331
332
|
text recognition architecture
|
|
332
333
|
"""
|
|
333
334
|
return _crnn(
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import layers
|
|
@@ -18,7 +18,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
|
|
|
18
18
|
|
|
19
19
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
20
20
|
|
|
21
|
-
default_cfgs:
|
|
21
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
22
22
|
"crnn_vgg16_bn": {
|
|
23
23
|
"mean": (0.694, 0.695, 0.693),
|
|
24
24
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -47,7 +47,6 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
47
47
|
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
48
48
|
|
|
49
49
|
Args:
|
|
50
|
-
----
|
|
51
50
|
vocab: string containing the ordered sequence of supported characters
|
|
52
51
|
ignore_case: if True, ignore case of letters
|
|
53
52
|
ignore_accents: if True, ignore accents of letters
|
|
@@ -58,18 +57,16 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
58
57
|
logits: tf.Tensor,
|
|
59
58
|
beam_width: int = 1,
|
|
60
59
|
top_paths: int = 1,
|
|
61
|
-
) ->
|
|
60
|
+
) -> list[tuple[str, float]] | list[tuple[list[str] | list[float]]]:
|
|
62
61
|
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
63
62
|
with label_to_idx mapping dictionnary
|
|
64
63
|
|
|
65
64
|
Args:
|
|
66
|
-
----
|
|
67
65
|
logits: raw output of the model, shape BATCH_SIZE X SEQ_LEN X NUM_CLASSES + 1
|
|
68
66
|
beam_width: An int scalar >= 0 (beam search beam width).
|
|
69
67
|
top_paths: An int scalar >= 0, <= beam_width (controls output size).
|
|
70
68
|
|
|
71
69
|
Returns:
|
|
72
|
-
-------
|
|
73
70
|
A list of decoded words of length BATCH_SIZE
|
|
74
71
|
|
|
75
72
|
|
|
@@ -114,7 +111,6 @@ class CRNN(RecognitionModel, Model):
|
|
|
114
111
|
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
115
112
|
|
|
116
113
|
Args:
|
|
117
|
-
----
|
|
118
114
|
feature_extractor: the backbone serving as feature extractor
|
|
119
115
|
vocab: vocabulary used for encoding
|
|
120
116
|
rnn_units: number of units in the LSTM layers
|
|
@@ -124,7 +120,7 @@ class CRNN(RecognitionModel, Model):
|
|
|
124
120
|
cfg: configuration dictionary
|
|
125
121
|
"""
|
|
126
122
|
|
|
127
|
-
_children_names:
|
|
123
|
+
_children_names: list[str] = ["feat_extractor", "decoder", "postprocessor"]
|
|
128
124
|
|
|
129
125
|
def __init__(
|
|
130
126
|
self,
|
|
@@ -134,7 +130,7 @@ class CRNN(RecognitionModel, Model):
|
|
|
134
130
|
exportable: bool = False,
|
|
135
131
|
beam_width: int = 1,
|
|
136
132
|
top_paths: int = 1,
|
|
137
|
-
cfg:
|
|
133
|
+
cfg: dict[str, Any] | None = None,
|
|
138
134
|
) -> None:
|
|
139
135
|
# Initialize kernels
|
|
140
136
|
h, w, c = feature_extractor.output_shape[1:]
|
|
@@ -158,20 +154,27 @@ class CRNN(RecognitionModel, Model):
|
|
|
158
154
|
self.beam_width = beam_width
|
|
159
155
|
self.top_paths = top_paths
|
|
160
156
|
|
|
157
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
158
|
+
"""Load pretrained parameters onto the model
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
162
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
163
|
+
"""
|
|
164
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
165
|
+
|
|
161
166
|
def compute_loss(
|
|
162
167
|
self,
|
|
163
168
|
model_output: tf.Tensor,
|
|
164
|
-
target:
|
|
169
|
+
target: list[str],
|
|
165
170
|
) -> tf.Tensor:
|
|
166
171
|
"""Compute CTC loss for the model.
|
|
167
172
|
|
|
168
173
|
Args:
|
|
169
|
-
----
|
|
170
174
|
model_output: predicted logits of the model
|
|
171
175
|
target: lengths of each gt word inside the batch
|
|
172
176
|
|
|
173
177
|
Returns:
|
|
174
|
-
-------
|
|
175
178
|
The loss of the model on the batch
|
|
176
179
|
"""
|
|
177
180
|
gt, seq_len = self.build_target(target)
|
|
@@ -185,13 +188,13 @@ class CRNN(RecognitionModel, Model):
|
|
|
185
188
|
def call(
|
|
186
189
|
self,
|
|
187
190
|
x: tf.Tensor,
|
|
188
|
-
target:
|
|
191
|
+
target: list[str] | None = None,
|
|
189
192
|
return_model_output: bool = False,
|
|
190
193
|
return_preds: bool = False,
|
|
191
194
|
beam_width: int = 1,
|
|
192
195
|
top_paths: int = 1,
|
|
193
196
|
**kwargs: Any,
|
|
194
|
-
) ->
|
|
197
|
+
) -> dict[str, Any]:
|
|
195
198
|
if kwargs.get("training", False) and target is None:
|
|
196
199
|
raise ValueError("Need to provide labels during training")
|
|
197
200
|
|
|
@@ -203,7 +206,7 @@ class CRNN(RecognitionModel, Model):
|
|
|
203
206
|
features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
|
|
204
207
|
logits = _bf16_to_float32(self.decoder(features_seq, **kwargs))
|
|
205
208
|
|
|
206
|
-
out:
|
|
209
|
+
out: dict[str, tf.Tensor] = {}
|
|
207
210
|
if self.exportable:
|
|
208
211
|
out["logits"] = logits
|
|
209
212
|
return out
|
|
@@ -226,7 +229,7 @@ def _crnn(
|
|
|
226
229
|
pretrained: bool,
|
|
227
230
|
backbone_fn,
|
|
228
231
|
pretrained_backbone: bool = True,
|
|
229
|
-
input_shape:
|
|
232
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
230
233
|
**kwargs: Any,
|
|
231
234
|
) -> CRNN:
|
|
232
235
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -249,7 +252,7 @@ def _crnn(
|
|
|
249
252
|
# Load pretrained parameters
|
|
250
253
|
if pretrained:
|
|
251
254
|
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
252
|
-
|
|
255
|
+
model.from_pretrained(_cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
253
256
|
|
|
254
257
|
return model
|
|
255
258
|
|
|
@@ -265,12 +268,10 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
265
268
|
>>> out = model(input_tensor)
|
|
266
269
|
|
|
267
270
|
Args:
|
|
268
|
-
----
|
|
269
271
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
270
272
|
**kwargs: keyword arguments of the CRNN architecture
|
|
271
273
|
|
|
272
274
|
Returns:
|
|
273
|
-
-------
|
|
274
275
|
text recognition architecture
|
|
275
276
|
"""
|
|
276
277
|
return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, **kwargs)
|
|
@@ -287,12 +288,10 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
287
288
|
>>> out = model(input_tensor)
|
|
288
289
|
|
|
289
290
|
Args:
|
|
290
|
-
----
|
|
291
291
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
292
292
|
**kwargs: keyword arguments of the CRNN architecture
|
|
293
293
|
|
|
294
294
|
Returns:
|
|
295
|
-
-------
|
|
296
295
|
text recognition architecture
|
|
297
296
|
"""
|
|
298
297
|
return _crnn("crnn_mobilenet_v3_small", pretrained, mobilenet_v3_small_r, **kwargs)
|
|
@@ -309,12 +308,10 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
309
308
|
>>> out = model(input_tensor)
|
|
310
309
|
|
|
311
310
|
Args:
|
|
312
|
-
----
|
|
313
311
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
314
312
|
**kwargs: keyword arguments of the CRNN architecture
|
|
315
313
|
|
|
316
314
|
Returns:
|
|
317
|
-
-------
|
|
318
315
|
text recognition architecture
|
|
319
316
|
"""
|
|
320
317
|
return _crnn("crnn_mobilenet_v3_large", pretrained, mobilenet_v3_large_r, **kwargs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Tuple
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -17,17 +16,15 @@ class _MASTER:
|
|
|
17
16
|
|
|
18
17
|
def build_target(
|
|
19
18
|
self,
|
|
20
|
-
gts:
|
|
21
|
-
) ->
|
|
19
|
+
gts: list[str],
|
|
20
|
+
) -> tuple[np.ndarray, list[int]]:
|
|
22
21
|
"""Encode a list of gts sequences into a np array and gives the corresponding*
|
|
23
22
|
sequence lengths.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
gts: list of ground-truth labels
|
|
28
26
|
|
|
29
27
|
Returns:
|
|
30
|
-
-------
|
|
31
28
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
32
29
|
"""
|
|
33
30
|
encoded = encode_sequences(
|
|
@@ -46,7 +43,6 @@ class _MASTERPostProcessor(RecognitionPostProcessor):
|
|
|
46
43
|
"""Abstract class to postprocess the raw output of the model
|
|
47
44
|
|
|
48
45
|
Args:
|
|
49
|
-
----
|
|
50
46
|
vocab: string containing the ordered sequence of supported characters
|
|
51
47
|
"""
|
|
52
48
|
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import torch
|
|
10
11
|
from torch import nn
|
|
@@ -21,7 +22,7 @@ from .base import _MASTER, _MASTERPostProcessor
|
|
|
21
22
|
__all__ = ["MASTER", "master"]
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
default_cfgs:
|
|
25
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
25
26
|
"master": {
|
|
26
27
|
"mean": (0.694, 0.695, 0.693),
|
|
27
28
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -37,7 +38,6 @@ class MASTER(_MASTER, nn.Module):
|
|
|
37
38
|
Implementation based on the official Pytorch implementation: <https://github.com/wenwenyu/MASTER-pytorch>`_.
|
|
38
39
|
|
|
39
40
|
Args:
|
|
40
|
-
----
|
|
41
41
|
feature_extractor: the backbone serving as feature extractor
|
|
42
42
|
vocab: vocabulary, (without EOS, SOS, PAD)
|
|
43
43
|
d_model: d parameter for the transformer decoder
|
|
@@ -61,9 +61,9 @@ class MASTER(_MASTER, nn.Module):
|
|
|
61
61
|
num_layers: int = 3,
|
|
62
62
|
max_length: int = 50,
|
|
63
63
|
dropout: float = 0.2,
|
|
64
|
-
input_shape:
|
|
64
|
+
input_shape: tuple[int, int, int] = (3, 32, 128), # different from the paper
|
|
65
65
|
exportable: bool = False,
|
|
66
|
-
cfg:
|
|
66
|
+
cfg: dict[str, Any] | None = None,
|
|
67
67
|
) -> None:
|
|
68
68
|
super().__init__()
|
|
69
69
|
|
|
@@ -102,12 +102,12 @@ class MASTER(_MASTER, nn.Module):
|
|
|
102
102
|
|
|
103
103
|
def make_source_and_target_mask(
|
|
104
104
|
self, source: torch.Tensor, target: torch.Tensor
|
|
105
|
-
) ->
|
|
105
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
106
106
|
# borrowed and slightly modified from https://github.com/wenwenyu/MASTER-pytorch
|
|
107
107
|
# NOTE: nn.TransformerDecoder takes the inverse from this implementation
|
|
108
108
|
# [True, True, True, ..., False, False, False] -> False is masked
|
|
109
109
|
# (N, 1, 1, max_length)
|
|
110
|
-
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
|
|
110
|
+
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
|
|
111
111
|
target_length = target.size(1)
|
|
112
112
|
# sub mask filled diagonal with True = see and False = masked (max_length, max_length)
|
|
113
113
|
# NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
|
|
@@ -130,19 +130,17 @@ class MASTER(_MASTER, nn.Module):
|
|
|
130
130
|
Sequences are masked after the EOS character.
|
|
131
131
|
|
|
132
132
|
Args:
|
|
133
|
-
----
|
|
134
133
|
gt: the encoded tensor with gt labels
|
|
135
134
|
model_output: predicted logits of the model
|
|
136
135
|
seq_len: lengths of each gt word inside the batch
|
|
137
136
|
|
|
138
137
|
Returns:
|
|
139
|
-
-------
|
|
140
138
|
The loss of the model on the batch
|
|
141
139
|
"""
|
|
142
140
|
# Input length : number of timesteps
|
|
143
141
|
input_len = model_output.shape[1]
|
|
144
142
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
145
|
-
seq_len = seq_len + 1
|
|
143
|
+
seq_len = seq_len + 1 # type: ignore[assignment]
|
|
146
144
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
147
145
|
# The "masked" first gt char is <sos>. Delete last logit of the model output.
|
|
148
146
|
cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -153,24 +151,31 @@ class MASTER(_MASTER, nn.Module):
|
|
|
153
151
|
ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype)
|
|
154
152
|
return ce_loss.mean()
|
|
155
153
|
|
|
154
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
155
|
+
"""Load pretrained parameters onto the model
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
159
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
160
|
+
"""
|
|
161
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
162
|
+
|
|
156
163
|
def forward(
|
|
157
164
|
self,
|
|
158
165
|
x: torch.Tensor,
|
|
159
|
-
target:
|
|
166
|
+
target: list[str] | None = None,
|
|
160
167
|
return_model_output: bool = False,
|
|
161
168
|
return_preds: bool = False,
|
|
162
|
-
) ->
|
|
169
|
+
) -> dict[str, Any]:
|
|
163
170
|
"""Call function for training
|
|
164
171
|
|
|
165
172
|
Args:
|
|
166
|
-
----
|
|
167
173
|
x: images
|
|
168
174
|
target: list of str labels
|
|
169
175
|
return_model_output: if True, return logits
|
|
170
176
|
return_preds: if True, decode logits
|
|
171
177
|
|
|
172
178
|
Returns:
|
|
173
|
-
-------
|
|
174
179
|
A dictionnary containing eventually loss, logits and predictions.
|
|
175
180
|
"""
|
|
176
181
|
# Encode
|
|
@@ -181,7 +186,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
181
186
|
# add positional encoding to features
|
|
182
187
|
encoded = self.positional_encoding(features)
|
|
183
188
|
|
|
184
|
-
out:
|
|
189
|
+
out: dict[str, Any] = {}
|
|
185
190
|
|
|
186
191
|
if self.training and target is None:
|
|
187
192
|
raise ValueError("Need to provide labels during training")
|
|
@@ -213,7 +218,13 @@ class MASTER(_MASTER, nn.Module):
|
|
|
213
218
|
out["out_map"] = logits
|
|
214
219
|
|
|
215
220
|
if return_preds:
|
|
216
|
-
|
|
221
|
+
# Disable for torch.compile compatibility
|
|
222
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
223
|
+
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
224
|
+
return self.postprocessor(logits)
|
|
225
|
+
|
|
226
|
+
# Post-process boxes
|
|
227
|
+
out["preds"] = _postprocess(logits)
|
|
217
228
|
|
|
218
229
|
return out
|
|
219
230
|
|
|
@@ -221,12 +232,10 @@ class MASTER(_MASTER, nn.Module):
|
|
|
221
232
|
"""Decode function for prediction
|
|
222
233
|
|
|
223
234
|
Args:
|
|
224
|
-
----
|
|
225
235
|
encoded: input tensor
|
|
226
236
|
|
|
227
237
|
Returns:
|
|
228
|
-
|
|
229
|
-
A Tuple of torch.Tensor: predictions, logits
|
|
238
|
+
A tuple of torch.Tensor: predictions, logits
|
|
230
239
|
"""
|
|
231
240
|
b = encoded.size(0)
|
|
232
241
|
|
|
@@ -254,7 +263,7 @@ class MASTERPostProcessor(_MASTERPostProcessor):
|
|
|
254
263
|
def __call__(
|
|
255
264
|
self,
|
|
256
265
|
logits: torch.Tensor,
|
|
257
|
-
) ->
|
|
266
|
+
) -> list[tuple[str, float]]:
|
|
258
267
|
# compute pred with argmax for attention models
|
|
259
268
|
out_idxs = logits.argmax(-1)
|
|
260
269
|
# N x L
|
|
@@ -277,7 +286,7 @@ def _master(
|
|
|
277
286
|
backbone_fn: Callable[[bool], nn.Module],
|
|
278
287
|
layer: str,
|
|
279
288
|
pretrained_backbone: bool = True,
|
|
280
|
-
ignore_keys:
|
|
289
|
+
ignore_keys: list[str] | None = None,
|
|
281
290
|
**kwargs: Any,
|
|
282
291
|
) -> MASTER:
|
|
283
292
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -301,7 +310,7 @@ def _master(
|
|
|
301
310
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
302
311
|
# remove the last layer weights
|
|
303
312
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
304
|
-
|
|
313
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
305
314
|
|
|
306
315
|
return model
|
|
307
316
|
|
|
@@ -316,12 +325,10 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
|
|
|
316
325
|
>>> out = model(input_tensor)
|
|
317
326
|
|
|
318
327
|
Args:
|
|
319
|
-
----
|
|
320
328
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
321
329
|
**kwargs: keywoard arguments passed to the MASTER architecture
|
|
322
330
|
|
|
323
331
|
Returns:
|
|
324
|
-
-------
|
|
325
332
|
text recognition architecture
|
|
326
333
|
"""
|
|
327
334
|
return _master(
|