python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +8 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +7 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +4 -5
- doctr/datasets/ic13.py +4 -5
- doctr/datasets/iiit5k.py +6 -5
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +6 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +6 -5
- doctr/datasets/svt.py +4 -5
- doctr/datasets/synthtext.py +4 -5
- doctr/datasets/utils.py +34 -29
- doctr/datasets/vocabs.py +17 -7
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +2 -7
- doctr/io/elements.py +59 -79
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +30 -48
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +8 -11
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +5 -17
- doctr/models/classification/mobilenet/tensorflow.py +8 -21
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +6 -8
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +20 -31
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +8 -15
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +9 -12
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +6 -12
- doctr/models/classification/zoo.py +19 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +14 -26
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +14 -23
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +3 -7
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +18 -19
- doctr/models/kie_predictor/tensorflow.py +13 -14
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +12 -13
- doctr/models/predictor/tensorflow.py +8 -9
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +11 -23
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +12 -22
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +16 -22
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +12 -21
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +12 -20
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +14 -17
- doctr/models/utils/tensorflow.py +17 -16
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +16 -47
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import Model, layers
|
|
@@ -19,7 +19,7 @@ from .base import _MASTER, _MASTERPostProcessor
|
|
|
19
19
|
__all__ = ["MASTER", "master"]
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
default_cfgs:
|
|
22
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
23
|
"master": {
|
|
24
24
|
"mean": (0.694, 0.695, 0.693),
|
|
25
25
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -35,7 +35,6 @@ class MASTER(_MASTER, Model):
|
|
|
35
35
|
Implementation based on the official TF implementation: <https://github.com/jiangxiluning/MASTER-TF>`_.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
|
-
----
|
|
39
38
|
feature_extractor: the backbone serving as feature extractor
|
|
40
39
|
vocab: vocabulary, (without EOS, SOS, PAD)
|
|
41
40
|
d_model: d parameter for the transformer decoder
|
|
@@ -59,9 +58,9 @@ class MASTER(_MASTER, Model):
|
|
|
59
58
|
num_layers: int = 3,
|
|
60
59
|
max_length: int = 50,
|
|
61
60
|
dropout: float = 0.2,
|
|
62
|
-
input_shape:
|
|
61
|
+
input_shape: tuple[int, int, int] = (32, 128, 3), # different from the paper
|
|
63
62
|
exportable: bool = False,
|
|
64
|
-
cfg:
|
|
63
|
+
cfg: dict[str, Any] | None = None,
|
|
65
64
|
) -> None:
|
|
66
65
|
super().__init__()
|
|
67
66
|
|
|
@@ -89,7 +88,7 @@ class MASTER(_MASTER, Model):
|
|
|
89
88
|
self.postprocessor = MASTERPostProcessor(vocab=self.vocab)
|
|
90
89
|
|
|
91
90
|
@tf.function
|
|
92
|
-
def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) ->
|
|
91
|
+
def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
|
|
93
92
|
# [1, 1, 1, ..., 0, 0, 0] -> 0 is masked
|
|
94
93
|
# (N, 1, 1, max_length)
|
|
95
94
|
target_pad_mask = tf.cast(tf.math.not_equal(target, self.vocab_size + 2), dtype=tf.uint8)
|
|
@@ -109,19 +108,17 @@ class MASTER(_MASTER, Model):
|
|
|
109
108
|
def compute_loss(
|
|
110
109
|
model_output: tf.Tensor,
|
|
111
110
|
gt: tf.Tensor,
|
|
112
|
-
seq_len:
|
|
111
|
+
seq_len: list[int],
|
|
113
112
|
) -> tf.Tensor:
|
|
114
113
|
"""Compute categorical cross-entropy loss for the model.
|
|
115
114
|
Sequences are masked after the EOS character.
|
|
116
115
|
|
|
117
116
|
Args:
|
|
118
|
-
----
|
|
119
117
|
gt: the encoded tensor with gt labels
|
|
120
118
|
model_output: predicted logits of the model
|
|
121
119
|
seq_len: lengths of each gt word inside the batch
|
|
122
120
|
|
|
123
121
|
Returns:
|
|
124
|
-
-------
|
|
125
122
|
The loss of the model on the batch
|
|
126
123
|
"""
|
|
127
124
|
# Input length : number of timesteps
|
|
@@ -144,15 +141,14 @@ class MASTER(_MASTER, Model):
|
|
|
144
141
|
def call(
|
|
145
142
|
self,
|
|
146
143
|
x: tf.Tensor,
|
|
147
|
-
target:
|
|
144
|
+
target: list[str] | None = None,
|
|
148
145
|
return_model_output: bool = False,
|
|
149
146
|
return_preds: bool = False,
|
|
150
147
|
**kwargs: Any,
|
|
151
|
-
) ->
|
|
148
|
+
) -> dict[str, Any]:
|
|
152
149
|
"""Call function for training
|
|
153
150
|
|
|
154
151
|
Args:
|
|
155
|
-
----
|
|
156
152
|
x: images
|
|
157
153
|
target: list of str labels
|
|
158
154
|
return_model_output: if True, return logits
|
|
@@ -160,7 +156,6 @@ class MASTER(_MASTER, Model):
|
|
|
160
156
|
**kwargs: keyword arguments passed to the decoder
|
|
161
157
|
|
|
162
158
|
Returns:
|
|
163
|
-
-------
|
|
164
159
|
A dictionnary containing eventually loss, logits and predictions.
|
|
165
160
|
"""
|
|
166
161
|
# Encode
|
|
@@ -171,7 +166,7 @@ class MASTER(_MASTER, Model):
|
|
|
171
166
|
# add positional encoding to features
|
|
172
167
|
encoded = self.positional_encoding(feature, **kwargs)
|
|
173
168
|
|
|
174
|
-
out:
|
|
169
|
+
out: dict[str, tf.Tensor] = {}
|
|
175
170
|
|
|
176
171
|
if kwargs.get("training", False) and target is None:
|
|
177
172
|
raise ValueError("Need to provide labels during training")
|
|
@@ -209,13 +204,11 @@ class MASTER(_MASTER, Model):
|
|
|
209
204
|
"""Decode function for prediction
|
|
210
205
|
|
|
211
206
|
Args:
|
|
212
|
-
----
|
|
213
207
|
encoded: encoded features
|
|
214
208
|
**kwargs: keyword arguments passed to the decoder
|
|
215
209
|
|
|
216
210
|
Returns:
|
|
217
|
-
|
|
218
|
-
A Tuple of tf.Tensor: predictions, logits
|
|
211
|
+
A tuple of tf.Tensor: predictions, logits
|
|
219
212
|
"""
|
|
220
213
|
b = encoded.shape[0]
|
|
221
214
|
|
|
@@ -247,14 +240,13 @@ class MASTERPostProcessor(_MASTERPostProcessor):
|
|
|
247
240
|
"""Post processor for MASTER architectures
|
|
248
241
|
|
|
249
242
|
Args:
|
|
250
|
-
----
|
|
251
243
|
vocab: string containing the ordered sequence of supported characters
|
|
252
244
|
"""
|
|
253
245
|
|
|
254
246
|
def __call__(
|
|
255
247
|
self,
|
|
256
248
|
logits: tf.Tensor,
|
|
257
|
-
) ->
|
|
249
|
+
) -> list[tuple[str, float]]:
|
|
258
250
|
# compute pred with argmax for attention models
|
|
259
251
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
260
252
|
# N x L
|
|
@@ -312,12 +304,10 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
|
|
|
312
304
|
>>> out = model(input_tensor)
|
|
313
305
|
|
|
314
306
|
Args:
|
|
315
|
-
----
|
|
316
307
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
317
308
|
**kwargs: keywoard arguments passed to the MASTER architecture
|
|
318
309
|
|
|
319
310
|
Returns:
|
|
320
|
-
-------
|
|
321
311
|
text recognition architecture
|
|
322
312
|
"""
|
|
323
313
|
return _master("master", pretrained, magc_resnet31, **kwargs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Tuple
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -17,17 +16,15 @@ class _PARSeq:
|
|
|
17
16
|
|
|
18
17
|
def build_target(
|
|
19
18
|
self,
|
|
20
|
-
gts:
|
|
21
|
-
) ->
|
|
19
|
+
gts: list[str],
|
|
20
|
+
) -> tuple[np.ndarray, list[int]]:
|
|
22
21
|
"""Encode a list of gts sequences into a np array and gives the corresponding*
|
|
23
22
|
sequence lengths.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
gts: list of ground-truth labels
|
|
28
26
|
|
|
29
27
|
Returns:
|
|
30
|
-
-------
|
|
31
28
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
32
29
|
"""
|
|
33
30
|
encoded = encode_sequences(
|
|
@@ -46,7 +43,6 @@ class _PARSeqPostProcessor(RecognitionPostProcessor):
|
|
|
46
43
|
"""Abstract class to postprocess the raw output of the model
|
|
47
44
|
|
|
48
45
|
Args:
|
|
49
|
-
----
|
|
50
46
|
vocab: string containing the ordered sequence of supported characters
|
|
51
47
|
"""
|
|
52
48
|
|
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
+
from collections.abc import Callable
|
|
7
8
|
from copy import deepcopy
|
|
8
9
|
from itertools import permutations
|
|
9
|
-
from typing import Any
|
|
10
|
+
from typing import Any
|
|
10
11
|
|
|
11
12
|
import numpy as np
|
|
12
13
|
import torch
|
|
@@ -23,7 +24,7 @@ from .base import _PARSeq, _PARSeqPostProcessor
|
|
|
23
24
|
|
|
24
25
|
__all__ = ["PARSeq", "parseq"]
|
|
25
26
|
|
|
26
|
-
default_cfgs:
|
|
27
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
27
28
|
"parseq": {
|
|
28
29
|
"mean": (0.694, 0.695, 0.693),
|
|
29
30
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -38,7 +39,6 @@ class CharEmbedding(nn.Module):
|
|
|
38
39
|
"""Implements the character embedding module
|
|
39
40
|
|
|
40
41
|
Args:
|
|
41
|
-
----
|
|
42
42
|
vocab_size: size of the vocabulary
|
|
43
43
|
d_model: dimension of the model
|
|
44
44
|
"""
|
|
@@ -56,7 +56,6 @@ class PARSeqDecoder(nn.Module):
|
|
|
56
56
|
"""Implements decoder module of the PARSeq model
|
|
57
57
|
|
|
58
58
|
Args:
|
|
59
|
-
----
|
|
60
59
|
d_model: dimension of the model
|
|
61
60
|
num_heads: number of attention heads
|
|
62
61
|
ffd: dimension of the feed forward layer
|
|
@@ -92,7 +91,7 @@ class PARSeqDecoder(nn.Module):
|
|
|
92
91
|
target,
|
|
93
92
|
content,
|
|
94
93
|
memory,
|
|
95
|
-
target_mask:
|
|
94
|
+
target_mask: torch.Tensor | None = None,
|
|
96
95
|
):
|
|
97
96
|
query_norm = self.query_norm(target)
|
|
98
97
|
content_norm = self.content_norm(content)
|
|
@@ -112,7 +111,6 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
112
111
|
Slightly modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
|
|
113
112
|
|
|
114
113
|
Args:
|
|
115
|
-
----
|
|
116
114
|
feature_extractor: the backbone serving as feature extractor
|
|
117
115
|
vocab: vocabulary used for encoding
|
|
118
116
|
embedding_units: number of embedding units
|
|
@@ -136,9 +134,9 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
136
134
|
dec_num_heads: int = 12,
|
|
137
135
|
dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
|
|
138
136
|
dec_ffd_ratio: int = 4,
|
|
139
|
-
input_shape:
|
|
137
|
+
input_shape: tuple[int, int, int] = (3, 32, 128),
|
|
140
138
|
exportable: bool = False,
|
|
141
|
-
cfg:
|
|
139
|
+
cfg: dict[str, Any] | None = None,
|
|
142
140
|
) -> None:
|
|
143
141
|
super().__init__()
|
|
144
142
|
self.vocab = vocab
|
|
@@ -212,12 +210,12 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
212
210
|
|
|
213
211
|
sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
|
|
214
212
|
eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
|
|
215
|
-
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
|
|
213
|
+
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int() # type: ignore[list-item]
|
|
216
214
|
if len(combined) > 1:
|
|
217
215
|
combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
|
|
218
216
|
return combined
|
|
219
217
|
|
|
220
|
-
def generate_permutations_attention_masks(self, permutation: torch.Tensor) ->
|
|
218
|
+
def generate_permutations_attention_masks(self, permutation: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
221
219
|
# Generate source and target mask for the decoder attention.
|
|
222
220
|
sz = permutation.shape[0]
|
|
223
221
|
mask = torch.ones((sz, sz), device=permutation.device)
|
|
@@ -236,8 +234,8 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
236
234
|
self,
|
|
237
235
|
target: torch.Tensor,
|
|
238
236
|
memory: torch.Tensor,
|
|
239
|
-
target_mask:
|
|
240
|
-
target_query:
|
|
237
|
+
target_mask: torch.Tensor | None = None,
|
|
238
|
+
target_query: torch.Tensor | None = None,
|
|
241
239
|
) -> torch.Tensor:
|
|
242
240
|
"""Add positional information to the target sequence and pass it through the decoder."""
|
|
243
241
|
batch_size, sequence_length = target.shape
|
|
@@ -250,7 +248,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
250
248
|
target_query = self.dropout(target_query)
|
|
251
249
|
return self.decoder(target_query, content, memory, target_mask)
|
|
252
250
|
|
|
253
|
-
def decode_autoregressive(self, features: torch.Tensor, max_len:
|
|
251
|
+
def decode_autoregressive(self, features: torch.Tensor, max_len: int | None = None) -> torch.Tensor:
|
|
254
252
|
"""Generate predictions for the given features."""
|
|
255
253
|
max_length = max_len if max_len is not None else self.max_length
|
|
256
254
|
max_length = min(max_length, self.max_length) + 1
|
|
@@ -283,7 +281,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
283
281
|
|
|
284
282
|
# Stop decoding if all sequences have reached the EOS token
|
|
285
283
|
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
|
|
286
|
-
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
|
|
284
|
+
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): # type: ignore[attr-defined]
|
|
287
285
|
break
|
|
288
286
|
|
|
289
287
|
logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
|
|
@@ -298,7 +296,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
298
296
|
|
|
299
297
|
# Create padding mask for refined target input maskes all behind EOS token as False
|
|
300
298
|
# (N, 1, 1, max_length)
|
|
301
|
-
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
299
|
+
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
|
|
302
300
|
mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
|
|
303
301
|
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
|
|
304
302
|
|
|
@@ -307,10 +305,10 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
307
305
|
def forward(
|
|
308
306
|
self,
|
|
309
307
|
x: torch.Tensor,
|
|
310
|
-
target:
|
|
308
|
+
target: list[str] | None = None,
|
|
311
309
|
return_model_output: bool = False,
|
|
312
310
|
return_preds: bool = False,
|
|
313
|
-
) ->
|
|
311
|
+
) -> dict[str, Any]:
|
|
314
312
|
features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
|
|
315
313
|
# remove cls token
|
|
316
314
|
features = features[:, 1:, :]
|
|
@@ -337,7 +335,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
337
335
|
).unsqueeze(1).unsqueeze(1) # (N, 1, 1, seq_len)
|
|
338
336
|
|
|
339
337
|
loss = torch.tensor(0.0, device=features.device)
|
|
340
|
-
loss_numel:
|
|
338
|
+
loss_numel: int | float = 0
|
|
341
339
|
n = (gt_out != self.vocab_size + 2).sum().item()
|
|
342
340
|
for i, perm in enumerate(tgt_perms):
|
|
343
341
|
_, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
|
|
@@ -365,7 +363,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
365
363
|
|
|
366
364
|
logits = _bf16_to_float32(logits)
|
|
367
365
|
|
|
368
|
-
out:
|
|
366
|
+
out: dict[str, Any] = {}
|
|
369
367
|
if self.exportable:
|
|
370
368
|
out["logits"] = logits
|
|
371
369
|
return out
|
|
@@ -374,8 +372,13 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
374
372
|
out["out_map"] = logits
|
|
375
373
|
|
|
376
374
|
if target is None or return_preds:
|
|
375
|
+
# Disable for torch.compile compatibility
|
|
376
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
377
|
+
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
378
|
+
return self.postprocessor(logits)
|
|
379
|
+
|
|
377
380
|
# Post-process boxes
|
|
378
|
-
out["preds"] =
|
|
381
|
+
out["preds"] = _postprocess(logits)
|
|
379
382
|
|
|
380
383
|
if target is not None:
|
|
381
384
|
out["loss"] = loss
|
|
@@ -387,14 +390,13 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
|
|
|
387
390
|
"""Post processor for PARSeq architecture
|
|
388
391
|
|
|
389
392
|
Args:
|
|
390
|
-
----
|
|
391
393
|
vocab: string containing the ordered sequence of supported characters
|
|
392
394
|
"""
|
|
393
395
|
|
|
394
396
|
def __call__(
|
|
395
397
|
self,
|
|
396
398
|
logits: torch.Tensor,
|
|
397
|
-
) ->
|
|
399
|
+
) -> list[tuple[str, float]]:
|
|
398
400
|
# compute pred with argmax for attention models
|
|
399
401
|
out_idxs = logits.argmax(-1)
|
|
400
402
|
preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
|
|
@@ -417,7 +419,7 @@ def _parseq(
|
|
|
417
419
|
pretrained: bool,
|
|
418
420
|
backbone_fn: Callable[[bool], nn.Module],
|
|
419
421
|
layer: str,
|
|
420
|
-
ignore_keys:
|
|
422
|
+
ignore_keys: list[str] | None = None,
|
|
421
423
|
**kwargs: Any,
|
|
422
424
|
) -> PARSeq:
|
|
423
425
|
# Patch the config
|
|
@@ -462,12 +464,10 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
|
|
|
462
464
|
>>> out = model(input_tensor)
|
|
463
465
|
|
|
464
466
|
Args:
|
|
465
|
-
----
|
|
466
467
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
467
468
|
**kwargs: keyword arguments of the PARSeq architecture
|
|
468
469
|
|
|
469
470
|
Returns:
|
|
470
|
-
-------
|
|
471
471
|
text recognition architecture
|
|
472
472
|
"""
|
|
473
473
|
return _parseq(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
import math
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from itertools import permutations
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
@@ -21,7 +21,7 @@ from .base import _PARSeq, _PARSeqPostProcessor
|
|
|
21
21
|
|
|
22
22
|
__all__ = ["PARSeq", "parseq"]
|
|
23
23
|
|
|
24
|
-
default_cfgs:
|
|
24
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
25
25
|
"parseq": {
|
|
26
26
|
"mean": (0.694, 0.695, 0.693),
|
|
27
27
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -36,7 +36,7 @@ class CharEmbedding(layers.Layer):
|
|
|
36
36
|
"""Implements the character embedding module
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
|
-
|
|
39
|
+
-
|
|
40
40
|
vocab_size: size of the vocabulary
|
|
41
41
|
d_model: dimension of the model
|
|
42
42
|
"""
|
|
@@ -54,7 +54,6 @@ class PARSeqDecoder(layers.Layer):
|
|
|
54
54
|
"""Implements decoder module of the PARSeq model
|
|
55
55
|
|
|
56
56
|
Args:
|
|
57
|
-
----
|
|
58
57
|
d_model: dimension of the model
|
|
59
58
|
num_heads: number of attention heads
|
|
60
59
|
ffd: dimension of the feed forward layer
|
|
@@ -115,7 +114,6 @@ class PARSeq(_PARSeq, Model):
|
|
|
115
114
|
Modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
|
|
116
115
|
|
|
117
116
|
Args:
|
|
118
|
-
----
|
|
119
117
|
feature_extractor: the backbone serving as feature extractor
|
|
120
118
|
vocab: vocabulary used for encoding
|
|
121
119
|
embedding_units: number of embedding units
|
|
@@ -129,7 +127,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
129
127
|
cfg: dictionary containing information about the model
|
|
130
128
|
"""
|
|
131
129
|
|
|
132
|
-
_children_names:
|
|
130
|
+
_children_names: list[str] = ["feat_extractor", "postprocessor"]
|
|
133
131
|
|
|
134
132
|
def __init__(
|
|
135
133
|
self,
|
|
@@ -141,9 +139,9 @@ class PARSeq(_PARSeq, Model):
|
|
|
141
139
|
dec_num_heads: int = 12,
|
|
142
140
|
dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
|
|
143
141
|
dec_ffd_ratio: int = 4,
|
|
144
|
-
input_shape:
|
|
142
|
+
input_shape: tuple[int, int, int] = (32, 128, 3),
|
|
145
143
|
exportable: bool = False,
|
|
146
|
-
cfg:
|
|
144
|
+
cfg: dict[str, Any] | None = None,
|
|
147
145
|
) -> None:
|
|
148
146
|
super().__init__()
|
|
149
147
|
self.vocab = vocab
|
|
@@ -213,7 +211,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
213
211
|
)
|
|
214
212
|
return combined
|
|
215
213
|
|
|
216
|
-
def generate_permutations_attention_masks(self, permutation: tf.Tensor) ->
|
|
214
|
+
def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
|
|
217
215
|
# Generate source and target mask for the decoder attention.
|
|
218
216
|
sz = permutation.shape[0]
|
|
219
217
|
mask = tf.ones((sz, sz), dtype=tf.float32)
|
|
@@ -236,8 +234,8 @@ class PARSeq(_PARSeq, Model):
|
|
|
236
234
|
self,
|
|
237
235
|
target: tf.Tensor,
|
|
238
236
|
memory: tf.Tensor,
|
|
239
|
-
target_mask:
|
|
240
|
-
target_query:
|
|
237
|
+
target_mask: tf.Tensor | None = None,
|
|
238
|
+
target_query: tf.Tensor | None = None,
|
|
241
239
|
**kwargs: Any,
|
|
242
240
|
) -> tf.Tensor:
|
|
243
241
|
batch_size, sequence_length = target.shape
|
|
@@ -250,8 +248,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
250
248
|
target_query = self.dropout(target_query, **kwargs)
|
|
251
249
|
return self.decoder(target_query, content, memory, target_mask, **kwargs)
|
|
252
250
|
|
|
253
|
-
|
|
254
|
-
def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor:
|
|
251
|
+
def decode_autoregressive(self, features: tf.Tensor, max_len: int | None = None, **kwargs) -> tf.Tensor:
|
|
255
252
|
"""Generate predictions for the given features."""
|
|
256
253
|
max_length = max_len if max_len is not None else self.max_length
|
|
257
254
|
max_length = min(max_length, self.max_length) + 1
|
|
@@ -318,11 +315,11 @@ class PARSeq(_PARSeq, Model):
|
|
|
318
315
|
def call(
|
|
319
316
|
self,
|
|
320
317
|
x: tf.Tensor,
|
|
321
|
-
target:
|
|
318
|
+
target: list[str] | None = None,
|
|
322
319
|
return_model_output: bool = False,
|
|
323
320
|
return_preds: bool = False,
|
|
324
321
|
**kwargs: Any,
|
|
325
|
-
) ->
|
|
322
|
+
) -> dict[str, Any]:
|
|
326
323
|
features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
|
|
327
324
|
# remove cls token
|
|
328
325
|
features = features[:, 1:, :]
|
|
@@ -393,7 +390,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
393
390
|
|
|
394
391
|
logits = _bf16_to_float32(logits)
|
|
395
392
|
|
|
396
|
-
out:
|
|
393
|
+
out: dict[str, tf.Tensor] = {}
|
|
397
394
|
if self.exportable:
|
|
398
395
|
out["logits"] = logits
|
|
399
396
|
return out
|
|
@@ -415,14 +412,13 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
|
|
|
415
412
|
"""Post processor for PARSeq architecture
|
|
416
413
|
|
|
417
414
|
Args:
|
|
418
|
-
----
|
|
419
415
|
vocab: string containing the ordered sequence of supported characters
|
|
420
416
|
"""
|
|
421
417
|
|
|
422
418
|
def __call__(
|
|
423
419
|
self,
|
|
424
420
|
logits: tf.Tensor,
|
|
425
|
-
) ->
|
|
421
|
+
) -> list[tuple[str, float]]:
|
|
426
422
|
# compute pred with argmax for attention models
|
|
427
423
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
428
424
|
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
|
|
@@ -448,7 +444,7 @@ def _parseq(
|
|
|
448
444
|
arch: str,
|
|
449
445
|
pretrained: bool,
|
|
450
446
|
backbone_fn,
|
|
451
|
-
input_shape:
|
|
447
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
452
448
|
**kwargs: Any,
|
|
453
449
|
) -> PARSeq:
|
|
454
450
|
# Patch the config
|
|
@@ -496,12 +492,10 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
|
|
|
496
492
|
>>> out = model(input_tensor)
|
|
497
493
|
|
|
498
494
|
Args:
|
|
499
|
-
----
|
|
500
495
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
501
496
|
**kwargs: keyword arguments of the PARSeq architecture
|
|
502
497
|
|
|
503
498
|
Returns:
|
|
504
|
-
-------
|
|
505
499
|
text recognition architecture
|
|
506
500
|
"""
|
|
507
501
|
return _parseq(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Tuple, Union
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -13,16 +12,15 @@ __all__ = ["split_crops", "remap_preds"]
|
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
def split_crops(
|
|
16
|
-
crops:
|
|
15
|
+
crops: list[np.ndarray],
|
|
17
16
|
max_ratio: float,
|
|
18
17
|
target_ratio: int,
|
|
19
18
|
dilation: float,
|
|
20
19
|
channels_last: bool = True,
|
|
21
|
-
) ->
|
|
20
|
+
) -> tuple[list[np.ndarray], list[int | tuple[int, int]], bool]:
|
|
22
21
|
"""Chunk crops horizontally to match a given aspect ratio
|
|
23
22
|
|
|
24
23
|
Args:
|
|
25
|
-
----
|
|
26
24
|
crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
|
|
27
25
|
max_ratio: the maximum aspect ratio that won't trigger the chunk
|
|
28
26
|
target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
|
|
@@ -30,12 +28,11 @@ def split_crops(
|
|
|
30
28
|
channels_last: whether the numpy array has dimensions in channels last order
|
|
31
29
|
|
|
32
30
|
Returns:
|
|
33
|
-
-------
|
|
34
31
|
a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
|
|
35
32
|
"""
|
|
36
33
|
_remap_required = False
|
|
37
|
-
crop_map:
|
|
38
|
-
new_crops:
|
|
34
|
+
crop_map: list[int | tuple[int, int]] = []
|
|
35
|
+
new_crops: list[np.ndarray] = []
|
|
39
36
|
for crop in crops:
|
|
40
37
|
h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
|
|
41
38
|
aspect_ratio = w / h
|
|
@@ -71,8 +68,8 @@ def split_crops(
|
|
|
71
68
|
|
|
72
69
|
|
|
73
70
|
def remap_preds(
|
|
74
|
-
preds:
|
|
75
|
-
) ->
|
|
71
|
+
preds: list[tuple[str, float]], crop_map: list[int | tuple[int, int]], dilation: float
|
|
72
|
+
) -> list[tuple[str, float]]:
|
|
76
73
|
remapped_out = []
|
|
77
74
|
for _idx in crop_map:
|
|
78
75
|
# Crop hasn't been split
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import torch
|
|
@@ -21,7 +22,6 @@ class RecognitionPredictor(nn.Module):
|
|
|
21
22
|
"""Implements an object able to identify character sequences in images
|
|
22
23
|
|
|
23
24
|
Args:
|
|
24
|
-
----
|
|
25
25
|
pre_processor: transform inputs for easier batched model inference
|
|
26
26
|
model: core detection architecture
|
|
27
27
|
split_wide_crops: wether to use crop splitting for high aspect ratio crops
|
|
@@ -44,9 +44,9 @@ class RecognitionPredictor(nn.Module):
|
|
|
44
44
|
@torch.inference_mode()
|
|
45
45
|
def forward(
|
|
46
46
|
self,
|
|
47
|
-
crops: Sequence[
|
|
47
|
+
crops: Sequence[np.ndarray | torch.Tensor],
|
|
48
48
|
**kwargs: Any,
|
|
49
|
-
) ->
|
|
49
|
+
) -> list[tuple[str, float]]:
|
|
50
50
|
if len(crops) == 0:
|
|
51
51
|
return []
|
|
52
52
|
# Dimension check
|
|
@@ -67,7 +67,7 @@ class RecognitionPredictor(nn.Module):
|
|
|
67
67
|
crops = new_crops
|
|
68
68
|
|
|
69
69
|
# Resize & batch them
|
|
70
|
-
processed_batches = self.pre_processor(crops)
|
|
70
|
+
processed_batches = self.pre_processor(crops) # type: ignore[arg-type]
|
|
71
71
|
|
|
72
72
|
# Forward it
|
|
73
73
|
_params = next(self.model.parameters())
|