python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +1 -5
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1100 -54
- doctr/file_utils.py +2 -92
- doctr/io/elements.py +37 -3
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +4 -4
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -4
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +11 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vip/__init__.py +1 -0
- doctr/models/classification/vip/layers/__init__.py +1 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +12 -3
- doctr/models/classification/zoo.py +7 -8
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +7 -16
- doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +6 -17
- doctr/models/detection/fast/pytorch.py +17 -8
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +5 -15
- doctr/models/detection/linknet/pytorch.py +12 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +9 -22
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +52 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +16 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +15 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +26 -8
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +100 -47
- doctr/models/recognition/predictor/pytorch.py +4 -5
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +13 -4
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +1 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +13 -4
- doctr/models/recognition/zoo.py +13 -8
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +29 -19
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -433
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -397
- doctr/models/classification/textnet/tensorflow.py +0 -266
- doctr/models/classification/vgg/tensorflow.py +0 -116
- doctr/models/classification/vit/tensorflow.py +0 -192
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
- doctr/models/detection/fast/tensorflow.py +0 -419
- doctr/models/detection/linknet/tensorflow.py +0 -369
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -308
- doctr/models/recognition/master/tensorflow.py +0 -313
- doctr/models/recognition/parseq/tensorflow.py +0 -508
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -416
- doctr/models/recognition/vitstr/tensorflow.py +0 -278
- doctr/models/utils/tensorflow.py +0 -182
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.11.0.dist-info/RECORD +0 -173
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
|
@@ -38,13 +38,13 @@ class RecognitionPredictor(nn.Module):
|
|
|
38
38
|
self.model = model.eval()
|
|
39
39
|
self.split_wide_crops = split_wide_crops
|
|
40
40
|
self.critical_ar = 8 # Critical aspect ratio
|
|
41
|
-
self.
|
|
41
|
+
self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
|
|
42
42
|
self.target_ar = 6 # Target aspect ratio
|
|
43
43
|
|
|
44
44
|
@torch.inference_mode()
|
|
45
45
|
def forward(
|
|
46
46
|
self,
|
|
47
|
-
crops: Sequence[np.ndarray
|
|
47
|
+
crops: Sequence[np.ndarray],
|
|
48
48
|
**kwargs: Any,
|
|
49
49
|
) -> list[tuple[str, float]]:
|
|
50
50
|
if len(crops) == 0:
|
|
@@ -60,8 +60,7 @@ class RecognitionPredictor(nn.Module):
|
|
|
60
60
|
crops, # type: ignore[arg-type]
|
|
61
61
|
self.critical_ar,
|
|
62
62
|
self.target_ar,
|
|
63
|
-
self.
|
|
64
|
-
isinstance(crops[0], np.ndarray),
|
|
63
|
+
self.overlap_ratio,
|
|
65
64
|
)
|
|
66
65
|
if remapped:
|
|
67
66
|
crops = new_crops
|
|
@@ -81,6 +80,6 @@ class RecognitionPredictor(nn.Module):
|
|
|
81
80
|
|
|
82
81
|
# Remap crops
|
|
83
82
|
if self.split_wide_crops and remapped:
|
|
84
|
-
out = remap_preds(out, crop_map, self.
|
|
83
|
+
out = remap_preds(out, crop_map, self.overlap_ratio)
|
|
85
84
|
|
|
86
85
|
return out
|
|
@@ -15,7 +15,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
17
|
from ...classification import resnet31
|
|
18
|
-
from ...utils
|
|
18
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
19
19
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
20
20
|
|
|
21
21
|
__all__ = ["SAR", "sar_resnet31"]
|
|
@@ -228,6 +228,15 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
228
228
|
nn.init.constant_(m.weight, 1)
|
|
229
229
|
nn.init.constant_(m.bias, 0)
|
|
230
230
|
|
|
231
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
232
|
+
"""Load pretrained parameters onto the model
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
236
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
237
|
+
"""
|
|
238
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
239
|
+
|
|
231
240
|
def forward(
|
|
232
241
|
self,
|
|
233
242
|
x: torch.Tensor,
|
|
@@ -263,7 +272,7 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
263
272
|
|
|
264
273
|
if target is None or return_preds:
|
|
265
274
|
# Disable for torch.compile compatibility
|
|
266
|
-
@torch.compiler.disable
|
|
275
|
+
@torch.compiler.disable
|
|
267
276
|
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
268
277
|
return self.postprocessor(decoded_features)
|
|
269
278
|
|
|
@@ -295,7 +304,7 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
295
304
|
# Input length : number of timesteps
|
|
296
305
|
input_len = model_output.shape[1]
|
|
297
306
|
# Add one for additional <eos> token
|
|
298
|
-
seq_len = seq_len + 1
|
|
307
|
+
seq_len = seq_len + 1
|
|
299
308
|
# Compute loss
|
|
300
309
|
# (N, L, vocab_size + 1)
|
|
301
310
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
|
|
@@ -364,7 +373,7 @@ def _sar(
|
|
|
364
373
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
365
374
|
# remove the last layer weights
|
|
366
375
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
367
|
-
|
|
376
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
368
377
|
|
|
369
378
|
return model
|
|
370
379
|
|
|
@@ -4,81 +4,90 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
from rapidfuzz.distance import
|
|
7
|
+
from rapidfuzz.distance import Hamming
|
|
8
8
|
|
|
9
9
|
__all__ = ["merge_strings", "merge_multi_strings"]
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def merge_strings(a: str, b: str,
|
|
12
|
+
def merge_strings(a: str, b: str, overlap_ratio: float) -> str:
|
|
13
13
|
"""Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
a: first char seq, suffix should be similar to b's prefix.
|
|
17
17
|
b: second char seq, prefix should be similar to a's suffix.
|
|
18
|
-
|
|
19
|
-
only used when the mother sequence is splitted on a character repetition
|
|
18
|
+
overlap_ratio: estimated ratio of overlapping characters.
|
|
20
19
|
|
|
21
20
|
Returns:
|
|
22
21
|
A merged character sequence.
|
|
23
22
|
|
|
24
23
|
Example::
|
|
25
|
-
>>> from doctr.models.recognition.utils import
|
|
26
|
-
>>>
|
|
24
|
+
>>> from doctr.models.recognition.utils import merge_strings
|
|
25
|
+
>>> merge_strings('abcd', 'cdefgh', 0.5)
|
|
27
26
|
'abcdefgh'
|
|
28
|
-
>>>
|
|
27
|
+
>>> merge_strings('abcdi', 'cdefgh', 0.5)
|
|
29
28
|
'abcdefgh'
|
|
30
29
|
"""
|
|
31
30
|
seq_len = min(len(a), len(b))
|
|
32
|
-
if seq_len
|
|
33
|
-
return b if len(a) == 0 else a
|
|
34
|
-
|
|
35
|
-
# Initialize merging index and corresponding score (mean Levenstein)
|
|
36
|
-
min_score, index = 1.0, 0 # No overlap, just concatenate
|
|
37
|
-
|
|
38
|
-
scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)]
|
|
39
|
-
|
|
40
|
-
# Edge case (split in the middle of char repetitions): if it starts with 2 or more 0
|
|
41
|
-
if len(scores) > 1 and (scores[0], scores[1]) == (0, 0):
|
|
42
|
-
# Compute n_overlap (number of overlapping chars, geometrically determined)
|
|
43
|
-
n_overlap = round(len(b) * (dil_factor - 1) / dil_factor)
|
|
44
|
-
# Find the number of consecutive zeros in the scores list
|
|
45
|
-
# Impossible to have a zero after a non-zero score in that case
|
|
46
|
-
n_zeros = sum(val == 0 for val in scores)
|
|
47
|
-
# Index is bounded by the geometrical overlap to avoid collapsing repetitions
|
|
48
|
-
min_score, index = 0, min(n_zeros, n_overlap)
|
|
49
|
-
|
|
50
|
-
else: # Common case: choose the min score index
|
|
51
|
-
for i, score in enumerate(scores):
|
|
52
|
-
if score < min_score:
|
|
53
|
-
min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char
|
|
54
|
-
|
|
55
|
-
# Merge with correct overlap
|
|
56
|
-
if index == 0:
|
|
31
|
+
if seq_len <= 1: # One sequence is empty or will be after cropping in next step, return both to keep data
|
|
57
32
|
return a + b
|
|
58
|
-
return a[:-1] + b[index - 1 :]
|
|
59
33
|
|
|
34
|
+
a_crop, b_crop = a[:-1], b[1:] # Remove last letter of "a" and first of "b", because they might be cut off
|
|
35
|
+
max_overlap = min(len(a_crop), len(b_crop))
|
|
60
36
|
|
|
61
|
-
|
|
62
|
-
|
|
37
|
+
# Compute Hamming distances for all possible overlaps
|
|
38
|
+
scores = [Hamming.distance(a_crop[-i:], b_crop[:i], processor=None) for i in range(1, max_overlap + 1)]
|
|
39
|
+
|
|
40
|
+
# Find zero-score matches
|
|
41
|
+
zero_matches = [i for i, score in enumerate(scores) if score == 0]
|
|
42
|
+
|
|
43
|
+
expected_overlap = round(len(b) * overlap_ratio) - 3 # adjust for cropping and index
|
|
44
|
+
|
|
45
|
+
# Case 1: One perfect match - exactly one zero score - just merge there
|
|
46
|
+
if len(zero_matches) == 1:
|
|
47
|
+
i = zero_matches[0]
|
|
48
|
+
return a_crop + b_crop[i + 1 :]
|
|
49
|
+
|
|
50
|
+
# Case 2: Multiple perfect matches - likely due to repeated characters.
|
|
51
|
+
# Use the estimated overlap length to choose the match closest to the expected alignment.
|
|
52
|
+
elif len(zero_matches) > 1:
|
|
53
|
+
best_i = min(zero_matches, key=lambda x: abs(x - expected_overlap))
|
|
54
|
+
return a_crop + b_crop[best_i + 1 :]
|
|
55
|
+
|
|
56
|
+
# Case 3: Absence of zero scores indicates that the same character in the image was recognized differently OR that
|
|
57
|
+
# the overlap was too small and we just need to merge the crops fully
|
|
58
|
+
if expected_overlap < -1:
|
|
59
|
+
return a + b
|
|
60
|
+
elif expected_overlap < 0:
|
|
61
|
+
return a_crop + b_crop
|
|
62
|
+
|
|
63
|
+
# Find best overlap by minimizing Hamming distance + distance from expected overlap size
|
|
64
|
+
combined_scores = [score + abs(i - expected_overlap) for i, score in enumerate(scores)]
|
|
65
|
+
best_i = combined_scores.index(min(combined_scores))
|
|
66
|
+
return a_crop + b_crop[best_i + 1 :]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def merge_multi_strings(seq_list: list[str], overlap_ratio: float, last_overlap_ratio: float) -> str:
|
|
70
|
+
"""
|
|
71
|
+
Merges consecutive string sequences with overlapping characters.
|
|
63
72
|
|
|
64
73
|
Args:
|
|
65
74
|
seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
|
|
66
|
-
|
|
67
|
-
|
|
75
|
+
overlap_ratio: Estimated ratio of overlapping letters between neighboring strings.
|
|
76
|
+
last_overlap_ratio: Estimated ratio of overlapping letters for the last element in seq_list.
|
|
68
77
|
|
|
69
78
|
Returns:
|
|
70
79
|
A merged character sequence
|
|
71
80
|
|
|
72
81
|
Example::
|
|
73
|
-
>>> from doctr.models.recognition.utils import
|
|
74
|
-
>>>
|
|
82
|
+
>>> from doctr.models.recognition.utils import merge_multi_strings
|
|
83
|
+
>>> merge_multi_strings(['abc', 'bcdef', 'difghi', 'aijkl'], 0.5, 0.1)
|
|
75
84
|
'abcdefghijkl'
|
|
76
85
|
"""
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
return
|
|
86
|
+
if not seq_list:
|
|
87
|
+
return ""
|
|
88
|
+
result = seq_list[0]
|
|
89
|
+
for i in range(1, len(seq_list)):
|
|
90
|
+
text_b = seq_list[i]
|
|
91
|
+
ratio = last_overlap_ratio if i == len(seq_list) - 1 else overlap_ratio
|
|
92
|
+
result = merge_strings(result, text_b, ratio)
|
|
93
|
+
return result
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .pytorch import *
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from itertools import groupby
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
import torch.nn.functional as F
|
|
14
|
+
from torchvision.models._utils import IntermediateLayerGetter
|
|
15
|
+
|
|
16
|
+
from doctr.datasets import VOCABS, decode_sequence
|
|
17
|
+
|
|
18
|
+
from ...classification import vip_tiny
|
|
19
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
20
|
+
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
21
|
+
|
|
22
|
+
__all__ = ["VIPTR", "viptr_tiny"]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
26
|
+
"viptr_tiny": {
|
|
27
|
+
"mean": (0.694, 0.695, 0.693),
|
|
28
|
+
"std": (0.299, 0.296, 0.301),
|
|
29
|
+
"input_shape": (3, 32, 128),
|
|
30
|
+
"vocab": VOCABS["french"],
|
|
31
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.11.0/viptr_tiny-1cb2515e.pt&src=0",
|
|
32
|
+
},
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class VIPTRPostProcessor(RecognitionPostProcessor):
|
|
37
|
+
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
vocab: string containing the ordered sequence of supported characters
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def ctc_best_path(
|
|
45
|
+
logits: torch.Tensor,
|
|
46
|
+
vocab: str = VOCABS["french"],
|
|
47
|
+
blank: int = 0,
|
|
48
|
+
) -> list[tuple[str, float]]:
|
|
49
|
+
"""Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
|
|
50
|
+
<https://github.com/githubharald/CTCDecoder>`_.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
logits: model output, shape: N x T x C
|
|
54
|
+
vocab: vocabulary to use
|
|
55
|
+
blank: index of blank label
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A list of tuples: (word, confidence)
|
|
59
|
+
"""
|
|
60
|
+
# Gather the most confident characters, and assign the smallest conf among those to the sequence prob
|
|
61
|
+
probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values
|
|
62
|
+
|
|
63
|
+
# collapse best path (using itertools.groupby), map to chars, join char list to string
|
|
64
|
+
words = [
|
|
65
|
+
decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
|
|
66
|
+
for seq in torch.argmax(logits, dim=-1)
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
return list(zip(words, probs.tolist()))
|
|
70
|
+
|
|
71
|
+
def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
72
|
+
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
73
|
+
with label_to_idx mapping dictionary
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
A tuple of 2 lists: a list of str (words) and a list of float (probs)
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
# Decode CTC
|
|
83
|
+
return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class VIPTR(RecognitionModel, nn.Module):
|
|
87
|
+
"""Implements a VIPTR architecture as described in `"A Vision Permutable Extractor for Fast and Efficient
|
|
88
|
+
Scene Text Recognition" <https://arxiv.org/abs/2401.10110>`_.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
feature_extractor: the backbone serving as feature extractor
|
|
92
|
+
vocab: vocabulary used for encoding
|
|
93
|
+
input_shape: input shape of the image
|
|
94
|
+
exportable: onnx exportable returns only logits
|
|
95
|
+
cfg: configuration dictionary
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
feature_extractor: nn.Module,
|
|
101
|
+
vocab: str,
|
|
102
|
+
input_shape: tuple[int, int, int] = (3, 32, 128),
|
|
103
|
+
exportable: bool = False,
|
|
104
|
+
cfg: dict[str, Any] | None = None,
|
|
105
|
+
):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.vocab = vocab
|
|
108
|
+
self.exportable = exportable
|
|
109
|
+
self.cfg = cfg
|
|
110
|
+
self.max_length = 32
|
|
111
|
+
self.vocab_size = len(vocab)
|
|
112
|
+
|
|
113
|
+
self.feat_extractor = feature_extractor
|
|
114
|
+
with torch.inference_mode():
|
|
115
|
+
embedding_units = self.feat_extractor(torch.zeros((1, *input_shape)))["features"].shape[-1]
|
|
116
|
+
|
|
117
|
+
self.postprocessor = VIPTRPostProcessor(vocab=self.vocab)
|
|
118
|
+
self.head = nn.Linear(embedding_units, len(self.vocab) + 1) # +1 for PAD
|
|
119
|
+
|
|
120
|
+
for n, m in self.named_modules():
|
|
121
|
+
# Don't override the initialization of the backbone
|
|
122
|
+
if n.startswith("feat_extractor."):
|
|
123
|
+
continue
|
|
124
|
+
if isinstance(m, nn.Linear):
|
|
125
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
126
|
+
if m.bias is not None:
|
|
127
|
+
nn.init.zeros_(m.bias)
|
|
128
|
+
|
|
129
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
130
|
+
"""Load pretrained parameters onto the model
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
134
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
135
|
+
"""
|
|
136
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
137
|
+
|
|
138
|
+
def forward(
|
|
139
|
+
self,
|
|
140
|
+
x: torch.Tensor,
|
|
141
|
+
target: list[str] | None = None,
|
|
142
|
+
return_model_output: bool = False,
|
|
143
|
+
return_preds: bool = False,
|
|
144
|
+
) -> dict[str, Any]:
|
|
145
|
+
if target is not None:
|
|
146
|
+
_gt, _seq_len = self.build_target(target)
|
|
147
|
+
gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len)
|
|
148
|
+
gt, seq_len = gt.to(x.device), seq_len.to(x.device)
|
|
149
|
+
|
|
150
|
+
if self.training and target is None:
|
|
151
|
+
raise ValueError("Need to provide labels during training")
|
|
152
|
+
|
|
153
|
+
features = self.feat_extractor(x)["features"] # (B, max_len, embed_dim)
|
|
154
|
+
B, N, E = features.size()
|
|
155
|
+
logits = self.head(features).view(B, N, len(self.vocab) + 1)
|
|
156
|
+
|
|
157
|
+
decoded_features = _bf16_to_float32(logits)
|
|
158
|
+
|
|
159
|
+
out: dict[str, Any] = {}
|
|
160
|
+
if self.exportable:
|
|
161
|
+
out["logits"] = decoded_features
|
|
162
|
+
return out
|
|
163
|
+
|
|
164
|
+
if return_model_output:
|
|
165
|
+
out["out_map"] = decoded_features
|
|
166
|
+
|
|
167
|
+
if target is None or return_preds:
|
|
168
|
+
# Disable for torch.compile compatibility
|
|
169
|
+
@torch.compiler.disable
|
|
170
|
+
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
171
|
+
return self.postprocessor(decoded_features)
|
|
172
|
+
|
|
173
|
+
# Post-process boxes
|
|
174
|
+
out["preds"] = _postprocess(decoded_features)
|
|
175
|
+
|
|
176
|
+
if target is not None:
|
|
177
|
+
out["loss"] = self.compute_loss(decoded_features, gt, seq_len, len(self.vocab))
|
|
178
|
+
|
|
179
|
+
return out
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
def compute_loss(
|
|
183
|
+
model_output: torch.Tensor,
|
|
184
|
+
gt: torch.Tensor,
|
|
185
|
+
seq_len: torch.Tensor,
|
|
186
|
+
blank_idx: int = 0,
|
|
187
|
+
) -> torch.Tensor:
|
|
188
|
+
"""Compute CTC loss for the model.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
model_output: predicted logits of the model
|
|
192
|
+
gt: ground truth tensor
|
|
193
|
+
seq_len: sequence lengths of the ground truth
|
|
194
|
+
blank_idx: index of the blank label
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
The loss of the model on the batch
|
|
198
|
+
"""
|
|
199
|
+
batch_len = model_output.shape[0]
|
|
200
|
+
input_length = model_output.shape[1] * torch.ones(size=(batch_len,), dtype=torch.int32)
|
|
201
|
+
# N x T x C -> T x N x C
|
|
202
|
+
logits = model_output.permute(1, 0, 2)
|
|
203
|
+
probs = F.log_softmax(logits, dim=-1)
|
|
204
|
+
ctc_loss = F.ctc_loss(
|
|
205
|
+
probs,
|
|
206
|
+
gt,
|
|
207
|
+
input_length,
|
|
208
|
+
seq_len,
|
|
209
|
+
blank_idx,
|
|
210
|
+
zero_infinity=True,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
return ctc_loss
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def _viptr(
|
|
217
|
+
arch: str,
|
|
218
|
+
pretrained: bool,
|
|
219
|
+
backbone_fn: Callable[[bool], nn.Module],
|
|
220
|
+
layer: str,
|
|
221
|
+
pretrained_backbone: bool = True,
|
|
222
|
+
ignore_keys: list[str] | None = None,
|
|
223
|
+
**kwargs: Any,
|
|
224
|
+
) -> VIPTR:
|
|
225
|
+
pretrained_backbone = pretrained_backbone and not pretrained
|
|
226
|
+
|
|
227
|
+
# Patch the config
|
|
228
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
229
|
+
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
|
|
230
|
+
_cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
|
|
231
|
+
|
|
232
|
+
# Feature extractor
|
|
233
|
+
feat_extractor = IntermediateLayerGetter(
|
|
234
|
+
backbone_fn(pretrained_backbone, input_shape=_cfg["input_shape"]), # type: ignore[call-arg]
|
|
235
|
+
{layer: "features"},
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
kwargs["vocab"] = _cfg["vocab"]
|
|
239
|
+
kwargs["input_shape"] = _cfg["input_shape"]
|
|
240
|
+
|
|
241
|
+
model = VIPTR(feat_extractor, cfg=_cfg, **kwargs)
|
|
242
|
+
|
|
243
|
+
# Load pretrained parameters
|
|
244
|
+
if pretrained:
|
|
245
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
246
|
+
# remove the last layer weights
|
|
247
|
+
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
248
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
249
|
+
|
|
250
|
+
return model
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def viptr_tiny(pretrained: bool = False, **kwargs: Any) -> VIPTR:
|
|
254
|
+
"""VIPTR-Tiny as described in `"A Vision Permutable Extractor for Fast and Efficient Scene Text Recognition"
|
|
255
|
+
<https://arxiv.org/abs/2401.10110>`_.
|
|
256
|
+
|
|
257
|
+
>>> import torch
|
|
258
|
+
>>> from doctr.models import viptr_tiny
|
|
259
|
+
>>> model = viptr_tiny(pretrained=False)
|
|
260
|
+
>>> input_tensor = torch.rand((1, 3, 32, 128))
|
|
261
|
+
>>> out = model(input_tensor)
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
265
|
+
**kwargs: keyword arguments of the VIPTR architecture
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
VIPTR: a VIPTR model instance
|
|
269
|
+
"""
|
|
270
|
+
return _viptr(
|
|
271
|
+
"viptr_tiny",
|
|
272
|
+
pretrained,
|
|
273
|
+
vip_tiny,
|
|
274
|
+
"5",
|
|
275
|
+
ignore_keys=["head.weight", "head.bias"],
|
|
276
|
+
**kwargs,
|
|
277
|
+
)
|
|
@@ -15,7 +15,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
17
|
from ...classification import vit_b, vit_s
|
|
18
|
-
from ...utils
|
|
18
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
19
19
|
from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
20
20
|
|
|
21
21
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
@@ -74,6 +74,15 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
74
74
|
|
|
75
75
|
self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
|
|
76
76
|
|
|
77
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
78
|
+
"""Load pretrained parameters onto the model
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
82
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
83
|
+
"""
|
|
84
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
85
|
+
|
|
77
86
|
def forward(
|
|
78
87
|
self,
|
|
79
88
|
x: torch.Tensor,
|
|
@@ -108,7 +117,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
108
117
|
|
|
109
118
|
if target is None or return_preds:
|
|
110
119
|
# Disable for torch.compile compatibility
|
|
111
|
-
@torch.compiler.disable
|
|
120
|
+
@torch.compiler.disable
|
|
112
121
|
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
113
122
|
return self.postprocessor(decoded_features)
|
|
114
123
|
|
|
@@ -140,7 +149,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
140
149
|
# Input length : number of steps
|
|
141
150
|
input_len = model_output.shape[1]
|
|
142
151
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
143
|
-
seq_len = seq_len + 1
|
|
152
|
+
seq_len = seq_len + 1
|
|
144
153
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
145
154
|
# The "masked" first gt char is <sos>.
|
|
146
155
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -214,7 +223,7 @@ def _vitstr(
|
|
|
214
223
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
215
224
|
# remove the last layer weights
|
|
216
225
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
217
|
-
|
|
226
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
218
227
|
|
|
219
228
|
return model
|
|
220
229
|
|
doctr/models/recognition/zoo.py
CHANGED
|
@@ -5,8 +5,8 @@
|
|
|
5
5
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.file_utils import is_tf_available, is_torch_available
|
|
9
8
|
from doctr.models.preprocessor import PreProcessor
|
|
9
|
+
from doctr.models.utils import _CompiledModule
|
|
10
10
|
|
|
11
11
|
from .. import recognition
|
|
12
12
|
from .predictor import RecognitionPredictor
|
|
@@ -23,6 +23,7 @@ ARCHS: list[str] = [
|
|
|
23
23
|
"vitstr_small",
|
|
24
24
|
"vitstr_base",
|
|
25
25
|
"parseq",
|
|
26
|
+
"viptr_tiny",
|
|
26
27
|
]
|
|
27
28
|
|
|
28
29
|
|
|
@@ -35,12 +36,16 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
35
36
|
pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
|
|
36
37
|
)
|
|
37
38
|
else:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
39
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
40
|
+
allowed_archs = [
|
|
41
|
+
recognition.CRNN,
|
|
42
|
+
recognition.SAR,
|
|
43
|
+
recognition.MASTER,
|
|
44
|
+
recognition.ViTSTR,
|
|
45
|
+
recognition.PARSeq,
|
|
46
|
+
recognition.VIPTR,
|
|
47
|
+
_CompiledModule,
|
|
48
|
+
]
|
|
44
49
|
|
|
45
50
|
if not isinstance(arch, tuple(allowed_archs)):
|
|
46
51
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
@@ -51,7 +56,7 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
51
56
|
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
52
57
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
53
58
|
kwargs["batch_size"] = kwargs.get("batch_size", 128)
|
|
54
|
-
input_shape = _model.cfg["input_shape"][
|
|
59
|
+
input_shape = _model.cfg["input_shape"][-2:]
|
|
55
60
|
predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model)
|
|
56
61
|
|
|
57
62
|
return predictor
|