python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +2 -0
- doctr/datasets/cord.py +6 -4
- doctr/datasets/datasets/base.py +3 -2
- doctr/datasets/datasets/pytorch.py +4 -2
- doctr/datasets/datasets/tensorflow.py +4 -2
- doctr/datasets/detection.py +6 -3
- doctr/datasets/doc_artefacts.py +2 -1
- doctr/datasets/funsd.py +7 -8
- doctr/datasets/generator/base.py +3 -2
- doctr/datasets/generator/pytorch.py +3 -1
- doctr/datasets/generator/tensorflow.py +3 -1
- doctr/datasets/ic03.py +3 -2
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +6 -4
- doctr/datasets/iiithws.py +2 -1
- doctr/datasets/imgur5k.py +3 -2
- doctr/datasets/loader.py +4 -2
- doctr/datasets/mjsynth.py +2 -1
- doctr/datasets/ocr.py +2 -1
- doctr/datasets/orientation.py +40 -0
- doctr/datasets/recognition.py +3 -2
- doctr/datasets/sroie.py +2 -1
- doctr/datasets/svhn.py +2 -1
- doctr/datasets/svt.py +3 -2
- doctr/datasets/synthtext.py +2 -1
- doctr/datasets/utils.py +27 -11
- doctr/datasets/vocabs.py +26 -1
- doctr/datasets/wildreceipt.py +111 -0
- doctr/file_utils.py +3 -1
- doctr/io/elements.py +52 -35
- doctr/io/html.py +5 -3
- doctr/io/image/base.py +5 -4
- doctr/io/image/pytorch.py +12 -7
- doctr/io/image/tensorflow.py +11 -6
- doctr/io/pdf.py +5 -4
- doctr/io/reader.py +13 -5
- doctr/models/_utils.py +30 -53
- doctr/models/artefacts/barcode.py +4 -3
- doctr/models/artefacts/face.py +4 -2
- doctr/models/builder.py +58 -43
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +5 -2
- doctr/models/classification/magc_resnet/tensorflow.py +5 -2
- doctr/models/classification/mobilenet/pytorch.py +16 -4
- doctr/models/classification/mobilenet/tensorflow.py +29 -20
- doctr/models/classification/predictor/pytorch.py +3 -2
- doctr/models/classification/predictor/tensorflow.py +2 -1
- doctr/models/classification/resnet/pytorch.py +23 -13
- doctr/models/classification/resnet/tensorflow.py +33 -26
- doctr/models/classification/textnet/__init__.py +6 -0
- doctr/models/classification/textnet/pytorch.py +275 -0
- doctr/models/classification/textnet/tensorflow.py +267 -0
- doctr/models/classification/vgg/pytorch.py +4 -2
- doctr/models/classification/vgg/tensorflow.py +5 -2
- doctr/models/classification/vit/pytorch.py +9 -3
- doctr/models/classification/vit/tensorflow.py +9 -3
- doctr/models/classification/zoo.py +7 -2
- doctr/models/core.py +1 -1
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/pytorch.py +7 -1
- doctr/models/detection/_utils/tensorflow.py +7 -3
- doctr/models/detection/core.py +9 -3
- doctr/models/detection/differentiable_binarization/base.py +37 -25
- doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
- doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +256 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +12 -5
- doctr/models/detection/linknet/pytorch.py +28 -15
- doctr/models/detection/linknet/tensorflow.py +68 -88
- doctr/models/detection/predictor/pytorch.py +16 -6
- doctr/models/detection/predictor/tensorflow.py +13 -5
- doctr/models/detection/zoo.py +19 -16
- doctr/models/factory/hub.py +20 -10
- doctr/models/kie_predictor/base.py +2 -1
- doctr/models/kie_predictor/pytorch.py +28 -36
- doctr/models/kie_predictor/tensorflow.py +27 -27
- doctr/models/modules/__init__.py +1 -0
- doctr/models/modules/layers/__init__.py +6 -0
- doctr/models/modules/layers/pytorch.py +166 -0
- doctr/models/modules/layers/tensorflow.py +175 -0
- doctr/models/modules/transformer/pytorch.py +24 -22
- doctr/models/modules/transformer/tensorflow.py +6 -4
- doctr/models/modules/vision_transformer/pytorch.py +2 -4
- doctr/models/modules/vision_transformer/tensorflow.py +2 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
- doctr/models/predictor/base.py +14 -3
- doctr/models/predictor/pytorch.py +26 -29
- doctr/models/predictor/tensorflow.py +25 -22
- doctr/models/preprocessor/pytorch.py +14 -9
- doctr/models/preprocessor/tensorflow.py +10 -5
- doctr/models/recognition/core.py +4 -1
- doctr/models/recognition/crnn/pytorch.py +23 -16
- doctr/models/recognition/crnn/tensorflow.py +25 -17
- doctr/models/recognition/master/base.py +4 -1
- doctr/models/recognition/master/pytorch.py +20 -9
- doctr/models/recognition/master/tensorflow.py +20 -8
- doctr/models/recognition/parseq/base.py +4 -1
- doctr/models/recognition/parseq/pytorch.py +28 -22
- doctr/models/recognition/parseq/tensorflow.py +22 -11
- doctr/models/recognition/predictor/_utils.py +3 -2
- doctr/models/recognition/predictor/pytorch.py +3 -2
- doctr/models/recognition/predictor/tensorflow.py +2 -1
- doctr/models/recognition/sar/pytorch.py +14 -7
- doctr/models/recognition/sar/tensorflow.py +23 -14
- doctr/models/recognition/utils.py +5 -1
- doctr/models/recognition/vitstr/base.py +4 -1
- doctr/models/recognition/vitstr/pytorch.py +22 -13
- doctr/models/recognition/vitstr/tensorflow.py +21 -10
- doctr/models/recognition/zoo.py +4 -2
- doctr/models/utils/pytorch.py +24 -6
- doctr/models/utils/tensorflow.py +22 -3
- doctr/models/zoo.py +21 -3
- doctr/transforms/functional/base.py +8 -3
- doctr/transforms/functional/pytorch.py +23 -6
- doctr/transforms/functional/tensorflow.py +25 -5
- doctr/transforms/modules/base.py +12 -5
- doctr/transforms/modules/pytorch.py +10 -12
- doctr/transforms/modules/tensorflow.py +17 -9
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +4 -2
- doctr/utils/fonts.py +3 -2
- doctr/utils/geometry.py +95 -26
- doctr/utils/metrics.py +36 -22
- doctr/utils/multithreading.py +5 -3
- doctr/utils/repr.py +3 -1
- doctr/utils/visualization.py +31 -8
- doctr/version.py +1 -1
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
- python_doctr-0.8.1.dist-info/RECORD +173 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
- python_doctr-0.7.0.dist-info/RECORD +0 -161
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -45,10 +45,10 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
class CTCPostProcessor(RecognitionPostProcessor):
|
|
48
|
-
"""
|
|
49
|
-
Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
48
|
+
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
50
49
|
|
|
51
50
|
Args:
|
|
51
|
+
----
|
|
52
52
|
vocab: string containing the ordered sequence of supported characters
|
|
53
53
|
"""
|
|
54
54
|
|
|
@@ -62,14 +62,15 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
62
62
|
<https://github.com/githubharald/CTCDecoder>`_.
|
|
63
63
|
|
|
64
64
|
Args:
|
|
65
|
+
----
|
|
65
66
|
logits: model output, shape: N x T x C
|
|
66
67
|
vocab: vocabulary to use
|
|
67
68
|
blank: index of blank label
|
|
68
69
|
|
|
69
70
|
Returns:
|
|
71
|
+
-------
|
|
70
72
|
A list of tuples: (word, confidence)
|
|
71
73
|
"""
|
|
72
|
-
|
|
73
74
|
# Gather the most confident characters, and assign the smallest conf among those to the sequence prob
|
|
74
75
|
probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values
|
|
75
76
|
|
|
@@ -82,14 +83,15 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
82
83
|
return list(zip(words, probs.tolist()))
|
|
83
84
|
|
|
84
85
|
def __call__(self, logits: torch.Tensor) -> List[Tuple[str, float]]:
|
|
85
|
-
"""
|
|
86
|
-
Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
86
|
+
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
87
87
|
with label_to_idx mapping dictionnary
|
|
88
88
|
|
|
89
89
|
Args:
|
|
90
|
+
----
|
|
90
91
|
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
91
92
|
|
|
92
93
|
Returns:
|
|
94
|
+
-------
|
|
93
95
|
A tuple of 2 lists: a list of str (words) and a list of float (probs)
|
|
94
96
|
|
|
95
97
|
"""
|
|
@@ -102,6 +104,7 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
102
104
|
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
103
105
|
|
|
104
106
|
Args:
|
|
107
|
+
----
|
|
105
108
|
feature_extractor: the backbone serving as feature extractor
|
|
106
109
|
vocab: vocabulary used for encoding
|
|
107
110
|
rnn_units: number of units in the LSTM layers
|
|
@@ -128,12 +131,9 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
128
131
|
self.feat_extractor = feature_extractor
|
|
129
132
|
|
|
130
133
|
# Resolve the input_size of the LSTM
|
|
131
|
-
|
|
132
|
-
with torch.no_grad():
|
|
134
|
+
with torch.inference_mode():
|
|
133
135
|
out_shape = self.feat_extractor(torch.zeros((1, *input_shape))).shape
|
|
134
136
|
lstm_in = out_shape[1] * out_shape[2]
|
|
135
|
-
# Switch back to original mode
|
|
136
|
-
self.feat_extractor.train()
|
|
137
137
|
|
|
138
138
|
self.decoder = nn.LSTM(
|
|
139
139
|
input_size=lstm_in,
|
|
@@ -168,11 +168,12 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
168
168
|
"""Compute CTC loss for the model.
|
|
169
169
|
|
|
170
170
|
Args:
|
|
171
|
-
|
|
171
|
+
----
|
|
172
172
|
model_output: predicted logits of the model
|
|
173
|
-
|
|
173
|
+
target: list of target strings
|
|
174
174
|
|
|
175
175
|
Returns:
|
|
176
|
+
-------
|
|
176
177
|
The loss of the model on the batch
|
|
177
178
|
"""
|
|
178
179
|
gt, seq_len = self.build_target(target)
|
|
@@ -249,7 +250,7 @@ def _crnn(
|
|
|
249
250
|
_cfg["input_shape"] = kwargs["input_shape"]
|
|
250
251
|
|
|
251
252
|
# Build the model
|
|
252
|
-
model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
|
|
253
|
+
model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
|
|
253
254
|
# Load pretrained parameters
|
|
254
255
|
if pretrained:
|
|
255
256
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
@@ -271,12 +272,14 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
271
272
|
>>> out = model(input_tensor)
|
|
272
273
|
|
|
273
274
|
Args:
|
|
275
|
+
----
|
|
274
276
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
277
|
+
**kwargs: keyword arguments of the CRNN architecture
|
|
275
278
|
|
|
276
279
|
Returns:
|
|
280
|
+
-------
|
|
277
281
|
text recognition architecture
|
|
278
282
|
"""
|
|
279
|
-
|
|
280
283
|
return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, ignore_keys=["linear.weight", "linear.bias"], **kwargs)
|
|
281
284
|
|
|
282
285
|
|
|
@@ -291,12 +294,14 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
291
294
|
>>> out = model(input_tensor)
|
|
292
295
|
|
|
293
296
|
Args:
|
|
297
|
+
----
|
|
294
298
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
299
|
+
**kwargs: keyword arguments of the CRNN architecture
|
|
295
300
|
|
|
296
301
|
Returns:
|
|
302
|
+
-------
|
|
297
303
|
text recognition architecture
|
|
298
304
|
"""
|
|
299
|
-
|
|
300
305
|
return _crnn(
|
|
301
306
|
"crnn_mobilenet_v3_small",
|
|
302
307
|
pretrained,
|
|
@@ -317,12 +322,14 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
317
322
|
>>> out = model(input_tensor)
|
|
318
323
|
|
|
319
324
|
Args:
|
|
325
|
+
----
|
|
320
326
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
327
|
+
**kwargs: keyword arguments of the CRNN architecture
|
|
321
328
|
|
|
322
329
|
Returns:
|
|
330
|
+
-------
|
|
323
331
|
text recognition architecture
|
|
324
332
|
"""
|
|
325
|
-
|
|
326
333
|
return _crnn(
|
|
327
334
|
"crnn_mobilenet_v3_large",
|
|
328
335
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -13,7 +13,7 @@ from tensorflow.keras.models import Model, Sequential
|
|
|
13
13
|
from doctr.datasets import VOCABS
|
|
14
14
|
|
|
15
15
|
from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
|
|
16
|
-
from ...utils.tensorflow import load_pretrained_params
|
|
16
|
+
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
17
17
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
18
18
|
|
|
19
19
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
@@ -44,10 +44,10 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
class CTCPostProcessor(RecognitionPostProcessor):
|
|
47
|
-
"""
|
|
48
|
-
Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
47
|
+
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
49
48
|
|
|
50
49
|
Args:
|
|
50
|
+
----
|
|
51
51
|
vocab: string containing the ordered sequence of supported characters
|
|
52
52
|
ignore_case: if True, ignore case of letters
|
|
53
53
|
ignore_accents: if True, ignore accents of letters
|
|
@@ -59,16 +59,17 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
59
59
|
beam_width: int = 1,
|
|
60
60
|
top_paths: int = 1,
|
|
61
61
|
) -> Union[List[Tuple[str, float]], List[Tuple[List[str], List[float]]]]:
|
|
62
|
-
"""
|
|
63
|
-
Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
62
|
+
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
64
63
|
with label_to_idx mapping dictionnary
|
|
65
64
|
|
|
66
65
|
Args:
|
|
66
|
+
----
|
|
67
67
|
logits: raw output of the model, shape BATCH_SIZE X SEQ_LEN X NUM_CLASSES + 1
|
|
68
68
|
beam_width: An int scalar >= 0 (beam search beam width).
|
|
69
69
|
top_paths: An int scalar >= 0, <= beam_width (controls output size).
|
|
70
70
|
|
|
71
71
|
Returns:
|
|
72
|
+
-------
|
|
72
73
|
A list of decoded words of length BATCH_SIZE
|
|
73
74
|
|
|
74
75
|
|
|
@@ -113,6 +114,7 @@ class CRNN(RecognitionModel, Model):
|
|
|
113
114
|
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
114
115
|
|
|
115
116
|
Args:
|
|
117
|
+
----
|
|
116
118
|
feature_extractor: the backbone serving as feature extractor
|
|
117
119
|
vocab: vocabulary used for encoding
|
|
118
120
|
rnn_units: number of units in the LSTM layers
|
|
@@ -144,13 +146,11 @@ class CRNN(RecognitionModel, Model):
|
|
|
144
146
|
self.exportable = exportable
|
|
145
147
|
self.feat_extractor = feature_extractor
|
|
146
148
|
|
|
147
|
-
self.decoder = Sequential(
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
]
|
|
153
|
-
)
|
|
149
|
+
self.decoder = Sequential([
|
|
150
|
+
layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
|
|
151
|
+
layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
|
|
152
|
+
layers.Dense(units=len(vocab) + 1),
|
|
153
|
+
])
|
|
154
154
|
self.decoder.build(input_shape=(None, w, h * c))
|
|
155
155
|
|
|
156
156
|
self.postprocessor = CTCPostProcessor(vocab=vocab)
|
|
@@ -166,10 +166,12 @@ class CRNN(RecognitionModel, Model):
|
|
|
166
166
|
"""Compute CTC loss for the model.
|
|
167
167
|
|
|
168
168
|
Args:
|
|
169
|
+
----
|
|
169
170
|
model_output: predicted logits of the model
|
|
170
171
|
target: lengths of each gt word inside the batch
|
|
171
172
|
|
|
172
173
|
Returns:
|
|
174
|
+
-------
|
|
173
175
|
The loss of the model on the batch
|
|
174
176
|
"""
|
|
175
177
|
gt, seq_len = self.build_target(target)
|
|
@@ -199,7 +201,7 @@ class CRNN(RecognitionModel, Model):
|
|
|
199
201
|
w, h, c = transposed_feat.get_shape().as_list()[1:]
|
|
200
202
|
# B x W x H x C --> B x W x H * C
|
|
201
203
|
features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
|
|
202
|
-
logits = self.decoder(features_seq, **kwargs)
|
|
204
|
+
logits = _bf16_to_float32(self.decoder(features_seq, **kwargs))
|
|
203
205
|
|
|
204
206
|
out: Dict[str, tf.Tensor] = {}
|
|
205
207
|
if self.exportable:
|
|
@@ -261,12 +263,14 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
261
263
|
>>> out = model(input_tensor)
|
|
262
264
|
|
|
263
265
|
Args:
|
|
266
|
+
----
|
|
264
267
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
268
|
+
**kwargs: keyword arguments of the CRNN architecture
|
|
265
269
|
|
|
266
270
|
Returns:
|
|
271
|
+
-------
|
|
267
272
|
text recognition architecture
|
|
268
273
|
"""
|
|
269
|
-
|
|
270
274
|
return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, **kwargs)
|
|
271
275
|
|
|
272
276
|
|
|
@@ -281,12 +285,14 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
281
285
|
>>> out = model(input_tensor)
|
|
282
286
|
|
|
283
287
|
Args:
|
|
288
|
+
----
|
|
284
289
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
290
|
+
**kwargs: keyword arguments of the CRNN architecture
|
|
285
291
|
|
|
286
292
|
Returns:
|
|
293
|
+
-------
|
|
287
294
|
text recognition architecture
|
|
288
295
|
"""
|
|
289
|
-
|
|
290
296
|
return _crnn("crnn_mobilenet_v3_small", pretrained, mobilenet_v3_small_r, **kwargs)
|
|
291
297
|
|
|
292
298
|
|
|
@@ -301,10 +307,12 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
|
301
307
|
>>> out = model(input_tensor)
|
|
302
308
|
|
|
303
309
|
Args:
|
|
310
|
+
----
|
|
304
311
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
312
|
+
**kwargs: keyword arguments of the CRNN architecture
|
|
305
313
|
|
|
306
314
|
Returns:
|
|
315
|
+
-------
|
|
307
316
|
text recognition architecture
|
|
308
317
|
"""
|
|
309
|
-
|
|
310
318
|
return _crnn("crnn_mobilenet_v3_large", pretrained, mobilenet_v3_large_r, **kwargs)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -23,9 +23,11 @@ class _MASTER:
|
|
|
23
23
|
sequence lengths.
|
|
24
24
|
|
|
25
25
|
Args:
|
|
26
|
+
----
|
|
26
27
|
gts: list of ground-truth labels
|
|
27
28
|
|
|
28
29
|
Returns:
|
|
30
|
+
-------
|
|
29
31
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
30
32
|
"""
|
|
31
33
|
encoded = encode_sequences(
|
|
@@ -44,6 +46,7 @@ class _MASTERPostProcessor(RecognitionPostProcessor):
|
|
|
44
46
|
"""Abstract class to postprocess the raw output of the model
|
|
45
47
|
|
|
46
48
|
Args:
|
|
49
|
+
----
|
|
47
50
|
vocab: string containing the ordered sequence of supported characters
|
|
48
51
|
"""
|
|
49
52
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -15,7 +15,7 @@ from doctr.datasets import VOCABS
|
|
|
15
15
|
from doctr.models.classification import magc_resnet31
|
|
16
16
|
from doctr.models.modules.transformer import Decoder, PositionalEncoding
|
|
17
17
|
|
|
18
|
-
from ...utils.pytorch import load_pretrained_params
|
|
18
|
+
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
|
|
19
19
|
from .base import _MASTER, _MASTERPostProcessor
|
|
20
20
|
|
|
21
21
|
__all__ = ["MASTER", "master"]
|
|
@@ -27,7 +27,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
27
27
|
"std": (0.299, 0.296, 0.301),
|
|
28
28
|
"input_shape": (3, 32, 128),
|
|
29
29
|
"vocab": VOCABS["french"],
|
|
30
|
-
"url":
|
|
30
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/master-fde31e4a.pt&src=0",
|
|
31
31
|
},
|
|
32
32
|
}
|
|
33
33
|
|
|
@@ -37,6 +37,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
37
37
|
Implementation based on the official Pytorch implementation: <https://github.com/wenwenyu/MASTER-pytorch>`_.
|
|
38
38
|
|
|
39
39
|
Args:
|
|
40
|
+
----
|
|
40
41
|
feature_extractor: the backbone serving as feature extractor
|
|
41
42
|
vocab: vocabulary, (without EOS, SOS, PAD)
|
|
42
43
|
d_model: d parameter for the transformer decoder
|
|
@@ -105,7 +106,8 @@ class MASTER(_MASTER, nn.Module):
|
|
|
105
106
|
# borrowed and slightly modified from https://github.com/wenwenyu/MASTER-pytorch
|
|
106
107
|
# NOTE: nn.TransformerDecoder takes the inverse from this implementation
|
|
107
108
|
# [True, True, True, ..., False, False, False] -> False is masked
|
|
108
|
-
|
|
109
|
+
# (N, 1, 1, max_length)
|
|
110
|
+
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
|
|
109
111
|
target_length = target.size(1)
|
|
110
112
|
# sub mask filled diagonal with True = see and False = masked (max_length, max_length)
|
|
111
113
|
# NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
|
|
@@ -128,17 +130,19 @@ class MASTER(_MASTER, nn.Module):
|
|
|
128
130
|
Sequences are masked after the EOS character.
|
|
129
131
|
|
|
130
132
|
Args:
|
|
133
|
+
----
|
|
131
134
|
gt: the encoded tensor with gt labels
|
|
132
135
|
model_output: predicted logits of the model
|
|
133
136
|
seq_len: lengths of each gt word inside the batch
|
|
134
137
|
|
|
135
138
|
Returns:
|
|
139
|
+
-------
|
|
136
140
|
The loss of the model on the batch
|
|
137
141
|
"""
|
|
138
142
|
# Input length : number of timesteps
|
|
139
143
|
input_len = model_output.shape[1]
|
|
140
144
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
141
|
-
seq_len = seq_len + 1
|
|
145
|
+
seq_len = seq_len + 1 # type: ignore[assignment]
|
|
142
146
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
143
147
|
# The "masked" first gt char is <sos>. Delete last logit of the model output.
|
|
144
148
|
cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -159,15 +163,16 @@ class MASTER(_MASTER, nn.Module):
|
|
|
159
163
|
"""Call function for training
|
|
160
164
|
|
|
161
165
|
Args:
|
|
166
|
+
----
|
|
162
167
|
x: images
|
|
163
168
|
target: list of str labels
|
|
164
169
|
return_model_output: if True, return logits
|
|
165
170
|
return_preds: if True, decode logits
|
|
166
171
|
|
|
167
172
|
Returns:
|
|
173
|
+
-------
|
|
168
174
|
A dictionnary containing eventually loss, logits and predictions.
|
|
169
175
|
"""
|
|
170
|
-
|
|
171
176
|
# Encode
|
|
172
177
|
features = self.feat_extractor(x)["features"]
|
|
173
178
|
b, c, h, w = features.shape
|
|
@@ -195,6 +200,8 @@ class MASTER(_MASTER, nn.Module):
|
|
|
195
200
|
else:
|
|
196
201
|
logits = self.decode(encoded)
|
|
197
202
|
|
|
203
|
+
logits = _bf16_to_float32(logits)
|
|
204
|
+
|
|
198
205
|
if self.exportable:
|
|
199
206
|
out["logits"] = logits
|
|
200
207
|
return out
|
|
@@ -214,9 +221,11 @@ class MASTER(_MASTER, nn.Module):
|
|
|
214
221
|
"""Decode function for prediction
|
|
215
222
|
|
|
216
223
|
Args:
|
|
224
|
+
----
|
|
217
225
|
encoded: input tensor
|
|
218
226
|
|
|
219
|
-
|
|
227
|
+
Returns:
|
|
228
|
+
-------
|
|
220
229
|
A Tuple of torch.Tensor: predictions, logits
|
|
221
230
|
"""
|
|
222
231
|
b = encoded.size(0)
|
|
@@ -259,7 +268,7 @@ class MASTERPostProcessor(_MASTERPostProcessor):
|
|
|
259
268
|
for encoded_seq in out_idxs.cpu().numpy()
|
|
260
269
|
]
|
|
261
270
|
|
|
262
|
-
return list(zip(word_values, probs.numpy().tolist()))
|
|
271
|
+
return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
|
|
263
272
|
|
|
264
273
|
|
|
265
274
|
def _master(
|
|
@@ -307,12 +316,14 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
|
|
|
307
316
|
>>> out = model(input_tensor)
|
|
308
317
|
|
|
309
318
|
Args:
|
|
319
|
+
----
|
|
310
320
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
321
|
+
**kwargs: keywoard arguments passed to the MASTER architecture
|
|
311
322
|
|
|
312
323
|
Returns:
|
|
324
|
+
-------
|
|
313
325
|
text recognition architecture
|
|
314
326
|
"""
|
|
315
|
-
|
|
316
327
|
return _master(
|
|
317
328
|
"master",
|
|
318
329
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -13,7 +13,7 @@ from doctr.datasets import VOCABS
|
|
|
13
13
|
from doctr.models.classification import magc_resnet31
|
|
14
14
|
from doctr.models.modules.transformer import Decoder, PositionalEncoding
|
|
15
15
|
|
|
16
|
-
from ...utils.tensorflow import load_pretrained_params
|
|
16
|
+
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
17
17
|
from .base import _MASTER, _MASTERPostProcessor
|
|
18
18
|
|
|
19
19
|
__all__ = ["MASTER", "master"]
|
|
@@ -31,11 +31,11 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class MASTER(_MASTER, Model):
|
|
34
|
-
|
|
35
34
|
"""Implements MASTER as described in paper: <https://arxiv.org/pdf/1910.02562.pdf>`_.
|
|
36
35
|
Implementation based on the official TF implementation: <https://github.com/jiangxiluning/MASTER-TF>`_.
|
|
37
36
|
|
|
38
37
|
Args:
|
|
38
|
+
----
|
|
39
39
|
feature_extractor: the backbone serving as feature extractor
|
|
40
40
|
vocab: vocabulary, (without EOS, SOS, PAD)
|
|
41
41
|
d_model: d parameter for the transformer decoder
|
|
@@ -115,11 +115,13 @@ class MASTER(_MASTER, Model):
|
|
|
115
115
|
Sequences are masked after the EOS character.
|
|
116
116
|
|
|
117
117
|
Args:
|
|
118
|
+
----
|
|
118
119
|
gt: the encoded tensor with gt labels
|
|
119
120
|
model_output: predicted logits of the model
|
|
120
121
|
seq_len: lengths of each gt word inside the batch
|
|
121
122
|
|
|
122
123
|
Returns:
|
|
124
|
+
-------
|
|
123
125
|
The loss of the model on the batch
|
|
124
126
|
"""
|
|
125
127
|
# Input length : number of timesteps
|
|
@@ -150,15 +152,17 @@ class MASTER(_MASTER, Model):
|
|
|
150
152
|
"""Call function for training
|
|
151
153
|
|
|
152
154
|
Args:
|
|
155
|
+
----
|
|
153
156
|
x: images
|
|
154
157
|
target: list of str labels
|
|
155
158
|
return_model_output: if True, return logits
|
|
156
159
|
return_preds: if True, decode logits
|
|
160
|
+
**kwargs: keyword arguments passed to the decoder
|
|
157
161
|
|
|
158
|
-
|
|
162
|
+
Returns:
|
|
163
|
+
-------
|
|
159
164
|
A dictionnary containing eventually loss, logits and predictions.
|
|
160
165
|
"""
|
|
161
|
-
|
|
162
166
|
# Encode
|
|
163
167
|
feature = self.feat_extractor(x, **kwargs)
|
|
164
168
|
b, h, w, c = feature.get_shape()
|
|
@@ -183,6 +187,8 @@ class MASTER(_MASTER, Model):
|
|
|
183
187
|
else:
|
|
184
188
|
logits = self.decode(encoded, **kwargs)
|
|
185
189
|
|
|
190
|
+
logits = _bf16_to_float32(logits)
|
|
191
|
+
|
|
186
192
|
if self.exportable:
|
|
187
193
|
out["logits"] = logits
|
|
188
194
|
return out
|
|
@@ -203,9 +209,12 @@ class MASTER(_MASTER, Model):
|
|
|
203
209
|
"""Decode function for prediction
|
|
204
210
|
|
|
205
211
|
Args:
|
|
212
|
+
----
|
|
206
213
|
encoded: encoded features
|
|
214
|
+
**kwargs: keyword arguments passed to the decoder
|
|
207
215
|
|
|
208
|
-
|
|
216
|
+
Returns:
|
|
217
|
+
-------
|
|
209
218
|
A Tuple of tf.Tensor: predictions, logits
|
|
210
219
|
"""
|
|
211
220
|
b = encoded.shape[0]
|
|
@@ -238,6 +247,7 @@ class MASTERPostProcessor(_MASTERPostProcessor):
|
|
|
238
247
|
"""Post processor for MASTER architectures
|
|
239
248
|
|
|
240
249
|
Args:
|
|
250
|
+
----
|
|
241
251
|
vocab: string containing the ordered sequence of supported characters
|
|
242
252
|
"""
|
|
243
253
|
|
|
@@ -260,7 +270,7 @@ class MASTERPostProcessor(_MASTERPostProcessor):
|
|
|
260
270
|
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
|
|
261
271
|
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
|
|
262
272
|
|
|
263
|
-
return list(zip(word_values, probs.numpy().tolist()))
|
|
273
|
+
return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
|
|
264
274
|
|
|
265
275
|
|
|
266
276
|
def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool = True, **kwargs: Any) -> MASTER:
|
|
@@ -297,10 +307,12 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
|
|
|
297
307
|
>>> out = model(input_tensor)
|
|
298
308
|
|
|
299
309
|
Args:
|
|
310
|
+
----
|
|
300
311
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
312
|
+
**kwargs: keywoard arguments passed to the MASTER architecture
|
|
301
313
|
|
|
302
314
|
Returns:
|
|
315
|
+
-------
|
|
303
316
|
text recognition architecture
|
|
304
317
|
"""
|
|
305
|
-
|
|
306
318
|
return _master("master", pretrained, magc_resnet31, **kwargs)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -23,9 +23,11 @@ class _PARSeq:
|
|
|
23
23
|
sequence lengths.
|
|
24
24
|
|
|
25
25
|
Args:
|
|
26
|
+
----
|
|
26
27
|
gts: list of ground-truth labels
|
|
27
28
|
|
|
28
29
|
Returns:
|
|
30
|
+
-------
|
|
29
31
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
30
32
|
"""
|
|
31
33
|
encoded = encode_sequences(
|
|
@@ -44,6 +46,7 @@ class _PARSeqPostProcessor(RecognitionPostProcessor):
|
|
|
44
46
|
"""Abstract class to postprocess the raw output of the model
|
|
45
47
|
|
|
46
48
|
Args:
|
|
49
|
+
----
|
|
47
50
|
vocab: string containing the ordered sequence of supported characters
|
|
48
51
|
"""
|
|
49
52
|
|