python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +2 -0
- doctr/datasets/cord.py +6 -4
- doctr/datasets/datasets/base.py +3 -2
- doctr/datasets/datasets/pytorch.py +4 -2
- doctr/datasets/datasets/tensorflow.py +4 -2
- doctr/datasets/detection.py +6 -3
- doctr/datasets/doc_artefacts.py +2 -1
- doctr/datasets/funsd.py +7 -8
- doctr/datasets/generator/base.py +3 -2
- doctr/datasets/generator/pytorch.py +3 -1
- doctr/datasets/generator/tensorflow.py +3 -1
- doctr/datasets/ic03.py +3 -2
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +6 -4
- doctr/datasets/iiithws.py +2 -1
- doctr/datasets/imgur5k.py +3 -2
- doctr/datasets/loader.py +4 -2
- doctr/datasets/mjsynth.py +2 -1
- doctr/datasets/ocr.py +2 -1
- doctr/datasets/orientation.py +40 -0
- doctr/datasets/recognition.py +3 -2
- doctr/datasets/sroie.py +2 -1
- doctr/datasets/svhn.py +2 -1
- doctr/datasets/svt.py +3 -2
- doctr/datasets/synthtext.py +2 -1
- doctr/datasets/utils.py +27 -11
- doctr/datasets/vocabs.py +26 -1
- doctr/datasets/wildreceipt.py +111 -0
- doctr/file_utils.py +3 -1
- doctr/io/elements.py +52 -35
- doctr/io/html.py +5 -3
- doctr/io/image/base.py +5 -4
- doctr/io/image/pytorch.py +12 -7
- doctr/io/image/tensorflow.py +11 -6
- doctr/io/pdf.py +5 -4
- doctr/io/reader.py +13 -5
- doctr/models/_utils.py +30 -53
- doctr/models/artefacts/barcode.py +4 -3
- doctr/models/artefacts/face.py +4 -2
- doctr/models/builder.py +58 -43
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +5 -2
- doctr/models/classification/magc_resnet/tensorflow.py +5 -2
- doctr/models/classification/mobilenet/pytorch.py +16 -4
- doctr/models/classification/mobilenet/tensorflow.py +29 -20
- doctr/models/classification/predictor/pytorch.py +3 -2
- doctr/models/classification/predictor/tensorflow.py +2 -1
- doctr/models/classification/resnet/pytorch.py +23 -13
- doctr/models/classification/resnet/tensorflow.py +33 -26
- doctr/models/classification/textnet/__init__.py +6 -0
- doctr/models/classification/textnet/pytorch.py +275 -0
- doctr/models/classification/textnet/tensorflow.py +267 -0
- doctr/models/classification/vgg/pytorch.py +4 -2
- doctr/models/classification/vgg/tensorflow.py +5 -2
- doctr/models/classification/vit/pytorch.py +9 -3
- doctr/models/classification/vit/tensorflow.py +9 -3
- doctr/models/classification/zoo.py +7 -2
- doctr/models/core.py +1 -1
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/pytorch.py +7 -1
- doctr/models/detection/_utils/tensorflow.py +7 -3
- doctr/models/detection/core.py +9 -3
- doctr/models/detection/differentiable_binarization/base.py +37 -25
- doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
- doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +256 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +12 -5
- doctr/models/detection/linknet/pytorch.py +28 -15
- doctr/models/detection/linknet/tensorflow.py +68 -88
- doctr/models/detection/predictor/pytorch.py +16 -6
- doctr/models/detection/predictor/tensorflow.py +13 -5
- doctr/models/detection/zoo.py +19 -16
- doctr/models/factory/hub.py +20 -10
- doctr/models/kie_predictor/base.py +2 -1
- doctr/models/kie_predictor/pytorch.py +28 -36
- doctr/models/kie_predictor/tensorflow.py +27 -27
- doctr/models/modules/__init__.py +1 -0
- doctr/models/modules/layers/__init__.py +6 -0
- doctr/models/modules/layers/pytorch.py +166 -0
- doctr/models/modules/layers/tensorflow.py +175 -0
- doctr/models/modules/transformer/pytorch.py +24 -22
- doctr/models/modules/transformer/tensorflow.py +6 -4
- doctr/models/modules/vision_transformer/pytorch.py +2 -4
- doctr/models/modules/vision_transformer/tensorflow.py +2 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
- doctr/models/predictor/base.py +14 -3
- doctr/models/predictor/pytorch.py +26 -29
- doctr/models/predictor/tensorflow.py +25 -22
- doctr/models/preprocessor/pytorch.py +14 -9
- doctr/models/preprocessor/tensorflow.py +10 -5
- doctr/models/recognition/core.py +4 -1
- doctr/models/recognition/crnn/pytorch.py +23 -16
- doctr/models/recognition/crnn/tensorflow.py +25 -17
- doctr/models/recognition/master/base.py +4 -1
- doctr/models/recognition/master/pytorch.py +20 -9
- doctr/models/recognition/master/tensorflow.py +20 -8
- doctr/models/recognition/parseq/base.py +4 -1
- doctr/models/recognition/parseq/pytorch.py +28 -22
- doctr/models/recognition/parseq/tensorflow.py +22 -11
- doctr/models/recognition/predictor/_utils.py +3 -2
- doctr/models/recognition/predictor/pytorch.py +3 -2
- doctr/models/recognition/predictor/tensorflow.py +2 -1
- doctr/models/recognition/sar/pytorch.py +14 -7
- doctr/models/recognition/sar/tensorflow.py +23 -14
- doctr/models/recognition/utils.py +5 -1
- doctr/models/recognition/vitstr/base.py +4 -1
- doctr/models/recognition/vitstr/pytorch.py +22 -13
- doctr/models/recognition/vitstr/tensorflow.py +21 -10
- doctr/models/recognition/zoo.py +4 -2
- doctr/models/utils/pytorch.py +24 -6
- doctr/models/utils/tensorflow.py +22 -3
- doctr/models/zoo.py +21 -3
- doctr/transforms/functional/base.py +8 -3
- doctr/transforms/functional/pytorch.py +23 -6
- doctr/transforms/functional/tensorflow.py +25 -5
- doctr/transforms/modules/base.py +12 -5
- doctr/transforms/modules/pytorch.py +10 -12
- doctr/transforms/modules/tensorflow.py +17 -9
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +4 -2
- doctr/utils/fonts.py +3 -2
- doctr/utils/geometry.py +95 -26
- doctr/utils/metrics.py +36 -22
- doctr/utils/multithreading.py +5 -3
- doctr/utils/repr.py +3 -1
- doctr/utils/visualization.py +31 -8
- doctr/version.py +1 -1
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
- python_doctr-0.8.1.dist-info/RECORD +173 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
- python_doctr-0.7.0.dist-info/RECORD +0 -161
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -18,7 +18,7 @@ from doctr.datasets import VOCABS
|
|
|
18
18
|
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
|
|
19
19
|
|
|
20
20
|
from ...classification import vit_s
|
|
21
|
-
from ...utils.pytorch import load_pretrained_params
|
|
21
|
+
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
|
|
22
22
|
from .base import _PARSeq, _PARSeqPostProcessor
|
|
23
23
|
|
|
24
24
|
__all__ = ["PARSeq", "parseq"]
|
|
@@ -29,7 +29,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
29
29
|
"std": (0.299, 0.296, 0.301),
|
|
30
30
|
"input_shape": (3, 32, 128),
|
|
31
31
|
"vocab": VOCABS["french"],
|
|
32
|
-
"url":
|
|
32
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/parseq-56125471.pt&src=0",
|
|
33
33
|
},
|
|
34
34
|
}
|
|
35
35
|
|
|
@@ -38,6 +38,7 @@ class CharEmbedding(nn.Module):
|
|
|
38
38
|
"""Implements the character embedding module
|
|
39
39
|
|
|
40
40
|
Args:
|
|
41
|
+
----
|
|
41
42
|
vocab_size: size of the vocabulary
|
|
42
43
|
d_model: dimension of the model
|
|
43
44
|
"""
|
|
@@ -55,6 +56,7 @@ class PARSeqDecoder(nn.Module):
|
|
|
55
56
|
"""Implements decoder module of the PARSeq model
|
|
56
57
|
|
|
57
58
|
Args:
|
|
59
|
+
----
|
|
58
60
|
d_model: dimension of the model
|
|
59
61
|
num_heads: number of attention heads
|
|
60
62
|
ffd: dimension of the feed forward layer
|
|
@@ -110,6 +112,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
110
112
|
Slightly modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
|
|
111
113
|
|
|
112
114
|
Args:
|
|
115
|
+
----
|
|
113
116
|
feature_extractor: the backbone serving as feature extractor
|
|
114
117
|
vocab: vocabulary used for encoding
|
|
115
118
|
embedding_units: number of embedding units
|
|
@@ -197,11 +200,11 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
197
200
|
final_perms = torch.stack(perms)
|
|
198
201
|
if len(perm_pool):
|
|
199
202
|
i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(final_perms), replace=False)
|
|
200
|
-
final_perms = torch.cat([final_perms, perm_pool[i]])
|
|
203
|
+
final_perms = torch.cat([final_perms, perm_pool[i]])
|
|
201
204
|
else:
|
|
202
|
-
perms.extend(
|
|
203
|
-
|
|
204
|
-
)
|
|
205
|
+
perms.extend([
|
|
206
|
+
torch.randperm(max_num_chars, device=seqlen.device) for _ in range(num_gen_perms - len(perms))
|
|
207
|
+
])
|
|
205
208
|
final_perms = torch.stack(perms)
|
|
206
209
|
|
|
207
210
|
comp = final_perms.flip(-1)
|
|
@@ -209,7 +212,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
209
212
|
|
|
210
213
|
sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
|
|
211
214
|
eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
|
|
212
|
-
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
|
|
215
|
+
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int() # type: ignore
|
|
213
216
|
if len(combined) > 1:
|
|
214
217
|
combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
|
|
215
218
|
return combined
|
|
@@ -237,7 +240,6 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
237
240
|
target_query: Optional[torch.Tensor] = None,
|
|
238
241
|
) -> torch.Tensor:
|
|
239
242
|
"""Add positional information to the target sequence and pass it through the decoder."""
|
|
240
|
-
|
|
241
243
|
batch_size, sequence_length = target.shape
|
|
242
244
|
# apply positional information to the target sequence excluding the SOS token
|
|
243
245
|
null_ctx = self.embed(target[:, :1])
|
|
@@ -280,7 +282,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
280
282
|
ys[:, i + 1] = pos_prob.squeeze().argmax(-1)
|
|
281
283
|
|
|
282
284
|
# Stop decoding if all sequences have reached the EOS token
|
|
283
|
-
if max_len is None and (ys == self.vocab_size).any(dim=-1).all():
|
|
285
|
+
if max_len is None and (ys == self.vocab_size).any(dim=-1).all(): # type: ignore[attr-defined]
|
|
284
286
|
break
|
|
285
287
|
|
|
286
288
|
logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
|
|
@@ -295,7 +297,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
295
297
|
|
|
296
298
|
# Create padding mask for refined target input maskes all behind EOS token as False
|
|
297
299
|
# (N, 1, 1, max_length)
|
|
298
|
-
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
300
|
+
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
|
|
299
301
|
mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
|
|
300
302
|
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
|
|
301
303
|
|
|
@@ -329,11 +331,9 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
329
331
|
gt_out = gt[:, 1:] # remove SOS token
|
|
330
332
|
# Create padding mask for target input
|
|
331
333
|
# [True, True, True, ..., False, False, False] -> False is masked
|
|
332
|
-
padding_mask = (
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
.unsqueeze(1)
|
|
336
|
-
) # (N, 1, 1, seq_len)
|
|
334
|
+
padding_mask = ~(
|
|
335
|
+
((gt_in == self.vocab_size + 2) | (gt_in == self.vocab_size)).int().cumsum(-1) > 0
|
|
336
|
+
).unsqueeze(1).unsqueeze(1) # (N, 1, 1, seq_len)
|
|
337
337
|
|
|
338
338
|
loss = torch.tensor(0.0, device=features.device)
|
|
339
339
|
loss_numel: Union[int, float] = 0
|
|
@@ -362,6 +362,8 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
362
362
|
else:
|
|
363
363
|
logits = self.decode_autoregressive(features)
|
|
364
364
|
|
|
365
|
+
logits = _bf16_to_float32(logits)
|
|
366
|
+
|
|
365
367
|
out: Dict[str, Any] = {}
|
|
366
368
|
if self.exportable:
|
|
367
369
|
out["logits"] = logits
|
|
@@ -384,6 +386,7 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
|
|
|
384
386
|
"""Post processor for PARSeq architecture
|
|
385
387
|
|
|
386
388
|
Args:
|
|
389
|
+
----
|
|
387
390
|
vocab: string containing the ordered sequence of supported characters
|
|
388
391
|
"""
|
|
389
392
|
|
|
@@ -393,18 +396,19 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
|
|
|
393
396
|
) -> List[Tuple[str, float]]:
|
|
394
397
|
# compute pred with argmax for attention models
|
|
395
398
|
out_idxs = logits.argmax(-1)
|
|
396
|
-
|
|
397
|
-
probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
|
|
398
|
-
# Take the minimum confidence of the sequence
|
|
399
|
-
probs = probs.min(dim=1).values.detach().cpu()
|
|
399
|
+
preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
|
|
400
400
|
|
|
401
401
|
# Manual decoding
|
|
402
402
|
word_values = [
|
|
403
403
|
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
|
|
404
404
|
for encoded_seq in out_idxs.cpu().numpy()
|
|
405
405
|
]
|
|
406
|
+
# compute probabilties for each word up to the EOS token
|
|
407
|
+
probs = [
|
|
408
|
+
preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values)
|
|
409
|
+
]
|
|
406
410
|
|
|
407
|
-
return list(zip(word_values, probs
|
|
411
|
+
return list(zip(word_values, probs))
|
|
408
412
|
|
|
409
413
|
|
|
410
414
|
def _parseq(
|
|
@@ -457,12 +461,14 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
|
|
|
457
461
|
>>> out = model(input_tensor)
|
|
458
462
|
|
|
459
463
|
Args:
|
|
464
|
+
----
|
|
460
465
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
466
|
+
**kwargs: keyword arguments of the PARSeq architecture
|
|
461
467
|
|
|
462
468
|
Returns:
|
|
469
|
+
-------
|
|
463
470
|
text recognition architecture
|
|
464
471
|
"""
|
|
465
|
-
|
|
466
472
|
return _parseq(
|
|
467
473
|
"parseq",
|
|
468
474
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -16,7 +16,7 @@ from doctr.datasets import VOCABS
|
|
|
16
16
|
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
|
|
17
17
|
|
|
18
18
|
from ...classification import vit_s
|
|
19
|
-
from ...utils.tensorflow import load_pretrained_params
|
|
19
|
+
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
20
20
|
from .base import _PARSeq, _PARSeqPostProcessor
|
|
21
21
|
|
|
22
22
|
__all__ = ["PARSeq", "parseq"]
|
|
@@ -36,6 +36,7 @@ class CharEmbedding(layers.Layer):
|
|
|
36
36
|
"""Implements the character embedding module
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
|
+
----
|
|
39
40
|
vocab_size: size of the vocabulary
|
|
40
41
|
d_model: dimension of the model
|
|
41
42
|
"""
|
|
@@ -53,6 +54,7 @@ class PARSeqDecoder(layers.Layer):
|
|
|
53
54
|
"""Implements decoder module of the PARSeq model
|
|
54
55
|
|
|
55
56
|
Args:
|
|
57
|
+
----
|
|
56
58
|
d_model: dimension of the model
|
|
57
59
|
num_heads: number of attention heads
|
|
58
60
|
ffd: dimension of the feed forward layer
|
|
@@ -113,6 +115,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
113
115
|
Modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
|
|
114
116
|
|
|
115
117
|
Args:
|
|
118
|
+
----
|
|
116
119
|
feature_extractor: the backbone serving as feature extractor
|
|
117
120
|
vocab: vocabulary used for encoding
|
|
118
121
|
embedding_units: number of embedding units
|
|
@@ -191,9 +194,9 @@ class PARSeq(_PARSeq, Model):
|
|
|
191
194
|
i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(final_perms), replace=False)
|
|
192
195
|
final_perms = tf.concat([final_perms, perm_pool[i[0] : i[1]]], axis=0)
|
|
193
196
|
else:
|
|
194
|
-
perms.extend(
|
|
195
|
-
|
|
196
|
-
)
|
|
197
|
+
perms.extend([
|
|
198
|
+
tf.random.shuffle(tf.range(max_num_chars, dtype=tf.int32)) for _ in range(num_gen_perms - len(perms))
|
|
199
|
+
])
|
|
197
200
|
final_perms = tf.stack(perms)
|
|
198
201
|
|
|
199
202
|
comp = tf.reverse(final_perms, axis=[-1])
|
|
@@ -390,6 +393,8 @@ class PARSeq(_PARSeq, Model):
|
|
|
390
393
|
else:
|
|
391
394
|
logits = self.decode_autoregressive(features, **kwargs)
|
|
392
395
|
|
|
396
|
+
logits = _bf16_to_float32(logits)
|
|
397
|
+
|
|
393
398
|
out: Dict[str, tf.Tensor] = {}
|
|
394
399
|
if self.exportable:
|
|
395
400
|
out["logits"] = logits
|
|
@@ -412,6 +417,7 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
|
|
|
412
417
|
"""Post processor for PARSeq architecture
|
|
413
418
|
|
|
414
419
|
Args:
|
|
420
|
+
----
|
|
415
421
|
vocab: string containing the ordered sequence of supported characters
|
|
416
422
|
"""
|
|
417
423
|
|
|
@@ -421,10 +427,7 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
|
|
|
421
427
|
) -> List[Tuple[str, float]]:
|
|
422
428
|
# compute pred with argmax for attention models
|
|
423
429
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
424
|
-
|
|
425
|
-
probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)
|
|
426
|
-
# Take the minimum confidence of the sequence
|
|
427
|
-
probs = tf.math.reduce_min(probs, axis=1)
|
|
430
|
+
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
|
|
428
431
|
|
|
429
432
|
# decode raw output of the model with tf_label_to_idx
|
|
430
433
|
out_idxs = tf.cast(out_idxs, dtype="int32")
|
|
@@ -434,7 +437,13 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
|
|
|
434
437
|
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
|
|
435
438
|
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
|
|
436
439
|
|
|
437
|
-
|
|
440
|
+
# compute probabilties for each word up to the EOS token
|
|
441
|
+
probs = [
|
|
442
|
+
preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
|
|
443
|
+
for i, word in enumerate(word_values)
|
|
444
|
+
]
|
|
445
|
+
|
|
446
|
+
return list(zip(word_values, probs))
|
|
438
447
|
|
|
439
448
|
|
|
440
449
|
def _parseq(
|
|
@@ -484,12 +493,14 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
|
|
|
484
493
|
>>> out = model(input_tensor)
|
|
485
494
|
|
|
486
495
|
Args:
|
|
496
|
+
----
|
|
487
497
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
498
|
+
**kwargs: keyword arguments of the PARSeq architecture
|
|
488
499
|
|
|
489
500
|
Returns:
|
|
501
|
+
-------
|
|
490
502
|
text recognition architecture
|
|
491
503
|
"""
|
|
492
|
-
|
|
493
504
|
return _parseq(
|
|
494
505
|
"parseq",
|
|
495
506
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -22,6 +22,7 @@ def split_crops(
|
|
|
22
22
|
"""Chunk crops horizontally to match a given aspect ratio
|
|
23
23
|
|
|
24
24
|
Args:
|
|
25
|
+
----
|
|
25
26
|
crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
|
|
26
27
|
max_ratio: the maximum aspect ratio that won't trigger the chunk
|
|
27
28
|
target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
|
|
@@ -29,9 +30,9 @@ def split_crops(
|
|
|
29
30
|
channels_last: whether the numpy array has dimensions in channels last order
|
|
30
31
|
|
|
31
32
|
Returns:
|
|
33
|
+
-------
|
|
32
34
|
a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
|
|
33
35
|
"""
|
|
34
|
-
|
|
35
36
|
_remap_required = False
|
|
36
37
|
crop_map: List[Union[int, Tuple[int, int]]] = []
|
|
37
38
|
new_crops: List[np.ndarray] = []
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -21,6 +21,7 @@ class RecognitionPredictor(nn.Module):
|
|
|
21
21
|
"""Implements an object able to identify character sequences in images
|
|
22
22
|
|
|
23
23
|
Args:
|
|
24
|
+
----
|
|
24
25
|
pre_processor: transform inputs for easier batched model inference
|
|
25
26
|
model: core detection architecture
|
|
26
27
|
split_wide_crops: wether to use crop splitting for high aspect ratio crops
|
|
@@ -40,7 +41,7 @@ class RecognitionPredictor(nn.Module):
|
|
|
40
41
|
self.dil_factor = 1.4 # Dilation factor to overlap the crops
|
|
41
42
|
self.target_ar = 6 # Target aspect ratio
|
|
42
43
|
|
|
43
|
-
@torch.
|
|
44
|
+
@torch.inference_mode()
|
|
44
45
|
def forward(
|
|
45
46
|
self,
|
|
46
47
|
crops: Sequence[Union[np.ndarray, torch.Tensor]],
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -21,6 +21,7 @@ class RecognitionPredictor(NestedObject):
|
|
|
21
21
|
"""Implements an object able to identify character sequences in images
|
|
22
22
|
|
|
23
23
|
Args:
|
|
24
|
+
----
|
|
24
25
|
pre_processor: transform inputs for easier batched model inference
|
|
25
26
|
model: core detection architecture
|
|
26
27
|
split_wide_crops: wether to use crop splitting for high aspect ratio crops
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -14,7 +14,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
|
|
14
14
|
from doctr.datasets import VOCABS
|
|
15
15
|
|
|
16
16
|
from ...classification import resnet31
|
|
17
|
-
from ...utils.pytorch import load_pretrained_params
|
|
17
|
+
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
|
|
18
18
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
19
19
|
|
|
20
20
|
__all__ = ["SAR", "sar_resnet31"]
|
|
@@ -25,7 +25,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
25
25
|
"std": (0.299, 0.296, 0.301),
|
|
26
26
|
"input_shape": (3, 32, 128),
|
|
27
27
|
"vocab": VOCABS["french"],
|
|
28
|
-
"url":
|
|
28
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/sar_resnet31-9a1deedf.pt&src=0",
|
|
29
29
|
},
|
|
30
30
|
}
|
|
31
31
|
|
|
@@ -80,6 +80,7 @@ class SARDecoder(nn.Module):
|
|
|
80
80
|
"""Implements decoder module of the SAR model
|
|
81
81
|
|
|
82
82
|
Args:
|
|
83
|
+
----
|
|
83
84
|
rnn_units: number of hidden units in recurrent cells
|
|
84
85
|
max_length: maximum length of a sequence
|
|
85
86
|
vocab_size: number of classes in the model alphabet
|
|
@@ -164,6 +165,7 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
164
165
|
Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
165
166
|
|
|
166
167
|
Args:
|
|
168
|
+
----
|
|
167
169
|
feature_extractor: the backbone serving as feature extractor
|
|
168
170
|
vocab: vocabulary used for encoding
|
|
169
171
|
rnn_units: number of hidden units in both encoder and decoder LSTM
|
|
@@ -249,7 +251,7 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
249
251
|
if self.training and target is None:
|
|
250
252
|
raise ValueError("Need to provide labels during training for teacher forcing")
|
|
251
253
|
|
|
252
|
-
decoded_features = self.decoder(features, encoded, gt=None if target is None else gt)
|
|
254
|
+
decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt))
|
|
253
255
|
|
|
254
256
|
out: Dict[str, Any] = {}
|
|
255
257
|
if self.exportable:
|
|
@@ -278,17 +280,19 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
278
280
|
Sequences are masked after the EOS character.
|
|
279
281
|
|
|
280
282
|
Args:
|
|
283
|
+
----
|
|
281
284
|
model_output: predicted logits of the model
|
|
282
285
|
gt: the encoded tensor with gt labels
|
|
283
286
|
seq_len: lengths of each gt word inside the batch
|
|
284
287
|
|
|
285
288
|
Returns:
|
|
289
|
+
-------
|
|
286
290
|
The loss of the model on the batch
|
|
287
291
|
"""
|
|
288
292
|
# Input length : number of timesteps
|
|
289
293
|
input_len = model_output.shape[1]
|
|
290
294
|
# Add one for additional <eos> token
|
|
291
|
-
seq_len = seq_len + 1
|
|
295
|
+
seq_len = seq_len + 1 # type: ignore[assignment]
|
|
292
296
|
# Compute loss
|
|
293
297
|
# (N, L, vocab_size + 1)
|
|
294
298
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
|
|
@@ -303,6 +307,7 @@ class SARPostProcessor(RecognitionPostProcessor):
|
|
|
303
307
|
"""Post processor for SAR architectures
|
|
304
308
|
|
|
305
309
|
Args:
|
|
310
|
+
----
|
|
306
311
|
vocab: string containing the ordered sequence of supported characters
|
|
307
312
|
"""
|
|
308
313
|
|
|
@@ -323,7 +328,7 @@ class SARPostProcessor(RecognitionPostProcessor):
|
|
|
323
328
|
for encoded_seq in out_idxs.detach().cpu().numpy()
|
|
324
329
|
]
|
|
325
330
|
|
|
326
|
-
return list(zip(word_values, probs.numpy().tolist()))
|
|
331
|
+
return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
|
|
327
332
|
|
|
328
333
|
|
|
329
334
|
def _sar(
|
|
@@ -373,12 +378,14 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
|
|
|
373
378
|
>>> out = model(input_tensor)
|
|
374
379
|
|
|
375
380
|
Args:
|
|
381
|
+
----
|
|
376
382
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
383
|
+
**kwargs: keyword arguments of the SAR architecture
|
|
377
384
|
|
|
378
385
|
Returns:
|
|
386
|
+
-------
|
|
379
387
|
text recognition architecture
|
|
380
388
|
"""
|
|
381
|
-
|
|
382
389
|
return _sar(
|
|
383
390
|
"sar_resnet31",
|
|
384
391
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -13,7 +13,7 @@ from doctr.datasets import VOCABS
|
|
|
13
13
|
from doctr.utils.repr import NestedObject
|
|
14
14
|
|
|
15
15
|
from ...classification import resnet31
|
|
16
|
-
from ...utils.tensorflow import load_pretrained_params
|
|
16
|
+
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
17
17
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
18
18
|
|
|
19
19
|
__all__ = ["SAR", "sar_resnet31"]
|
|
@@ -33,18 +33,17 @@ class SAREncoder(layers.Layer, NestedObject):
|
|
|
33
33
|
"""Implements encoder module of the SAR model
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
|
+
----
|
|
36
37
|
rnn_units: number of hidden rnn units
|
|
37
38
|
dropout_prob: dropout probability
|
|
38
39
|
"""
|
|
39
40
|
|
|
40
41
|
def __init__(self, rnn_units: int, dropout_prob: float = 0.0) -> None:
|
|
41
42
|
super().__init__()
|
|
42
|
-
self.rnn = Sequential(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
]
|
|
47
|
-
)
|
|
43
|
+
self.rnn = Sequential([
|
|
44
|
+
layers.LSTM(units=rnn_units, return_sequences=True, recurrent_dropout=dropout_prob),
|
|
45
|
+
layers.LSTM(units=rnn_units, return_sequences=False, recurrent_dropout=dropout_prob),
|
|
46
|
+
])
|
|
48
47
|
|
|
49
48
|
def call(
|
|
50
49
|
self,
|
|
@@ -59,6 +58,7 @@ class AttentionModule(layers.Layer, NestedObject):
|
|
|
59
58
|
"""Implements attention module of the SAR model
|
|
60
59
|
|
|
61
60
|
Args:
|
|
61
|
+
----
|
|
62
62
|
attention_units: number of hidden attention units
|
|
63
63
|
|
|
64
64
|
"""
|
|
@@ -120,6 +120,7 @@ class SARDecoder(layers.Layer, NestedObject):
|
|
|
120
120
|
"""Implements decoder module of the SAR model
|
|
121
121
|
|
|
122
122
|
Args:
|
|
123
|
+
----
|
|
123
124
|
rnn_units: number of hidden units in recurrent cells
|
|
124
125
|
max_length: maximum length of a sequence
|
|
125
126
|
vocab_size: number of classes in the model alphabet
|
|
@@ -147,9 +148,9 @@ class SARDecoder(layers.Layer, NestedObject):
|
|
|
147
148
|
self.embed = layers.Dense(embedding_units, use_bias=False)
|
|
148
149
|
self.embed_tgt = layers.Embedding(embedding_units, self.vocab_size + 1)
|
|
149
150
|
|
|
150
|
-
self.lstm_cells = layers.StackedRNNCells(
|
|
151
|
-
|
|
152
|
-
)
|
|
151
|
+
self.lstm_cells = layers.StackedRNNCells([
|
|
152
|
+
layers.LSTMCell(rnn_units, implementation=1) for _ in range(num_decoder_cells)
|
|
153
|
+
])
|
|
153
154
|
self.attention_module = AttentionModule(attention_units)
|
|
154
155
|
self.output_dense = layers.Dense(self.vocab_size + 1, use_bias=True)
|
|
155
156
|
self.dropout = layers.Dropout(dropout_prob)
|
|
@@ -215,6 +216,7 @@ class SAR(Model, RecognitionModel):
|
|
|
215
216
|
Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
216
217
|
|
|
217
218
|
Args:
|
|
219
|
+
----
|
|
218
220
|
feature_extractor: the backbone serving as feature extractor
|
|
219
221
|
vocab: vocabulary used for encoding
|
|
220
222
|
rnn_units: number of hidden units in both encoder and decoder LSTM
|
|
@@ -273,11 +275,13 @@ class SAR(Model, RecognitionModel):
|
|
|
273
275
|
Sequences are masked after the EOS character.
|
|
274
276
|
|
|
275
277
|
Args:
|
|
278
|
+
----
|
|
276
279
|
gt: the encoded tensor with gt labels
|
|
277
280
|
model_output: predicted logits of the model
|
|
278
281
|
seq_len: lengths of each gt word inside the batch
|
|
279
282
|
|
|
280
283
|
Returns:
|
|
284
|
+
-------
|
|
281
285
|
The loss of the model on the batch
|
|
282
286
|
"""
|
|
283
287
|
# Input length : number of timesteps
|
|
@@ -316,7 +320,9 @@ class SAR(Model, RecognitionModel):
|
|
|
316
320
|
if kwargs.get("training", False) and target is None:
|
|
317
321
|
raise ValueError("Need to provide labels during training for teacher forcing")
|
|
318
322
|
|
|
319
|
-
decoded_features =
|
|
323
|
+
decoded_features = _bf16_to_float32(
|
|
324
|
+
self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
|
|
325
|
+
)
|
|
320
326
|
|
|
321
327
|
out: Dict[str, tf.Tensor] = {}
|
|
322
328
|
if self.exportable:
|
|
@@ -340,6 +346,7 @@ class SARPostProcessor(RecognitionPostProcessor):
|
|
|
340
346
|
"""Post processor for SAR architectures
|
|
341
347
|
|
|
342
348
|
Args:
|
|
349
|
+
----
|
|
343
350
|
vocab: string containing the ordered sequence of supported characters
|
|
344
351
|
"""
|
|
345
352
|
|
|
@@ -362,7 +369,7 @@ class SARPostProcessor(RecognitionPostProcessor):
|
|
|
362
369
|
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
|
|
363
370
|
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
|
|
364
371
|
|
|
365
|
-
return list(zip(word_values, probs.numpy().tolist()))
|
|
372
|
+
return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
|
|
366
373
|
|
|
367
374
|
|
|
368
375
|
def _sar(
|
|
@@ -409,10 +416,12 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
|
|
|
409
416
|
>>> out = model(input_tensor)
|
|
410
417
|
|
|
411
418
|
Args:
|
|
419
|
+
----
|
|
412
420
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
421
|
+
**kwargs: keyword arguments of the SAR architecture
|
|
413
422
|
|
|
414
423
|
Returns:
|
|
424
|
+
-------
|
|
415
425
|
text recognition architecture
|
|
416
426
|
"""
|
|
417
|
-
|
|
418
427
|
return _sar("sar_resnet31", pretrained, resnet31, **kwargs)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -14,12 +14,14 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
|
|
|
14
14
|
"""Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
|
|
15
15
|
|
|
16
16
|
Args:
|
|
17
|
+
----
|
|
17
18
|
a: first char seq, suffix should be similar to b's prefix.
|
|
18
19
|
b: second char seq, prefix should be similar to a's suffix.
|
|
19
20
|
dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
|
|
20
21
|
only used when the mother sequence is splitted on a character repetition
|
|
21
22
|
|
|
22
23
|
Returns:
|
|
24
|
+
-------
|
|
23
25
|
A merged character sequence.
|
|
24
26
|
|
|
25
27
|
Example::
|
|
@@ -63,11 +65,13 @@ def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str:
|
|
|
63
65
|
"""Recursively merges consecutive string sequences with overlapping characters.
|
|
64
66
|
|
|
65
67
|
Args:
|
|
68
|
+
----
|
|
66
69
|
seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
|
|
67
70
|
dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
|
|
68
71
|
only used when the mother sequence is splitted on a character repetition
|
|
69
72
|
|
|
70
73
|
Returns:
|
|
74
|
+
-------
|
|
71
75
|
A merged character sequence
|
|
72
76
|
|
|
73
77
|
Example::
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -23,9 +23,11 @@ class _ViTSTR:
|
|
|
23
23
|
sequence lengths.
|
|
24
24
|
|
|
25
25
|
Args:
|
|
26
|
+
----
|
|
26
27
|
gts: list of ground-truth labels
|
|
27
28
|
|
|
28
29
|
Returns:
|
|
30
|
+
-------
|
|
29
31
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
30
32
|
"""
|
|
31
33
|
encoded = encode_sequences(
|
|
@@ -43,6 +45,7 @@ class _ViTSTRPostProcessor(RecognitionPostProcessor):
|
|
|
43
45
|
"""Abstract class to postprocess the raw output of the model
|
|
44
46
|
|
|
45
47
|
Args:
|
|
48
|
+
----
|
|
46
49
|
vocab: string containing the ordered sequence of supported characters
|
|
47
50
|
"""
|
|
48
51
|
|