python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import Model, Sequential, layers
|
|
@@ -18,7 +18,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
|
|
|
18
18
|
|
|
19
19
|
__all__ = ["SAR", "sar_resnet31"]
|
|
20
20
|
|
|
21
|
-
default_cfgs:
|
|
21
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
22
22
|
"sar_resnet31": {
|
|
23
23
|
"mean": (0.694, 0.695, 0.693),
|
|
24
24
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -33,7 +33,6 @@ class SAREncoder(layers.Layer, NestedObject):
|
|
|
33
33
|
"""Implements encoder module of the SAR model
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
|
-
----
|
|
37
36
|
rnn_units: number of hidden rnn units
|
|
38
37
|
dropout_prob: dropout probability
|
|
39
38
|
"""
|
|
@@ -58,7 +57,6 @@ class AttentionModule(layers.Layer, NestedObject):
|
|
|
58
57
|
"""Implements attention module of the SAR model
|
|
59
58
|
|
|
60
59
|
Args:
|
|
61
|
-
----
|
|
62
60
|
attention_units: number of hidden attention units
|
|
63
61
|
|
|
64
62
|
"""
|
|
@@ -120,7 +118,6 @@ class SARDecoder(layers.Layer, NestedObject):
|
|
|
120
118
|
"""Implements decoder module of the SAR model
|
|
121
119
|
|
|
122
120
|
Args:
|
|
123
|
-
----
|
|
124
121
|
rnn_units: number of hidden units in recurrent cells
|
|
125
122
|
max_length: maximum length of a sequence
|
|
126
123
|
vocab_size: number of classes in the model alphabet
|
|
@@ -159,13 +156,13 @@ class SARDecoder(layers.Layer, NestedObject):
|
|
|
159
156
|
self,
|
|
160
157
|
features: tf.Tensor,
|
|
161
158
|
holistic: tf.Tensor,
|
|
162
|
-
gt:
|
|
159
|
+
gt: tf.Tensor | None = None,
|
|
163
160
|
**kwargs: Any,
|
|
164
161
|
) -> tf.Tensor:
|
|
165
162
|
if gt is not None:
|
|
166
163
|
gt_embedding = self.embed_tgt(gt, **kwargs)
|
|
167
164
|
|
|
168
|
-
logits_list:
|
|
165
|
+
logits_list: list[tf.Tensor] = []
|
|
169
166
|
|
|
170
167
|
for t in range(self.max_length + 1): # 32
|
|
171
168
|
if t == 0:
|
|
@@ -210,7 +207,6 @@ class SAR(Model, RecognitionModel):
|
|
|
210
207
|
Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
211
208
|
|
|
212
209
|
Args:
|
|
213
|
-
----
|
|
214
210
|
feature_extractor: the backbone serving as feature extractor
|
|
215
211
|
vocab: vocabulary used for encoding
|
|
216
212
|
rnn_units: number of hidden units in both encoder and decoder LSTM
|
|
@@ -223,7 +219,7 @@ class SAR(Model, RecognitionModel):
|
|
|
223
219
|
cfg: dictionary containing information about the model
|
|
224
220
|
"""
|
|
225
221
|
|
|
226
|
-
_children_names:
|
|
222
|
+
_children_names: list[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"]
|
|
227
223
|
|
|
228
224
|
def __init__(
|
|
229
225
|
self,
|
|
@@ -236,7 +232,7 @@ class SAR(Model, RecognitionModel):
|
|
|
236
232
|
num_decoder_cells: int = 2,
|
|
237
233
|
dropout_prob: float = 0.0,
|
|
238
234
|
exportable: bool = False,
|
|
239
|
-
cfg:
|
|
235
|
+
cfg: dict[str, Any] | None = None,
|
|
240
236
|
) -> None:
|
|
241
237
|
super().__init__()
|
|
242
238
|
self.vocab = vocab
|
|
@@ -259,6 +255,15 @@ class SAR(Model, RecognitionModel):
|
|
|
259
255
|
|
|
260
256
|
self.postprocessor = SARPostProcessor(vocab=vocab)
|
|
261
257
|
|
|
258
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
259
|
+
"""Load pretrained parameters onto the model
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
263
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
264
|
+
"""
|
|
265
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
266
|
+
|
|
262
267
|
@staticmethod
|
|
263
268
|
def compute_loss(
|
|
264
269
|
model_output: tf.Tensor,
|
|
@@ -269,13 +274,11 @@ class SAR(Model, RecognitionModel):
|
|
|
269
274
|
Sequences are masked after the EOS character.
|
|
270
275
|
|
|
271
276
|
Args:
|
|
272
|
-
----
|
|
273
277
|
gt: the encoded tensor with gt labels
|
|
274
278
|
model_output: predicted logits of the model
|
|
275
279
|
seq_len: lengths of each gt word inside the batch
|
|
276
280
|
|
|
277
281
|
Returns:
|
|
278
|
-
-------
|
|
279
282
|
The loss of the model on the batch
|
|
280
283
|
"""
|
|
281
284
|
# Input length : number of timesteps
|
|
@@ -296,11 +299,11 @@ class SAR(Model, RecognitionModel):
|
|
|
296
299
|
def call(
|
|
297
300
|
self,
|
|
298
301
|
x: tf.Tensor,
|
|
299
|
-
target:
|
|
302
|
+
target: list[str] | None = None,
|
|
300
303
|
return_model_output: bool = False,
|
|
301
304
|
return_preds: bool = False,
|
|
302
305
|
**kwargs: Any,
|
|
303
|
-
) ->
|
|
306
|
+
) -> dict[str, Any]:
|
|
304
307
|
features = self.feat_extractor(x, **kwargs)
|
|
305
308
|
# vertical max pooling --> (N, C, W)
|
|
306
309
|
pooled_features = tf.reduce_max(features, axis=1)
|
|
@@ -318,7 +321,7 @@ class SAR(Model, RecognitionModel):
|
|
|
318
321
|
self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
|
|
319
322
|
)
|
|
320
323
|
|
|
321
|
-
out:
|
|
324
|
+
out: dict[str, tf.Tensor] = {}
|
|
322
325
|
if self.exportable:
|
|
323
326
|
out["logits"] = decoded_features
|
|
324
327
|
return out
|
|
@@ -340,14 +343,13 @@ class SARPostProcessor(RecognitionPostProcessor):
|
|
|
340
343
|
"""Post processor for SAR architectures
|
|
341
344
|
|
|
342
345
|
Args:
|
|
343
|
-
----
|
|
344
346
|
vocab: string containing the ordered sequence of supported characters
|
|
345
347
|
"""
|
|
346
348
|
|
|
347
349
|
def __call__(
|
|
348
350
|
self,
|
|
349
351
|
logits: tf.Tensor,
|
|
350
|
-
) ->
|
|
352
|
+
) -> list[tuple[str, float]]:
|
|
351
353
|
# compute pred with argmax for attention models
|
|
352
354
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
353
355
|
# N x L
|
|
@@ -371,7 +373,7 @@ def _sar(
|
|
|
371
373
|
pretrained: bool,
|
|
372
374
|
backbone_fn,
|
|
373
375
|
pretrained_backbone: bool = True,
|
|
374
|
-
input_shape:
|
|
376
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
375
377
|
**kwargs: Any,
|
|
376
378
|
) -> SAR:
|
|
377
379
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -396,9 +398,7 @@ def _sar(
|
|
|
396
398
|
# Load pretrained parameters
|
|
397
399
|
if pretrained:
|
|
398
400
|
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
399
|
-
|
|
400
|
-
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
401
|
-
)
|
|
401
|
+
model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
402
402
|
|
|
403
403
|
return model
|
|
404
404
|
|
|
@@ -414,12 +414,10 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
|
|
|
414
414
|
>>> out = model(input_tensor)
|
|
415
415
|
|
|
416
416
|
Args:
|
|
417
|
-
----
|
|
418
417
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
419
418
|
**kwargs: keyword arguments of the SAR architecture
|
|
420
419
|
|
|
421
420
|
Returns:
|
|
422
|
-
-------
|
|
423
421
|
text recognition architecture
|
|
424
422
|
"""
|
|
425
423
|
return _sar("sar_resnet31", pretrained, resnet31, **kwargs)
|
|
@@ -1,89 +1,93 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List
|
|
7
6
|
|
|
8
|
-
from rapidfuzz.distance import
|
|
7
|
+
from rapidfuzz.distance import Hamming
|
|
9
8
|
|
|
10
9
|
__all__ = ["merge_strings", "merge_multi_strings"]
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
def merge_strings(a: str, b: str,
|
|
12
|
+
def merge_strings(a: str, b: str, overlap_ratio: float) -> str:
|
|
14
13
|
"""Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
|
|
15
14
|
|
|
16
15
|
Args:
|
|
17
|
-
----
|
|
18
16
|
a: first char seq, suffix should be similar to b's prefix.
|
|
19
17
|
b: second char seq, prefix should be similar to a's suffix.
|
|
20
|
-
|
|
21
|
-
only used when the mother sequence is splitted on a character repetition
|
|
18
|
+
overlap_ratio: estimated ratio of overlapping characters.
|
|
22
19
|
|
|
23
20
|
Returns:
|
|
24
|
-
-------
|
|
25
21
|
A merged character sequence.
|
|
26
22
|
|
|
27
23
|
Example::
|
|
28
|
-
>>> from doctr.
|
|
29
|
-
>>>
|
|
24
|
+
>>> from doctr.models.recognition.utils import merge_strings
|
|
25
|
+
>>> merge_strings('abcd', 'cdefgh', 0.5)
|
|
30
26
|
'abcdefgh'
|
|
31
|
-
>>>
|
|
27
|
+
>>> merge_strings('abcdi', 'cdefgh', 0.5)
|
|
32
28
|
'abcdefgh'
|
|
33
29
|
"""
|
|
34
30
|
seq_len = min(len(a), len(b))
|
|
35
|
-
if seq_len
|
|
36
|
-
return b if len(a) == 0 else a
|
|
37
|
-
|
|
38
|
-
# Initialize merging index and corresponding score (mean Levenstein)
|
|
39
|
-
min_score, index = 1.0, 0 # No overlap, just concatenate
|
|
40
|
-
|
|
41
|
-
scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)]
|
|
42
|
-
|
|
43
|
-
# Edge case (split in the middle of char repetitions): if it starts with 2 or more 0
|
|
44
|
-
if len(scores) > 1 and (scores[0], scores[1]) == (0, 0):
|
|
45
|
-
# Compute n_overlap (number of overlapping chars, geometrically determined)
|
|
46
|
-
n_overlap = round(len(b) * (dil_factor - 1) / dil_factor)
|
|
47
|
-
# Find the number of consecutive zeros in the scores list
|
|
48
|
-
# Impossible to have a zero after a non-zero score in that case
|
|
49
|
-
n_zeros = sum(val == 0 for val in scores)
|
|
50
|
-
# Index is bounded by the geometrical overlap to avoid collapsing repetitions
|
|
51
|
-
min_score, index = 0, min(n_zeros, n_overlap)
|
|
52
|
-
|
|
53
|
-
else: # Common case: choose the min score index
|
|
54
|
-
for i, score in enumerate(scores):
|
|
55
|
-
if score < min_score:
|
|
56
|
-
min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char
|
|
57
|
-
|
|
58
|
-
# Merge with correct overlap
|
|
59
|
-
if index == 0:
|
|
31
|
+
if seq_len <= 1: # One sequence is empty or will be after cropping in next step, return both to keep data
|
|
60
32
|
return a + b
|
|
61
|
-
return a[:-1] + b[index - 1 :]
|
|
62
33
|
|
|
34
|
+
a_crop, b_crop = a[:-1], b[1:] # Remove last letter of "a" and first of "b", because they might be cut off
|
|
35
|
+
max_overlap = min(len(a_crop), len(b_crop))
|
|
63
36
|
|
|
64
|
-
|
|
65
|
-
|
|
37
|
+
# Compute Hamming distances for all possible overlaps
|
|
38
|
+
scores = [Hamming.distance(a_crop[-i:], b_crop[:i], processor=None) for i in range(1, max_overlap + 1)]
|
|
39
|
+
|
|
40
|
+
# Find zero-score matches
|
|
41
|
+
zero_matches = [i for i, score in enumerate(scores) if score == 0]
|
|
42
|
+
|
|
43
|
+
expected_overlap = round(len(b) * overlap_ratio) - 3 # adjust for cropping and index
|
|
44
|
+
|
|
45
|
+
# Case 1: One perfect match - exactly one zero score - just merge there
|
|
46
|
+
if len(zero_matches) == 1:
|
|
47
|
+
i = zero_matches[0]
|
|
48
|
+
return a_crop + b_crop[i + 1 :]
|
|
49
|
+
|
|
50
|
+
# Case 2: Multiple perfect matches - likely due to repeated characters.
|
|
51
|
+
# Use the estimated overlap length to choose the match closest to the expected alignment.
|
|
52
|
+
elif len(zero_matches) > 1:
|
|
53
|
+
best_i = min(zero_matches, key=lambda x: abs(x - expected_overlap))
|
|
54
|
+
return a_crop + b_crop[best_i + 1 :]
|
|
55
|
+
|
|
56
|
+
# Case 3: Absence of zero scores indicates that the same character in the image was recognized differently OR that
|
|
57
|
+
# the overlap was too small and we just need to merge the crops fully
|
|
58
|
+
if expected_overlap < -1:
|
|
59
|
+
return a + b
|
|
60
|
+
elif expected_overlap < 0:
|
|
61
|
+
return a_crop + b_crop
|
|
62
|
+
|
|
63
|
+
# Find best overlap by minimizing Hamming distance + distance from expected overlap size
|
|
64
|
+
combined_scores = [score + abs(i - expected_overlap) for i, score in enumerate(scores)]
|
|
65
|
+
best_i = combined_scores.index(min(combined_scores))
|
|
66
|
+
return a_crop + b_crop[best_i + 1 :]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def merge_multi_strings(seq_list: list[str], overlap_ratio: float, last_overlap_ratio: float) -> str:
|
|
70
|
+
"""
|
|
71
|
+
Merges consecutive string sequences with overlapping characters.
|
|
66
72
|
|
|
67
73
|
Args:
|
|
68
|
-
----
|
|
69
74
|
seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
|
|
70
|
-
|
|
71
|
-
|
|
75
|
+
overlap_ratio: Estimated ratio of overlapping letters between neighboring strings.
|
|
76
|
+
last_overlap_ratio: Estimated ratio of overlapping letters for the last element in seq_list.
|
|
72
77
|
|
|
73
78
|
Returns:
|
|
74
|
-
-------
|
|
75
79
|
A merged character sequence
|
|
76
80
|
|
|
77
81
|
Example::
|
|
78
|
-
>>> from doctr.
|
|
79
|
-
>>>
|
|
82
|
+
>>> from doctr.models.recognition.utils import merge_multi_strings
|
|
83
|
+
>>> merge_multi_strings(['abc', 'bcdef', 'difghi', 'aijkl'], 0.5, 0.1)
|
|
80
84
|
'abcdefghijkl'
|
|
81
85
|
"""
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
return
|
|
86
|
+
if not seq_list:
|
|
87
|
+
return ""
|
|
88
|
+
result = seq_list[0]
|
|
89
|
+
for i in range(1, len(seq_list)):
|
|
90
|
+
text_b = seq_list[i]
|
|
91
|
+
ratio = last_overlap_ratio if i == len(seq_list) - 1 else overlap_ratio
|
|
92
|
+
result = merge_strings(result, text_b, ratio)
|
|
93
|
+
return result
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from itertools import groupby
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
import torch.nn.functional as F
|
|
14
|
+
from torchvision.models._utils import IntermediateLayerGetter
|
|
15
|
+
|
|
16
|
+
from doctr.datasets import VOCABS, decode_sequence
|
|
17
|
+
|
|
18
|
+
from ...classification import vip_tiny
|
|
19
|
+
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
|
|
20
|
+
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
21
|
+
|
|
22
|
+
__all__ = ["VIPTR", "viptr_tiny"]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
26
|
+
"viptr_tiny": {
|
|
27
|
+
"mean": (0.694, 0.695, 0.693),
|
|
28
|
+
"std": (0.299, 0.296, 0.301),
|
|
29
|
+
"input_shape": (3, 32, 128),
|
|
30
|
+
"vocab": VOCABS["french"],
|
|
31
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.11.0/viptr_tiny-1cb2515e.pt&src=0",
|
|
32
|
+
},
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class VIPTRPostProcessor(RecognitionPostProcessor):
|
|
37
|
+
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
vocab: string containing the ordered sequence of supported characters
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def ctc_best_path(
|
|
45
|
+
logits: torch.Tensor,
|
|
46
|
+
vocab: str = VOCABS["french"],
|
|
47
|
+
blank: int = 0,
|
|
48
|
+
) -> list[tuple[str, float]]:
|
|
49
|
+
"""Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
|
|
50
|
+
<https://github.com/githubharald/CTCDecoder>`_.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
logits: model output, shape: N x T x C
|
|
54
|
+
vocab: vocabulary to use
|
|
55
|
+
blank: index of blank label
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A list of tuples: (word, confidence)
|
|
59
|
+
"""
|
|
60
|
+
# Gather the most confident characters, and assign the smallest conf among those to the sequence prob
|
|
61
|
+
probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values
|
|
62
|
+
|
|
63
|
+
# collapse best path (using itertools.groupby), map to chars, join char list to string
|
|
64
|
+
words = [
|
|
65
|
+
decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
|
|
66
|
+
for seq in torch.argmax(logits, dim=-1)
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
return list(zip(words, probs.tolist()))
|
|
70
|
+
|
|
71
|
+
def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
72
|
+
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
73
|
+
with label_to_idx mapping dictionnary
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
A tuple of 2 lists: a list of str (words) and a list of float (probs)
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
# Decode CTC
|
|
83
|
+
return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class VIPTR(RecognitionModel, nn.Module):
|
|
87
|
+
"""Implements a VIPTR architecture as described in `"A Vision Permutable Extractor for Fast and Efficient
|
|
88
|
+
Scene Text Recognition" <https://arxiv.org/abs/2401.10110>`_.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
feature_extractor: the backbone serving as feature extractor
|
|
92
|
+
vocab: vocabulary used for encoding
|
|
93
|
+
input_shape: input shape of the image
|
|
94
|
+
exportable: onnx exportable returns only logits
|
|
95
|
+
cfg: configuration dictionary
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
feature_extractor: nn.Module,
|
|
101
|
+
vocab: str,
|
|
102
|
+
input_shape: tuple[int, int, int] = (3, 32, 128),
|
|
103
|
+
exportable: bool = False,
|
|
104
|
+
cfg: dict[str, Any] | None = None,
|
|
105
|
+
):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.vocab = vocab
|
|
108
|
+
self.exportable = exportable
|
|
109
|
+
self.cfg = cfg
|
|
110
|
+
self.max_length = 32
|
|
111
|
+
self.vocab_size = len(vocab)
|
|
112
|
+
|
|
113
|
+
self.feat_extractor = feature_extractor
|
|
114
|
+
with torch.inference_mode():
|
|
115
|
+
embedding_units = self.feat_extractor(torch.zeros((1, *input_shape)))["features"].shape[-1]
|
|
116
|
+
|
|
117
|
+
self.postprocessor = VIPTRPostProcessor(vocab=self.vocab)
|
|
118
|
+
self.head = nn.Linear(embedding_units, len(self.vocab) + 1) # +1 for PAD
|
|
119
|
+
|
|
120
|
+
for n, m in self.named_modules():
|
|
121
|
+
# Don't override the initialization of the backbone
|
|
122
|
+
if n.startswith("feat_extractor."):
|
|
123
|
+
continue
|
|
124
|
+
if isinstance(m, nn.Linear):
|
|
125
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
126
|
+
if m.bias is not None:
|
|
127
|
+
nn.init.zeros_(m.bias)
|
|
128
|
+
|
|
129
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
130
|
+
"""Load pretrained parameters onto the model
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
134
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
135
|
+
"""
|
|
136
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
137
|
+
|
|
138
|
+
def forward(
|
|
139
|
+
self,
|
|
140
|
+
x: torch.Tensor,
|
|
141
|
+
target: list[str] | None = None,
|
|
142
|
+
return_model_output: bool = False,
|
|
143
|
+
return_preds: bool = False,
|
|
144
|
+
) -> dict[str, Any]:
|
|
145
|
+
if target is not None:
|
|
146
|
+
_gt, _seq_len = self.build_target(target)
|
|
147
|
+
gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len)
|
|
148
|
+
gt, seq_len = gt.to(x.device), seq_len.to(x.device)
|
|
149
|
+
|
|
150
|
+
if self.training and target is None:
|
|
151
|
+
raise ValueError("Need to provide labels during training")
|
|
152
|
+
|
|
153
|
+
features = self.feat_extractor(x)["features"] # (B, max_len, embed_dim)
|
|
154
|
+
B, N, E = features.size()
|
|
155
|
+
logits = self.head(features).view(B, N, len(self.vocab) + 1)
|
|
156
|
+
|
|
157
|
+
decoded_features = _bf16_to_float32(logits)
|
|
158
|
+
|
|
159
|
+
out: dict[str, Any] = {}
|
|
160
|
+
if self.exportable:
|
|
161
|
+
out["logits"] = decoded_features
|
|
162
|
+
return out
|
|
163
|
+
|
|
164
|
+
if return_model_output:
|
|
165
|
+
out["out_map"] = decoded_features
|
|
166
|
+
|
|
167
|
+
if target is None or return_preds:
|
|
168
|
+
# Disable for torch.compile compatibility
|
|
169
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
170
|
+
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
171
|
+
return self.postprocessor(decoded_features)
|
|
172
|
+
|
|
173
|
+
# Post-process boxes
|
|
174
|
+
out["preds"] = _postprocess(decoded_features)
|
|
175
|
+
|
|
176
|
+
if target is not None:
|
|
177
|
+
out["loss"] = self.compute_loss(decoded_features, gt, seq_len, len(self.vocab))
|
|
178
|
+
|
|
179
|
+
return out
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
def compute_loss(
|
|
183
|
+
model_output: torch.Tensor,
|
|
184
|
+
gt: torch.Tensor,
|
|
185
|
+
seq_len: torch.Tensor,
|
|
186
|
+
blank_idx: int = 0,
|
|
187
|
+
) -> torch.Tensor:
|
|
188
|
+
"""Compute CTC loss for the model.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
model_output: predicted logits of the model
|
|
192
|
+
gt: ground truth tensor
|
|
193
|
+
seq_len: sequence lengths of the ground truth
|
|
194
|
+
blank_idx: index of the blank label
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
The loss of the model on the batch
|
|
198
|
+
"""
|
|
199
|
+
batch_len = model_output.shape[0]
|
|
200
|
+
input_length = model_output.shape[1] * torch.ones(size=(batch_len,), dtype=torch.int32)
|
|
201
|
+
# N x T x C -> T x N x C
|
|
202
|
+
logits = model_output.permute(1, 0, 2)
|
|
203
|
+
probs = F.log_softmax(logits, dim=-1)
|
|
204
|
+
ctc_loss = F.ctc_loss(
|
|
205
|
+
probs,
|
|
206
|
+
gt,
|
|
207
|
+
input_length,
|
|
208
|
+
seq_len,
|
|
209
|
+
blank_idx,
|
|
210
|
+
zero_infinity=True,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
return ctc_loss
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def _viptr(
|
|
217
|
+
arch: str,
|
|
218
|
+
pretrained: bool,
|
|
219
|
+
backbone_fn: Callable[[bool], nn.Module],
|
|
220
|
+
layer: str,
|
|
221
|
+
pretrained_backbone: bool = True,
|
|
222
|
+
ignore_keys: list[str] | None = None,
|
|
223
|
+
**kwargs: Any,
|
|
224
|
+
) -> VIPTR:
|
|
225
|
+
pretrained_backbone = pretrained_backbone and not pretrained
|
|
226
|
+
|
|
227
|
+
# Patch the config
|
|
228
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
229
|
+
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
|
|
230
|
+
_cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
|
|
231
|
+
|
|
232
|
+
# Feature extractor
|
|
233
|
+
feat_extractor = IntermediateLayerGetter(
|
|
234
|
+
backbone_fn(pretrained_backbone, input_shape=_cfg["input_shape"]), # type: ignore[call-arg]
|
|
235
|
+
{layer: "features"},
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
kwargs["vocab"] = _cfg["vocab"]
|
|
239
|
+
kwargs["input_shape"] = _cfg["input_shape"]
|
|
240
|
+
|
|
241
|
+
model = VIPTR(feat_extractor, cfg=_cfg, **kwargs)
|
|
242
|
+
|
|
243
|
+
# Load pretrained parameters
|
|
244
|
+
if pretrained:
|
|
245
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
246
|
+
# remove the last layer weights
|
|
247
|
+
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
248
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
249
|
+
|
|
250
|
+
return model
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def viptr_tiny(pretrained: bool = False, **kwargs: Any) -> VIPTR:
|
|
254
|
+
"""VIPTR-Tiny as described in `"A Vision Permutable Extractor for Fast and Efficient Scene Text Recognition"
|
|
255
|
+
<https://arxiv.org/abs/2401.10110>`_.
|
|
256
|
+
|
|
257
|
+
>>> import torch
|
|
258
|
+
>>> from doctr.models import viptr_tiny
|
|
259
|
+
>>> model = viptr_tiny(pretrained=False)
|
|
260
|
+
>>> input_tensor = torch.rand((1, 3, 32, 128))
|
|
261
|
+
>>> out = model(input_tensor)
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
265
|
+
**kwargs: keyword arguments of the VIPTR architecture
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
VIPTR: a VIPTR model instance
|
|
269
|
+
"""
|
|
270
|
+
return _viptr(
|
|
271
|
+
"viptr_tiny",
|
|
272
|
+
pretrained,
|
|
273
|
+
vip_tiny,
|
|
274
|
+
"5",
|
|
275
|
+
ignore_keys=["head.weight", "head.bias"],
|
|
276
|
+
**kwargs,
|
|
277
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Tuple
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -17,17 +16,15 @@ class _ViTSTR:
|
|
|
17
16
|
|
|
18
17
|
def build_target(
|
|
19
18
|
self,
|
|
20
|
-
gts:
|
|
21
|
-
) ->
|
|
19
|
+
gts: list[str],
|
|
20
|
+
) -> tuple[np.ndarray, list[int]]:
|
|
22
21
|
"""Encode a list of gts sequences into a np array and gives the corresponding*
|
|
23
22
|
sequence lengths.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
gts: list of ground-truth labels
|
|
28
26
|
|
|
29
27
|
Returns:
|
|
30
|
-
-------
|
|
31
28
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
32
29
|
"""
|
|
33
30
|
encoded = encode_sequences(
|
|
@@ -45,7 +42,6 @@ class _ViTSTRPostProcessor(RecognitionPostProcessor):
|
|
|
45
42
|
"""Abstract class to postprocess the raw output of the model
|
|
46
43
|
|
|
47
44
|
Args:
|
|
48
|
-
----
|
|
49
45
|
vocab: string containing the ordered sequence of supported characters
|
|
50
46
|
"""
|
|
51
47
|
|