python-doctr 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/loader.py +1 -1
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1102 -54
- doctr/file_utils.py +9 -0
- doctr/io/elements.py +37 -3
- doctr/models/_utils.py +1 -1
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +1 -2
- doctr/models/classification/magc_resnet/tensorflow.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/mobilenet/tensorflow.py +11 -2
- doctr/models/classification/predictor/pytorch.py +1 -1
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/resnet/tensorflow.py +25 -4
- doctr/models/classification/textnet/pytorch.py +10 -1
- doctr/models/classification/textnet/tensorflow.py +11 -2
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vgg/tensorflow.py +11 -2
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/pytorch.py +10 -1
- doctr/models/classification/vit/tensorflow.py +9 -0
- doctr/models/classification/zoo.py +4 -0
- doctr/models/detection/differentiable_binarization/base.py +3 -4
- doctr/models/detection/differentiable_binarization/pytorch.py +10 -1
- doctr/models/detection/differentiable_binarization/tensorflow.py +11 -4
- doctr/models/detection/fast/base.py +2 -3
- doctr/models/detection/fast/pytorch.py +13 -4
- doctr/models/detection/fast/tensorflow.py +10 -2
- doctr/models/detection/linknet/base.py +2 -3
- doctr/models/detection/linknet/pytorch.py +10 -1
- doctr/models/detection/linknet/tensorflow.py +10 -2
- doctr/models/factory/hub.py +3 -3
- doctr/models/kie_predictor/pytorch.py +1 -1
- doctr/models/kie_predictor/tensorflow.py +1 -1
- doctr/models/modules/layers/pytorch.py +49 -1
- doctr/models/predictor/pytorch.py +1 -1
- doctr/models/predictor/tensorflow.py +1 -1
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/pytorch.py +10 -1
- doctr/models/recognition/crnn/tensorflow.py +10 -1
- doctr/models/recognition/master/pytorch.py +10 -1
- doctr/models/recognition/master/tensorflow.py +10 -3
- doctr/models/recognition/parseq/pytorch.py +23 -5
- doctr/models/recognition/parseq/tensorflow.py +13 -5
- doctr/models/recognition/predictor/_utils.py +107 -45
- doctr/models/recognition/predictor/pytorch.py +3 -3
- doctr/models/recognition/predictor/tensorflow.py +3 -3
- doctr/models/recognition/sar/pytorch.py +10 -1
- doctr/models/recognition/sar/tensorflow.py +10 -3
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/pytorch.py +10 -1
- doctr/models/recognition/vitstr/tensorflow.py +10 -3
- doctr/models/recognition/zoo.py +5 -0
- doctr/models/utils/pytorch.py +28 -18
- doctr/models/utils/tensorflow.py +15 -8
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +19 -3
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/RECORD +82 -75
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -228,6 +228,15 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
228
228
|
nn.init.constant_(m.weight, 1)
|
|
229
229
|
nn.init.constant_(m.bias, 0)
|
|
230
230
|
|
|
231
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
232
|
+
"""Load pretrained parameters onto the model
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
236
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
237
|
+
"""
|
|
238
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
239
|
+
|
|
231
240
|
def forward(
|
|
232
241
|
self,
|
|
233
242
|
x: torch.Tensor,
|
|
@@ -364,7 +373,7 @@ def _sar(
|
|
|
364
373
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
365
374
|
# remove the last layer weights
|
|
366
375
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
367
|
-
|
|
376
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
368
377
|
|
|
369
378
|
return model
|
|
370
379
|
|
|
@@ -255,6 +255,15 @@ class SAR(Model, RecognitionModel):
|
|
|
255
255
|
|
|
256
256
|
self.postprocessor = SARPostProcessor(vocab=vocab)
|
|
257
257
|
|
|
258
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
259
|
+
"""Load pretrained parameters onto the model
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
263
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
264
|
+
"""
|
|
265
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
266
|
+
|
|
258
267
|
@staticmethod
|
|
259
268
|
def compute_loss(
|
|
260
269
|
model_output: tf.Tensor,
|
|
@@ -389,9 +398,7 @@ def _sar(
|
|
|
389
398
|
# Load pretrained parameters
|
|
390
399
|
if pretrained:
|
|
391
400
|
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
392
|
-
|
|
393
|
-
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
394
|
-
)
|
|
401
|
+
model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
395
402
|
|
|
396
403
|
return model
|
|
397
404
|
|
|
@@ -4,81 +4,90 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
from rapidfuzz.distance import
|
|
7
|
+
from rapidfuzz.distance import Hamming
|
|
8
8
|
|
|
9
9
|
__all__ = ["merge_strings", "merge_multi_strings"]
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def merge_strings(a: str, b: str,
|
|
12
|
+
def merge_strings(a: str, b: str, overlap_ratio: float) -> str:
|
|
13
13
|
"""Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
a: first char seq, suffix should be similar to b's prefix.
|
|
17
17
|
b: second char seq, prefix should be similar to a's suffix.
|
|
18
|
-
|
|
19
|
-
only used when the mother sequence is splitted on a character repetition
|
|
18
|
+
overlap_ratio: estimated ratio of overlapping characters.
|
|
20
19
|
|
|
21
20
|
Returns:
|
|
22
21
|
A merged character sequence.
|
|
23
22
|
|
|
24
23
|
Example::
|
|
25
|
-
>>> from doctr.models.recognition.utils import
|
|
26
|
-
>>>
|
|
24
|
+
>>> from doctr.models.recognition.utils import merge_strings
|
|
25
|
+
>>> merge_strings('abcd', 'cdefgh', 0.5)
|
|
27
26
|
'abcdefgh'
|
|
28
|
-
>>>
|
|
27
|
+
>>> merge_strings('abcdi', 'cdefgh', 0.5)
|
|
29
28
|
'abcdefgh'
|
|
30
29
|
"""
|
|
31
30
|
seq_len = min(len(a), len(b))
|
|
32
|
-
if seq_len
|
|
33
|
-
return b if len(a) == 0 else a
|
|
34
|
-
|
|
35
|
-
# Initialize merging index and corresponding score (mean Levenstein)
|
|
36
|
-
min_score, index = 1.0, 0 # No overlap, just concatenate
|
|
37
|
-
|
|
38
|
-
scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)]
|
|
39
|
-
|
|
40
|
-
# Edge case (split in the middle of char repetitions): if it starts with 2 or more 0
|
|
41
|
-
if len(scores) > 1 and (scores[0], scores[1]) == (0, 0):
|
|
42
|
-
# Compute n_overlap (number of overlapping chars, geometrically determined)
|
|
43
|
-
n_overlap = round(len(b) * (dil_factor - 1) / dil_factor)
|
|
44
|
-
# Find the number of consecutive zeros in the scores list
|
|
45
|
-
# Impossible to have a zero after a non-zero score in that case
|
|
46
|
-
n_zeros = sum(val == 0 for val in scores)
|
|
47
|
-
# Index is bounded by the geometrical overlap to avoid collapsing repetitions
|
|
48
|
-
min_score, index = 0, min(n_zeros, n_overlap)
|
|
49
|
-
|
|
50
|
-
else: # Common case: choose the min score index
|
|
51
|
-
for i, score in enumerate(scores):
|
|
52
|
-
if score < min_score:
|
|
53
|
-
min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char
|
|
54
|
-
|
|
55
|
-
# Merge with correct overlap
|
|
56
|
-
if index == 0:
|
|
31
|
+
if seq_len <= 1: # One sequence is empty or will be after cropping in next step, return both to keep data
|
|
57
32
|
return a + b
|
|
58
|
-
return a[:-1] + b[index - 1 :]
|
|
59
33
|
|
|
34
|
+
a_crop, b_crop = a[:-1], b[1:] # Remove last letter of "a" and first of "b", because they might be cut off
|
|
35
|
+
max_overlap = min(len(a_crop), len(b_crop))
|
|
60
36
|
|
|
61
|
-
|
|
62
|
-
|
|
37
|
+
# Compute Hamming distances for all possible overlaps
|
|
38
|
+
scores = [Hamming.distance(a_crop[-i:], b_crop[:i], processor=None) for i in range(1, max_overlap + 1)]
|
|
39
|
+
|
|
40
|
+
# Find zero-score matches
|
|
41
|
+
zero_matches = [i for i, score in enumerate(scores) if score == 0]
|
|
42
|
+
|
|
43
|
+
expected_overlap = round(len(b) * overlap_ratio) - 3 # adjust for cropping and index
|
|
44
|
+
|
|
45
|
+
# Case 1: One perfect match - exactly one zero score - just merge there
|
|
46
|
+
if len(zero_matches) == 1:
|
|
47
|
+
i = zero_matches[0]
|
|
48
|
+
return a_crop + b_crop[i + 1 :]
|
|
49
|
+
|
|
50
|
+
# Case 2: Multiple perfect matches - likely due to repeated characters.
|
|
51
|
+
# Use the estimated overlap length to choose the match closest to the expected alignment.
|
|
52
|
+
elif len(zero_matches) > 1:
|
|
53
|
+
best_i = min(zero_matches, key=lambda x: abs(x - expected_overlap))
|
|
54
|
+
return a_crop + b_crop[best_i + 1 :]
|
|
55
|
+
|
|
56
|
+
# Case 3: Absence of zero scores indicates that the same character in the image was recognized differently OR that
|
|
57
|
+
# the overlap was too small and we just need to merge the crops fully
|
|
58
|
+
if expected_overlap < -1:
|
|
59
|
+
return a + b
|
|
60
|
+
elif expected_overlap < 0:
|
|
61
|
+
return a_crop + b_crop
|
|
62
|
+
|
|
63
|
+
# Find best overlap by minimizing Hamming distance + distance from expected overlap size
|
|
64
|
+
combined_scores = [score + abs(i - expected_overlap) for i, score in enumerate(scores)]
|
|
65
|
+
best_i = combined_scores.index(min(combined_scores))
|
|
66
|
+
return a_crop + b_crop[best_i + 1 :]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def merge_multi_strings(seq_list: list[str], overlap_ratio: float, last_overlap_ratio: float) -> str:
|
|
70
|
+
"""
|
|
71
|
+
Merges consecutive string sequences with overlapping characters.
|
|
63
72
|
|
|
64
73
|
Args:
|
|
65
74
|
seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
|
|
66
|
-
|
|
67
|
-
|
|
75
|
+
overlap_ratio: Estimated ratio of overlapping letters between neighboring strings.
|
|
76
|
+
last_overlap_ratio: Estimated ratio of overlapping letters for the last element in seq_list.
|
|
68
77
|
|
|
69
78
|
Returns:
|
|
70
79
|
A merged character sequence
|
|
71
80
|
|
|
72
81
|
Example::
|
|
73
|
-
>>> from doctr.models.recognition.utils import
|
|
74
|
-
>>>
|
|
82
|
+
>>> from doctr.models.recognition.utils import merge_multi_strings
|
|
83
|
+
>>> merge_multi_strings(['abc', 'bcdef', 'difghi', 'aijkl'], 0.5, 0.1)
|
|
75
84
|
'abcdefghijkl'
|
|
76
85
|
"""
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
return
|
|
86
|
+
if not seq_list:
|
|
87
|
+
return ""
|
|
88
|
+
result = seq_list[0]
|
|
89
|
+
for i in range(1, len(seq_list)):
|
|
90
|
+
text_b = seq_list[i]
|
|
91
|
+
ratio = last_overlap_ratio if i == len(seq_list) - 1 else overlap_ratio
|
|
92
|
+
result = merge_strings(result, text_b, ratio)
|
|
93
|
+
return result
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from itertools import groupby
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
import torch.nn.functional as F
|
|
14
|
+
from torchvision.models._utils import IntermediateLayerGetter
|
|
15
|
+
|
|
16
|
+
from doctr.datasets import VOCABS, decode_sequence
|
|
17
|
+
|
|
18
|
+
from ...classification import vip_tiny
|
|
19
|
+
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
|
|
20
|
+
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
21
|
+
|
|
22
|
+
__all__ = ["VIPTR", "viptr_tiny"]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
26
|
+
"viptr_tiny": {
|
|
27
|
+
"mean": (0.694, 0.695, 0.693),
|
|
28
|
+
"std": (0.299, 0.296, 0.301),
|
|
29
|
+
"input_shape": (3, 32, 128),
|
|
30
|
+
"vocab": VOCABS["french"],
|
|
31
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.11.0/viptr_tiny-1cb2515e.pt&src=0",
|
|
32
|
+
},
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class VIPTRPostProcessor(RecognitionPostProcessor):
|
|
37
|
+
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
vocab: string containing the ordered sequence of supported characters
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def ctc_best_path(
|
|
45
|
+
logits: torch.Tensor,
|
|
46
|
+
vocab: str = VOCABS["french"],
|
|
47
|
+
blank: int = 0,
|
|
48
|
+
) -> list[tuple[str, float]]:
|
|
49
|
+
"""Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
|
|
50
|
+
<https://github.com/githubharald/CTCDecoder>`_.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
logits: model output, shape: N x T x C
|
|
54
|
+
vocab: vocabulary to use
|
|
55
|
+
blank: index of blank label
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A list of tuples: (word, confidence)
|
|
59
|
+
"""
|
|
60
|
+
# Gather the most confident characters, and assign the smallest conf among those to the sequence prob
|
|
61
|
+
probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values
|
|
62
|
+
|
|
63
|
+
# collapse best path (using itertools.groupby), map to chars, join char list to string
|
|
64
|
+
words = [
|
|
65
|
+
decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
|
|
66
|
+
for seq in torch.argmax(logits, dim=-1)
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
return list(zip(words, probs.tolist()))
|
|
70
|
+
|
|
71
|
+
def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
72
|
+
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
73
|
+
with label_to_idx mapping dictionnary
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
A tuple of 2 lists: a list of str (words) and a list of float (probs)
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
# Decode CTC
|
|
83
|
+
return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class VIPTR(RecognitionModel, nn.Module):
|
|
87
|
+
"""Implements a VIPTR architecture as described in `"A Vision Permutable Extractor for Fast and Efficient
|
|
88
|
+
Scene Text Recognition" <https://arxiv.org/abs/2401.10110>`_.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
feature_extractor: the backbone serving as feature extractor
|
|
92
|
+
vocab: vocabulary used for encoding
|
|
93
|
+
input_shape: input shape of the image
|
|
94
|
+
exportable: onnx exportable returns only logits
|
|
95
|
+
cfg: configuration dictionary
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
feature_extractor: nn.Module,
|
|
101
|
+
vocab: str,
|
|
102
|
+
input_shape: tuple[int, int, int] = (3, 32, 128),
|
|
103
|
+
exportable: bool = False,
|
|
104
|
+
cfg: dict[str, Any] | None = None,
|
|
105
|
+
):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.vocab = vocab
|
|
108
|
+
self.exportable = exportable
|
|
109
|
+
self.cfg = cfg
|
|
110
|
+
self.max_length = 32
|
|
111
|
+
self.vocab_size = len(vocab)
|
|
112
|
+
|
|
113
|
+
self.feat_extractor = feature_extractor
|
|
114
|
+
with torch.inference_mode():
|
|
115
|
+
embedding_units = self.feat_extractor(torch.zeros((1, *input_shape)))["features"].shape[-1]
|
|
116
|
+
|
|
117
|
+
self.postprocessor = VIPTRPostProcessor(vocab=self.vocab)
|
|
118
|
+
self.head = nn.Linear(embedding_units, len(self.vocab) + 1) # +1 for PAD
|
|
119
|
+
|
|
120
|
+
for n, m in self.named_modules():
|
|
121
|
+
# Don't override the initialization of the backbone
|
|
122
|
+
if n.startswith("feat_extractor."):
|
|
123
|
+
continue
|
|
124
|
+
if isinstance(m, nn.Linear):
|
|
125
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
126
|
+
if m.bias is not None:
|
|
127
|
+
nn.init.zeros_(m.bias)
|
|
128
|
+
|
|
129
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
130
|
+
"""Load pretrained parameters onto the model
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
134
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
135
|
+
"""
|
|
136
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
137
|
+
|
|
138
|
+
def forward(
|
|
139
|
+
self,
|
|
140
|
+
x: torch.Tensor,
|
|
141
|
+
target: list[str] | None = None,
|
|
142
|
+
return_model_output: bool = False,
|
|
143
|
+
return_preds: bool = False,
|
|
144
|
+
) -> dict[str, Any]:
|
|
145
|
+
if target is not None:
|
|
146
|
+
_gt, _seq_len = self.build_target(target)
|
|
147
|
+
gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len)
|
|
148
|
+
gt, seq_len = gt.to(x.device), seq_len.to(x.device)
|
|
149
|
+
|
|
150
|
+
if self.training and target is None:
|
|
151
|
+
raise ValueError("Need to provide labels during training")
|
|
152
|
+
|
|
153
|
+
features = self.feat_extractor(x)["features"] # (B, max_len, embed_dim)
|
|
154
|
+
B, N, E = features.size()
|
|
155
|
+
logits = self.head(features).view(B, N, len(self.vocab) + 1)
|
|
156
|
+
|
|
157
|
+
decoded_features = _bf16_to_float32(logits)
|
|
158
|
+
|
|
159
|
+
out: dict[str, Any] = {}
|
|
160
|
+
if self.exportable:
|
|
161
|
+
out["logits"] = decoded_features
|
|
162
|
+
return out
|
|
163
|
+
|
|
164
|
+
if return_model_output:
|
|
165
|
+
out["out_map"] = decoded_features
|
|
166
|
+
|
|
167
|
+
if target is None or return_preds:
|
|
168
|
+
# Disable for torch.compile compatibility
|
|
169
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
170
|
+
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
171
|
+
return self.postprocessor(decoded_features)
|
|
172
|
+
|
|
173
|
+
# Post-process boxes
|
|
174
|
+
out["preds"] = _postprocess(decoded_features)
|
|
175
|
+
|
|
176
|
+
if target is not None:
|
|
177
|
+
out["loss"] = self.compute_loss(decoded_features, gt, seq_len, len(self.vocab))
|
|
178
|
+
|
|
179
|
+
return out
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
def compute_loss(
|
|
183
|
+
model_output: torch.Tensor,
|
|
184
|
+
gt: torch.Tensor,
|
|
185
|
+
seq_len: torch.Tensor,
|
|
186
|
+
blank_idx: int = 0,
|
|
187
|
+
) -> torch.Tensor:
|
|
188
|
+
"""Compute CTC loss for the model.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
model_output: predicted logits of the model
|
|
192
|
+
gt: ground truth tensor
|
|
193
|
+
seq_len: sequence lengths of the ground truth
|
|
194
|
+
blank_idx: index of the blank label
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
The loss of the model on the batch
|
|
198
|
+
"""
|
|
199
|
+
batch_len = model_output.shape[0]
|
|
200
|
+
input_length = model_output.shape[1] * torch.ones(size=(batch_len,), dtype=torch.int32)
|
|
201
|
+
# N x T x C -> T x N x C
|
|
202
|
+
logits = model_output.permute(1, 0, 2)
|
|
203
|
+
probs = F.log_softmax(logits, dim=-1)
|
|
204
|
+
ctc_loss = F.ctc_loss(
|
|
205
|
+
probs,
|
|
206
|
+
gt,
|
|
207
|
+
input_length,
|
|
208
|
+
seq_len,
|
|
209
|
+
blank_idx,
|
|
210
|
+
zero_infinity=True,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
return ctc_loss
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def _viptr(
|
|
217
|
+
arch: str,
|
|
218
|
+
pretrained: bool,
|
|
219
|
+
backbone_fn: Callable[[bool], nn.Module],
|
|
220
|
+
layer: str,
|
|
221
|
+
pretrained_backbone: bool = True,
|
|
222
|
+
ignore_keys: list[str] | None = None,
|
|
223
|
+
**kwargs: Any,
|
|
224
|
+
) -> VIPTR:
|
|
225
|
+
pretrained_backbone = pretrained_backbone and not pretrained
|
|
226
|
+
|
|
227
|
+
# Patch the config
|
|
228
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
229
|
+
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
|
|
230
|
+
_cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
|
|
231
|
+
|
|
232
|
+
# Feature extractor
|
|
233
|
+
feat_extractor = IntermediateLayerGetter(
|
|
234
|
+
backbone_fn(pretrained_backbone, input_shape=_cfg["input_shape"]), # type: ignore[call-arg]
|
|
235
|
+
{layer: "features"},
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
kwargs["vocab"] = _cfg["vocab"]
|
|
239
|
+
kwargs["input_shape"] = _cfg["input_shape"]
|
|
240
|
+
|
|
241
|
+
model = VIPTR(feat_extractor, cfg=_cfg, **kwargs)
|
|
242
|
+
|
|
243
|
+
# Load pretrained parameters
|
|
244
|
+
if pretrained:
|
|
245
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
246
|
+
# remove the last layer weights
|
|
247
|
+
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
248
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
249
|
+
|
|
250
|
+
return model
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def viptr_tiny(pretrained: bool = False, **kwargs: Any) -> VIPTR:
|
|
254
|
+
"""VIPTR-Tiny as described in `"A Vision Permutable Extractor for Fast and Efficient Scene Text Recognition"
|
|
255
|
+
<https://arxiv.org/abs/2401.10110>`_.
|
|
256
|
+
|
|
257
|
+
>>> import torch
|
|
258
|
+
>>> from doctr.models import viptr_tiny
|
|
259
|
+
>>> model = viptr_tiny(pretrained=False)
|
|
260
|
+
>>> input_tensor = torch.rand((1, 3, 32, 128))
|
|
261
|
+
>>> out = model(input_tensor)
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
265
|
+
**kwargs: keyword arguments of the VIPTR architecture
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
VIPTR: a VIPTR model instance
|
|
269
|
+
"""
|
|
270
|
+
return _viptr(
|
|
271
|
+
"viptr_tiny",
|
|
272
|
+
pretrained,
|
|
273
|
+
vip_tiny,
|
|
274
|
+
"5",
|
|
275
|
+
ignore_keys=["head.weight", "head.bias"],
|
|
276
|
+
**kwargs,
|
|
277
|
+
)
|
|
@@ -74,6 +74,15 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
74
74
|
|
|
75
75
|
self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
|
|
76
76
|
|
|
77
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
78
|
+
"""Load pretrained parameters onto the model
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
82
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
83
|
+
"""
|
|
84
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
85
|
+
|
|
77
86
|
def forward(
|
|
78
87
|
self,
|
|
79
88
|
x: torch.Tensor,
|
|
@@ -214,7 +223,7 @@ def _vitstr(
|
|
|
214
223
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
215
224
|
# remove the last layer weights
|
|
216
225
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
217
|
-
|
|
226
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
218
227
|
|
|
219
228
|
return model
|
|
220
229
|
|
|
@@ -74,6 +74,15 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
74
74
|
|
|
75
75
|
self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
|
|
76
76
|
|
|
77
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
78
|
+
"""Load pretrained parameters onto the model
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
82
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
83
|
+
"""
|
|
84
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
85
|
+
|
|
77
86
|
@staticmethod
|
|
78
87
|
def compute_loss(
|
|
79
88
|
model_output: tf.Tensor,
|
|
@@ -217,9 +226,7 @@ def _vitstr(
|
|
|
217
226
|
# Load pretrained parameters
|
|
218
227
|
if pretrained:
|
|
219
228
|
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
220
|
-
|
|
221
|
-
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
222
|
-
)
|
|
229
|
+
model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
223
230
|
|
|
224
231
|
return model
|
|
225
232
|
|
doctr/models/recognition/zoo.py
CHANGED
|
@@ -25,6 +25,9 @@ ARCHS: list[str] = [
|
|
|
25
25
|
"parseq",
|
|
26
26
|
]
|
|
27
27
|
|
|
28
|
+
if is_torch_available():
|
|
29
|
+
ARCHS.extend(["viptr_tiny"])
|
|
30
|
+
|
|
28
31
|
|
|
29
32
|
def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor:
|
|
30
33
|
if isinstance(arch, str):
|
|
@@ -37,6 +40,8 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
37
40
|
else:
|
|
38
41
|
allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
|
|
39
42
|
if is_torch_available():
|
|
43
|
+
# Add VIPTR which is only available in torch at the moment
|
|
44
|
+
allowed_archs.append(recognition.VIPTR)
|
|
40
45
|
# Adding the type for torch compiled models to the allowed architectures
|
|
41
46
|
from doctr.models.utils import _CompiledModule
|
|
42
47
|
|
doctr/models/utils/pytorch.py
CHANGED
|
@@ -7,6 +7,7 @@ import logging
|
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
|
+
import validators
|
|
10
11
|
from torch import nn
|
|
11
12
|
|
|
12
13
|
from doctr.utils.data import download_from_url
|
|
@@ -36,7 +37,7 @@ def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
|
|
|
36
37
|
|
|
37
38
|
def load_pretrained_params(
|
|
38
39
|
model: nn.Module,
|
|
39
|
-
|
|
40
|
+
path_or_url: str | None = None,
|
|
40
41
|
hash_prefix: str | None = None,
|
|
41
42
|
ignore_keys: list[str] | None = None,
|
|
42
43
|
**kwargs: Any,
|
|
@@ -44,33 +45,42 @@ def load_pretrained_params(
|
|
|
44
45
|
"""Load a set of parameters onto a model
|
|
45
46
|
|
|
46
47
|
>>> from doctr.models import load_pretrained_params
|
|
47
|
-
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.
|
|
48
|
+
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.pt")
|
|
48
49
|
|
|
49
50
|
Args:
|
|
50
51
|
model: the PyTorch model to be loaded
|
|
51
|
-
|
|
52
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
52
53
|
hash_prefix: first characters of SHA256 expected hash
|
|
53
54
|
ignore_keys: list of weights to be ignored from the state_dict
|
|
54
55
|
**kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
|
|
55
56
|
"""
|
|
56
|
-
if
|
|
57
|
-
logging.warning("
|
|
58
|
-
|
|
59
|
-
|
|
57
|
+
if path_or_url is None:
|
|
58
|
+
logging.warning("No model URL or Path provided, using default initialization.")
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
archive_path = (
|
|
62
|
+
download_from_url(path_or_url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
|
|
63
|
+
if validators.url(path_or_url)
|
|
64
|
+
else path_or_url
|
|
65
|
+
)
|
|
60
66
|
|
|
61
|
-
|
|
62
|
-
|
|
67
|
+
# Read state_dict
|
|
68
|
+
state_dict = torch.load(archive_path, map_location="cpu")
|
|
63
69
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
70
|
+
# Remove weights from the state_dict
|
|
71
|
+
if ignore_keys is not None and len(ignore_keys) > 0:
|
|
72
|
+
for key in ignore_keys:
|
|
73
|
+
if key in state_dict:
|
|
67
74
|
state_dict.pop(key)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
75
|
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
|
76
|
+
if any(k not in ignore_keys for k in missing_keys + unexpected_keys):
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Unable to load state_dict, due to non-matching keys.\n"
|
|
79
|
+
+ f"Unexpected keys: {unexpected_keys}\nMissing keys: {missing_keys}"
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
# Load weights
|
|
83
|
+
model.load_state_dict(state_dict)
|
|
74
84
|
|
|
75
85
|
|
|
76
86
|
def conv_sequence_pt(
|
doctr/models/utils/tensorflow.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Any
|
|
|
9
9
|
|
|
10
10
|
import tensorflow as tf
|
|
11
11
|
import tf2onnx
|
|
12
|
+
import validators
|
|
12
13
|
from tensorflow.keras import Model, layers
|
|
13
14
|
|
|
14
15
|
from doctr.utils.data import download_from_url
|
|
@@ -47,7 +48,7 @@ def _build_model(model: Model):
|
|
|
47
48
|
|
|
48
49
|
def load_pretrained_params(
|
|
49
50
|
model: Model,
|
|
50
|
-
|
|
51
|
+
path_or_url: str | None = None,
|
|
51
52
|
hash_prefix: str | None = None,
|
|
52
53
|
skip_mismatch: bool = False,
|
|
53
54
|
**kwargs: Any,
|
|
@@ -59,17 +60,23 @@ def load_pretrained_params(
|
|
|
59
60
|
|
|
60
61
|
Args:
|
|
61
62
|
model: the keras model to be loaded
|
|
62
|
-
|
|
63
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
63
64
|
hash_prefix: first characters of SHA256 expected hash
|
|
64
65
|
skip_mismatch: skip loading layers with mismatched shapes
|
|
65
66
|
**kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
|
|
66
67
|
"""
|
|
67
|
-
if
|
|
68
|
-
logging.warning("
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
68
|
+
if path_or_url is None:
|
|
69
|
+
logging.warning("No model URL or Path provided, using default initialization.")
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
archive_path = (
|
|
73
|
+
download_from_url(path_or_url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
|
|
74
|
+
if validators.url(path_or_url)
|
|
75
|
+
else path_or_url
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Load weights
|
|
79
|
+
model.load_weights(archive_path, skip_mismatch=skip_mismatch)
|
|
73
80
|
|
|
74
81
|
|
|
75
82
|
def conv_sequence(
|