python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import Model, layers
|
|
@@ -13,19 +13,19 @@ from doctr.datasets import VOCABS
|
|
|
13
13
|
from doctr.models.classification import magc_resnet31
|
|
14
14
|
from doctr.models.modules.transformer import Decoder, PositionalEncoding
|
|
15
15
|
|
|
16
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
16
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
17
17
|
from .base import _MASTER, _MASTERPostProcessor
|
|
18
18
|
|
|
19
19
|
__all__ = ["MASTER", "master"]
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
default_cfgs:
|
|
22
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
23
|
"master": {
|
|
24
24
|
"mean": (0.694, 0.695, 0.693),
|
|
25
25
|
"std": (0.299, 0.296, 0.301),
|
|
26
26
|
"input_shape": (32, 128, 3),
|
|
27
27
|
"vocab": VOCABS["french"],
|
|
28
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
28
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/master-d7fdaeff.weights.h5&src=0",
|
|
29
29
|
},
|
|
30
30
|
}
|
|
31
31
|
|
|
@@ -35,7 +35,6 @@ class MASTER(_MASTER, Model):
|
|
|
35
35
|
Implementation based on the official TF implementation: <https://github.com/jiangxiluning/MASTER-TF>`_.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
|
-
----
|
|
39
38
|
feature_extractor: the backbone serving as feature extractor
|
|
40
39
|
vocab: vocabulary, (without EOS, SOS, PAD)
|
|
41
40
|
d_model: d parameter for the transformer decoder
|
|
@@ -51,7 +50,7 @@ class MASTER(_MASTER, Model):
|
|
|
51
50
|
|
|
52
51
|
def __init__(
|
|
53
52
|
self,
|
|
54
|
-
feature_extractor:
|
|
53
|
+
feature_extractor: Model,
|
|
55
54
|
vocab: str,
|
|
56
55
|
d_model: int = 512,
|
|
57
56
|
dff: int = 2048,
|
|
@@ -59,9 +58,9 @@ class MASTER(_MASTER, Model):
|
|
|
59
58
|
num_layers: int = 3,
|
|
60
59
|
max_length: int = 50,
|
|
61
60
|
dropout: float = 0.2,
|
|
62
|
-
input_shape:
|
|
61
|
+
input_shape: tuple[int, int, int] = (32, 128, 3), # different from the paper
|
|
63
62
|
exportable: bool = False,
|
|
64
|
-
cfg:
|
|
63
|
+
cfg: dict[str, Any] | None = None,
|
|
65
64
|
) -> None:
|
|
66
65
|
super().__init__()
|
|
67
66
|
|
|
@@ -89,7 +88,7 @@ class MASTER(_MASTER, Model):
|
|
|
89
88
|
self.postprocessor = MASTERPostProcessor(vocab=self.vocab)
|
|
90
89
|
|
|
91
90
|
@tf.function
|
|
92
|
-
def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) ->
|
|
91
|
+
def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
|
|
93
92
|
# [1, 1, 1, ..., 0, 0, 0] -> 0 is masked
|
|
94
93
|
# (N, 1, 1, max_length)
|
|
95
94
|
target_pad_mask = tf.cast(tf.math.not_equal(target, self.vocab_size + 2), dtype=tf.uint8)
|
|
@@ -109,19 +108,17 @@ class MASTER(_MASTER, Model):
|
|
|
109
108
|
def compute_loss(
|
|
110
109
|
model_output: tf.Tensor,
|
|
111
110
|
gt: tf.Tensor,
|
|
112
|
-
seq_len:
|
|
111
|
+
seq_len: list[int],
|
|
113
112
|
) -> tf.Tensor:
|
|
114
113
|
"""Compute categorical cross-entropy loss for the model.
|
|
115
114
|
Sequences are masked after the EOS character.
|
|
116
115
|
|
|
117
116
|
Args:
|
|
118
|
-
----
|
|
119
117
|
gt: the encoded tensor with gt labels
|
|
120
118
|
model_output: predicted logits of the model
|
|
121
119
|
seq_len: lengths of each gt word inside the batch
|
|
122
120
|
|
|
123
121
|
Returns:
|
|
124
|
-
-------
|
|
125
122
|
The loss of the model on the batch
|
|
126
123
|
"""
|
|
127
124
|
# Input length : number of timesteps
|
|
@@ -144,15 +141,14 @@ class MASTER(_MASTER, Model):
|
|
|
144
141
|
def call(
|
|
145
142
|
self,
|
|
146
143
|
x: tf.Tensor,
|
|
147
|
-
target:
|
|
144
|
+
target: list[str] | None = None,
|
|
148
145
|
return_model_output: bool = False,
|
|
149
146
|
return_preds: bool = False,
|
|
150
147
|
**kwargs: Any,
|
|
151
|
-
) ->
|
|
148
|
+
) -> dict[str, Any]:
|
|
152
149
|
"""Call function for training
|
|
153
150
|
|
|
154
151
|
Args:
|
|
155
|
-
----
|
|
156
152
|
x: images
|
|
157
153
|
target: list of str labels
|
|
158
154
|
return_model_output: if True, return logits
|
|
@@ -160,7 +156,6 @@ class MASTER(_MASTER, Model):
|
|
|
160
156
|
**kwargs: keyword arguments passed to the decoder
|
|
161
157
|
|
|
162
158
|
Returns:
|
|
163
|
-
-------
|
|
164
159
|
A dictionnary containing eventually loss, logits and predictions.
|
|
165
160
|
"""
|
|
166
161
|
# Encode
|
|
@@ -171,7 +166,7 @@ class MASTER(_MASTER, Model):
|
|
|
171
166
|
# add positional encoding to features
|
|
172
167
|
encoded = self.positional_encoding(feature, **kwargs)
|
|
173
168
|
|
|
174
|
-
out:
|
|
169
|
+
out: dict[str, tf.Tensor] = {}
|
|
175
170
|
|
|
176
171
|
if kwargs.get("training", False) and target is None:
|
|
177
172
|
raise ValueError("Need to provide labels during training")
|
|
@@ -209,13 +204,11 @@ class MASTER(_MASTER, Model):
|
|
|
209
204
|
"""Decode function for prediction
|
|
210
205
|
|
|
211
206
|
Args:
|
|
212
|
-
----
|
|
213
207
|
encoded: encoded features
|
|
214
208
|
**kwargs: keyword arguments passed to the decoder
|
|
215
209
|
|
|
216
210
|
Returns:
|
|
217
|
-
|
|
218
|
-
A Tuple of tf.Tensor: predictions, logits
|
|
211
|
+
A tuple of tf.Tensor: predictions, logits
|
|
219
212
|
"""
|
|
220
213
|
b = encoded.shape[0]
|
|
221
214
|
|
|
@@ -247,14 +240,13 @@ class MASTERPostProcessor(_MASTERPostProcessor):
|
|
|
247
240
|
"""Post processor for MASTER architectures
|
|
248
241
|
|
|
249
242
|
Args:
|
|
250
|
-
----
|
|
251
243
|
vocab: string containing the ordered sequence of supported characters
|
|
252
244
|
"""
|
|
253
245
|
|
|
254
246
|
def __call__(
|
|
255
247
|
self,
|
|
256
248
|
logits: tf.Tensor,
|
|
257
|
-
) ->
|
|
249
|
+
) -> list[tuple[str, float]]:
|
|
258
250
|
# compute pred with argmax for attention models
|
|
259
251
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
260
252
|
# N x L
|
|
@@ -290,9 +282,14 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool
|
|
|
290
282
|
cfg=_cfg,
|
|
291
283
|
**kwargs,
|
|
292
284
|
)
|
|
285
|
+
_build_model(model)
|
|
286
|
+
|
|
293
287
|
# Load pretrained parameters
|
|
294
288
|
if pretrained:
|
|
295
|
-
|
|
289
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
290
|
+
load_pretrained_params(
|
|
291
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
292
|
+
)
|
|
296
293
|
|
|
297
294
|
return model
|
|
298
295
|
|
|
@@ -307,12 +304,10 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
|
|
|
307
304
|
>>> out = model(input_tensor)
|
|
308
305
|
|
|
309
306
|
Args:
|
|
310
|
-
----
|
|
311
307
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
312
308
|
**kwargs: keywoard arguments passed to the MASTER architecture
|
|
313
309
|
|
|
314
310
|
Returns:
|
|
315
|
-
-------
|
|
316
311
|
text recognition architecture
|
|
317
312
|
"""
|
|
318
313
|
return _master("master", pretrained, magc_resnet31, **kwargs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Tuple
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -17,17 +16,15 @@ class _PARSeq:
|
|
|
17
16
|
|
|
18
17
|
def build_target(
|
|
19
18
|
self,
|
|
20
|
-
gts:
|
|
21
|
-
) ->
|
|
19
|
+
gts: list[str],
|
|
20
|
+
) -> tuple[np.ndarray, list[int]]:
|
|
22
21
|
"""Encode a list of gts sequences into a np array and gives the corresponding*
|
|
23
22
|
sequence lengths.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
gts: list of ground-truth labels
|
|
28
26
|
|
|
29
27
|
Returns:
|
|
30
|
-
-------
|
|
31
28
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
32
29
|
"""
|
|
33
30
|
encoded = encode_sequences(
|
|
@@ -46,7 +43,6 @@ class _PARSeqPostProcessor(RecognitionPostProcessor):
|
|
|
46
43
|
"""Abstract class to postprocess the raw output of the model
|
|
47
44
|
|
|
48
45
|
Args:
|
|
49
|
-
----
|
|
50
46
|
vocab: string containing the ordered sequence of supported characters
|
|
51
47
|
"""
|
|
52
48
|
|
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
+
from collections.abc import Callable
|
|
7
8
|
from copy import deepcopy
|
|
8
9
|
from itertools import permutations
|
|
9
|
-
from typing import Any
|
|
10
|
+
from typing import Any
|
|
10
11
|
|
|
11
12
|
import numpy as np
|
|
12
13
|
import torch
|
|
@@ -23,7 +24,7 @@ from .base import _PARSeq, _PARSeqPostProcessor
|
|
|
23
24
|
|
|
24
25
|
__all__ = ["PARSeq", "parseq"]
|
|
25
26
|
|
|
26
|
-
default_cfgs:
|
|
27
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
27
28
|
"parseq": {
|
|
28
29
|
"mean": (0.694, 0.695, 0.693),
|
|
29
30
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -38,7 +39,6 @@ class CharEmbedding(nn.Module):
|
|
|
38
39
|
"""Implements the character embedding module
|
|
39
40
|
|
|
40
41
|
Args:
|
|
41
|
-
----
|
|
42
42
|
vocab_size: size of the vocabulary
|
|
43
43
|
d_model: dimension of the model
|
|
44
44
|
"""
|
|
@@ -56,7 +56,6 @@ class PARSeqDecoder(nn.Module):
|
|
|
56
56
|
"""Implements decoder module of the PARSeq model
|
|
57
57
|
|
|
58
58
|
Args:
|
|
59
|
-
----
|
|
60
59
|
d_model: dimension of the model
|
|
61
60
|
num_heads: number of attention heads
|
|
62
61
|
ffd: dimension of the feed forward layer
|
|
@@ -92,7 +91,7 @@ class PARSeqDecoder(nn.Module):
|
|
|
92
91
|
target,
|
|
93
92
|
content,
|
|
94
93
|
memory,
|
|
95
|
-
target_mask:
|
|
94
|
+
target_mask: torch.Tensor | None = None,
|
|
96
95
|
):
|
|
97
96
|
query_norm = self.query_norm(target)
|
|
98
97
|
content_norm = self.content_norm(content)
|
|
@@ -112,7 +111,6 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
112
111
|
Slightly modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
|
|
113
112
|
|
|
114
113
|
Args:
|
|
115
|
-
----
|
|
116
114
|
feature_extractor: the backbone serving as feature extractor
|
|
117
115
|
vocab: vocabulary used for encoding
|
|
118
116
|
embedding_units: number of embedding units
|
|
@@ -136,9 +134,9 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
136
134
|
dec_num_heads: int = 12,
|
|
137
135
|
dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
|
|
138
136
|
dec_ffd_ratio: int = 4,
|
|
139
|
-
input_shape:
|
|
137
|
+
input_shape: tuple[int, int, int] = (3, 32, 128),
|
|
140
138
|
exportable: bool = False,
|
|
141
|
-
cfg:
|
|
139
|
+
cfg: dict[str, Any] | None = None,
|
|
142
140
|
) -> None:
|
|
143
141
|
super().__init__()
|
|
144
142
|
self.vocab = vocab
|
|
@@ -212,12 +210,12 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
212
210
|
|
|
213
211
|
sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
|
|
214
212
|
eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
|
|
215
|
-
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
|
|
213
|
+
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int() # type: ignore[list-item]
|
|
216
214
|
if len(combined) > 1:
|
|
217
215
|
combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
|
|
218
216
|
return combined
|
|
219
217
|
|
|
220
|
-
def generate_permutations_attention_masks(self, permutation: torch.Tensor) ->
|
|
218
|
+
def generate_permutations_attention_masks(self, permutation: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
221
219
|
# Generate source and target mask for the decoder attention.
|
|
222
220
|
sz = permutation.shape[0]
|
|
223
221
|
mask = torch.ones((sz, sz), device=permutation.device)
|
|
@@ -236,8 +234,8 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
236
234
|
self,
|
|
237
235
|
target: torch.Tensor,
|
|
238
236
|
memory: torch.Tensor,
|
|
239
|
-
target_mask:
|
|
240
|
-
target_query:
|
|
237
|
+
target_mask: torch.Tensor | None = None,
|
|
238
|
+
target_query: torch.Tensor | None = None,
|
|
241
239
|
) -> torch.Tensor:
|
|
242
240
|
"""Add positional information to the target sequence and pass it through the decoder."""
|
|
243
241
|
batch_size, sequence_length = target.shape
|
|
@@ -250,7 +248,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
250
248
|
target_query = self.dropout(target_query)
|
|
251
249
|
return self.decoder(target_query, content, memory, target_mask)
|
|
252
250
|
|
|
253
|
-
def decode_autoregressive(self, features: torch.Tensor, max_len:
|
|
251
|
+
def decode_autoregressive(self, features: torch.Tensor, max_len: int | None = None) -> torch.Tensor:
|
|
254
252
|
"""Generate predictions for the given features."""
|
|
255
253
|
max_length = max_len if max_len is not None else self.max_length
|
|
256
254
|
max_length = min(max_length, self.max_length) + 1
|
|
@@ -283,7 +281,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
283
281
|
|
|
284
282
|
# Stop decoding if all sequences have reached the EOS token
|
|
285
283
|
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
|
|
286
|
-
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
|
|
284
|
+
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): # type: ignore[attr-defined]
|
|
287
285
|
break
|
|
288
286
|
|
|
289
287
|
logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
|
|
@@ -298,7 +296,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
298
296
|
|
|
299
297
|
# Create padding mask for refined target input maskes all behind EOS token as False
|
|
300
298
|
# (N, 1, 1, max_length)
|
|
301
|
-
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
299
|
+
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
|
|
302
300
|
mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
|
|
303
301
|
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
|
|
304
302
|
|
|
@@ -307,10 +305,10 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
307
305
|
def forward(
|
|
308
306
|
self,
|
|
309
307
|
x: torch.Tensor,
|
|
310
|
-
target:
|
|
308
|
+
target: list[str] | None = None,
|
|
311
309
|
return_model_output: bool = False,
|
|
312
310
|
return_preds: bool = False,
|
|
313
|
-
) ->
|
|
311
|
+
) -> dict[str, Any]:
|
|
314
312
|
features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
|
|
315
313
|
# remove cls token
|
|
316
314
|
features = features[:, 1:, :]
|
|
@@ -337,7 +335,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
337
335
|
).unsqueeze(1).unsqueeze(1) # (N, 1, 1, seq_len)
|
|
338
336
|
|
|
339
337
|
loss = torch.tensor(0.0, device=features.device)
|
|
340
|
-
loss_numel:
|
|
338
|
+
loss_numel: int | float = 0
|
|
341
339
|
n = (gt_out != self.vocab_size + 2).sum().item()
|
|
342
340
|
for i, perm in enumerate(tgt_perms):
|
|
343
341
|
_, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
|
|
@@ -365,7 +363,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
365
363
|
|
|
366
364
|
logits = _bf16_to_float32(logits)
|
|
367
365
|
|
|
368
|
-
out:
|
|
366
|
+
out: dict[str, Any] = {}
|
|
369
367
|
if self.exportable:
|
|
370
368
|
out["logits"] = logits
|
|
371
369
|
return out
|
|
@@ -374,8 +372,13 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
374
372
|
out["out_map"] = logits
|
|
375
373
|
|
|
376
374
|
if target is None or return_preds:
|
|
375
|
+
# Disable for torch.compile compatibility
|
|
376
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
377
|
+
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
378
|
+
return self.postprocessor(logits)
|
|
379
|
+
|
|
377
380
|
# Post-process boxes
|
|
378
|
-
out["preds"] =
|
|
381
|
+
out["preds"] = _postprocess(logits)
|
|
379
382
|
|
|
380
383
|
if target is not None:
|
|
381
384
|
out["loss"] = loss
|
|
@@ -387,14 +390,13 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
|
|
|
387
390
|
"""Post processor for PARSeq architecture
|
|
388
391
|
|
|
389
392
|
Args:
|
|
390
|
-
----
|
|
391
393
|
vocab: string containing the ordered sequence of supported characters
|
|
392
394
|
"""
|
|
393
395
|
|
|
394
396
|
def __call__(
|
|
395
397
|
self,
|
|
396
398
|
logits: torch.Tensor,
|
|
397
|
-
) ->
|
|
399
|
+
) -> list[tuple[str, float]]:
|
|
398
400
|
# compute pred with argmax for attention models
|
|
399
401
|
out_idxs = logits.argmax(-1)
|
|
400
402
|
preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
|
|
@@ -417,7 +419,7 @@ def _parseq(
|
|
|
417
419
|
pretrained: bool,
|
|
418
420
|
backbone_fn: Callable[[bool], nn.Module],
|
|
419
421
|
layer: str,
|
|
420
|
-
ignore_keys:
|
|
422
|
+
ignore_keys: list[str] | None = None,
|
|
421
423
|
**kwargs: Any,
|
|
422
424
|
) -> PARSeq:
|
|
423
425
|
# Patch the config
|
|
@@ -462,12 +464,10 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
|
|
|
462
464
|
>>> out = model(input_tensor)
|
|
463
465
|
|
|
464
466
|
Args:
|
|
465
|
-
----
|
|
466
467
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
467
468
|
**kwargs: keyword arguments of the PARSeq architecture
|
|
468
469
|
|
|
469
470
|
Returns:
|
|
470
|
-
-------
|
|
471
471
|
text recognition architecture
|
|
472
472
|
"""
|
|
473
473
|
return _parseq(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
import math
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from itertools import permutations
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
@@ -16,18 +16,18 @@ from doctr.datasets import VOCABS
|
|
|
16
16
|
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
|
|
17
17
|
|
|
18
18
|
from ...classification import vit_s
|
|
19
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
19
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
20
20
|
from .base import _PARSeq, _PARSeqPostProcessor
|
|
21
21
|
|
|
22
22
|
__all__ = ["PARSeq", "parseq"]
|
|
23
23
|
|
|
24
|
-
default_cfgs:
|
|
24
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
25
25
|
"parseq": {
|
|
26
26
|
"mean": (0.694, 0.695, 0.693),
|
|
27
27
|
"std": (0.299, 0.296, 0.301),
|
|
28
28
|
"input_shape": (32, 128, 3),
|
|
29
29
|
"vocab": VOCABS["french"],
|
|
30
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
30
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0",
|
|
31
31
|
},
|
|
32
32
|
}
|
|
33
33
|
|
|
@@ -36,14 +36,14 @@ class CharEmbedding(layers.Layer):
|
|
|
36
36
|
"""Implements the character embedding module
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
|
-
|
|
39
|
+
-
|
|
40
40
|
vocab_size: size of the vocabulary
|
|
41
41
|
d_model: dimension of the model
|
|
42
42
|
"""
|
|
43
43
|
|
|
44
44
|
def __init__(self, vocab_size: int, d_model: int):
|
|
45
45
|
super(CharEmbedding, self).__init__()
|
|
46
|
-
self.embedding =
|
|
46
|
+
self.embedding = layers.Embedding(vocab_size, d_model)
|
|
47
47
|
self.d_model = d_model
|
|
48
48
|
|
|
49
49
|
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
@@ -54,7 +54,6 @@ class PARSeqDecoder(layers.Layer):
|
|
|
54
54
|
"""Implements decoder module of the PARSeq model
|
|
55
55
|
|
|
56
56
|
Args:
|
|
57
|
-
----
|
|
58
57
|
d_model: dimension of the model
|
|
59
58
|
num_heads: number of attention heads
|
|
60
59
|
ffd: dimension of the feed forward layer
|
|
@@ -115,7 +114,6 @@ class PARSeq(_PARSeq, Model):
|
|
|
115
114
|
Modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
|
|
116
115
|
|
|
117
116
|
Args:
|
|
118
|
-
----
|
|
119
117
|
feature_extractor: the backbone serving as feature extractor
|
|
120
118
|
vocab: vocabulary used for encoding
|
|
121
119
|
embedding_units: number of embedding units
|
|
@@ -129,7 +127,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
129
127
|
cfg: dictionary containing information about the model
|
|
130
128
|
"""
|
|
131
129
|
|
|
132
|
-
_children_names:
|
|
130
|
+
_children_names: list[str] = ["feat_extractor", "postprocessor"]
|
|
133
131
|
|
|
134
132
|
def __init__(
|
|
135
133
|
self,
|
|
@@ -141,9 +139,9 @@ class PARSeq(_PARSeq, Model):
|
|
|
141
139
|
dec_num_heads: int = 12,
|
|
142
140
|
dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
|
|
143
141
|
dec_ffd_ratio: int = 4,
|
|
144
|
-
input_shape:
|
|
142
|
+
input_shape: tuple[int, int, int] = (32, 128, 3),
|
|
145
143
|
exportable: bool = False,
|
|
146
|
-
cfg:
|
|
144
|
+
cfg: dict[str, Any] | None = None,
|
|
147
145
|
) -> None:
|
|
148
146
|
super().__init__()
|
|
149
147
|
self.vocab = vocab
|
|
@@ -167,7 +165,6 @@ class PARSeq(_PARSeq, Model):
|
|
|
167
165
|
|
|
168
166
|
self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
|
|
169
167
|
|
|
170
|
-
@tf.function
|
|
171
168
|
def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
|
|
172
169
|
# Generates permutations of the target sequence.
|
|
173
170
|
# Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
|
|
@@ -214,8 +211,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
214
211
|
)
|
|
215
212
|
return combined
|
|
216
213
|
|
|
217
|
-
|
|
218
|
-
def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
214
|
+
def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
|
|
219
215
|
# Generate source and target mask for the decoder attention.
|
|
220
216
|
sz = permutation.shape[0]
|
|
221
217
|
mask = tf.ones((sz, sz), dtype=tf.float32)
|
|
@@ -234,13 +230,12 @@ class PARSeq(_PARSeq, Model):
|
|
|
234
230
|
target_mask = mask[1:, :-1]
|
|
235
231
|
return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)
|
|
236
232
|
|
|
237
|
-
@tf.function
|
|
238
233
|
def decode(
|
|
239
234
|
self,
|
|
240
235
|
target: tf.Tensor,
|
|
241
|
-
memory: tf,
|
|
242
|
-
target_mask:
|
|
243
|
-
target_query:
|
|
236
|
+
memory: tf.Tensor,
|
|
237
|
+
target_mask: tf.Tensor | None = None,
|
|
238
|
+
target_query: tf.Tensor | None = None,
|
|
244
239
|
**kwargs: Any,
|
|
245
240
|
) -> tf.Tensor:
|
|
246
241
|
batch_size, sequence_length = target.shape
|
|
@@ -253,8 +248,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
253
248
|
target_query = self.dropout(target_query, **kwargs)
|
|
254
249
|
return self.decoder(target_query, content, memory, target_mask, **kwargs)
|
|
255
250
|
|
|
256
|
-
|
|
257
|
-
def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor:
|
|
251
|
+
def decode_autoregressive(self, features: tf.Tensor, max_len: int | None = None, **kwargs) -> tf.Tensor:
|
|
258
252
|
"""Generate predictions for the given features."""
|
|
259
253
|
max_length = max_len if max_len is not None else self.max_length
|
|
260
254
|
max_length = min(max_length, self.max_length) + 1
|
|
@@ -321,11 +315,11 @@ class PARSeq(_PARSeq, Model):
|
|
|
321
315
|
def call(
|
|
322
316
|
self,
|
|
323
317
|
x: tf.Tensor,
|
|
324
|
-
target:
|
|
318
|
+
target: list[str] | None = None,
|
|
325
319
|
return_model_output: bool = False,
|
|
326
320
|
return_preds: bool = False,
|
|
327
321
|
**kwargs: Any,
|
|
328
|
-
) ->
|
|
322
|
+
) -> dict[str, Any]:
|
|
329
323
|
features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
|
|
330
324
|
# remove cls token
|
|
331
325
|
features = features[:, 1:, :]
|
|
@@ -396,7 +390,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
396
390
|
|
|
397
391
|
logits = _bf16_to_float32(logits)
|
|
398
392
|
|
|
399
|
-
out:
|
|
393
|
+
out: dict[str, tf.Tensor] = {}
|
|
400
394
|
if self.exportable:
|
|
401
395
|
out["logits"] = logits
|
|
402
396
|
return out
|
|
@@ -418,14 +412,13 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
|
|
|
418
412
|
"""Post processor for PARSeq architecture
|
|
419
413
|
|
|
420
414
|
Args:
|
|
421
|
-
----
|
|
422
415
|
vocab: string containing the ordered sequence of supported characters
|
|
423
416
|
"""
|
|
424
417
|
|
|
425
418
|
def __call__(
|
|
426
419
|
self,
|
|
427
420
|
logits: tf.Tensor,
|
|
428
|
-
) ->
|
|
421
|
+
) -> list[tuple[str, float]]:
|
|
429
422
|
# compute pred with argmax for attention models
|
|
430
423
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
431
424
|
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
|
|
@@ -451,7 +444,7 @@ def _parseq(
|
|
|
451
444
|
arch: str,
|
|
452
445
|
pretrained: bool,
|
|
453
446
|
backbone_fn,
|
|
454
|
-
input_shape:
|
|
447
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
455
448
|
**kwargs: Any,
|
|
456
449
|
) -> PARSeq:
|
|
457
450
|
# Patch the config
|
|
@@ -476,9 +469,14 @@ def _parseq(
|
|
|
476
469
|
|
|
477
470
|
# Build the model
|
|
478
471
|
model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
|
|
472
|
+
_build_model(model)
|
|
473
|
+
|
|
479
474
|
# Load pretrained parameters
|
|
480
475
|
if pretrained:
|
|
481
|
-
|
|
476
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
477
|
+
load_pretrained_params(
|
|
478
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
479
|
+
)
|
|
482
480
|
|
|
483
481
|
return model
|
|
484
482
|
|
|
@@ -494,12 +492,10 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
|
|
|
494
492
|
>>> out = model(input_tensor)
|
|
495
493
|
|
|
496
494
|
Args:
|
|
497
|
-
----
|
|
498
495
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
499
496
|
**kwargs: keyword arguments of the PARSeq architecture
|
|
500
497
|
|
|
501
498
|
Returns:
|
|
502
|
-
-------
|
|
503
499
|
text recognition architecture
|
|
504
500
|
"""
|
|
505
501
|
return _parseq(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|