python-doctr 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. doctr/datasets/__init__.py +1 -0
  2. doctr/datasets/coco_text.py +139 -0
  3. doctr/datasets/cord.py +2 -1
  4. doctr/datasets/funsd.py +2 -2
  5. doctr/datasets/ic03.py +1 -1
  6. doctr/datasets/ic13.py +2 -1
  7. doctr/datasets/iiit5k.py +4 -1
  8. doctr/datasets/imgur5k.py +9 -2
  9. doctr/datasets/loader.py +1 -1
  10. doctr/datasets/ocr.py +1 -1
  11. doctr/datasets/recognition.py +1 -1
  12. doctr/datasets/svhn.py +1 -1
  13. doctr/datasets/svt.py +2 -2
  14. doctr/datasets/synthtext.py +15 -2
  15. doctr/datasets/utils.py +7 -6
  16. doctr/datasets/vocabs.py +1102 -54
  17. doctr/file_utils.py +9 -0
  18. doctr/io/elements.py +37 -3
  19. doctr/models/_utils.py +1 -1
  20. doctr/models/classification/__init__.py +1 -0
  21. doctr/models/classification/magc_resnet/pytorch.py +1 -2
  22. doctr/models/classification/magc_resnet/tensorflow.py +3 -3
  23. doctr/models/classification/mobilenet/pytorch.py +15 -1
  24. doctr/models/classification/mobilenet/tensorflow.py +11 -2
  25. doctr/models/classification/predictor/pytorch.py +1 -1
  26. doctr/models/classification/resnet/pytorch.py +26 -3
  27. doctr/models/classification/resnet/tensorflow.py +25 -4
  28. doctr/models/classification/textnet/pytorch.py +10 -1
  29. doctr/models/classification/textnet/tensorflow.py +11 -2
  30. doctr/models/classification/vgg/pytorch.py +16 -1
  31. doctr/models/classification/vgg/tensorflow.py +11 -2
  32. doctr/models/classification/vip/__init__.py +4 -0
  33. doctr/models/classification/vip/layers/__init__.py +4 -0
  34. doctr/models/classification/vip/layers/pytorch.py +615 -0
  35. doctr/models/classification/vip/pytorch.py +505 -0
  36. doctr/models/classification/vit/pytorch.py +10 -1
  37. doctr/models/classification/vit/tensorflow.py +9 -0
  38. doctr/models/classification/zoo.py +4 -0
  39. doctr/models/detection/differentiable_binarization/base.py +3 -4
  40. doctr/models/detection/differentiable_binarization/pytorch.py +10 -1
  41. doctr/models/detection/differentiable_binarization/tensorflow.py +11 -4
  42. doctr/models/detection/fast/base.py +2 -3
  43. doctr/models/detection/fast/pytorch.py +13 -4
  44. doctr/models/detection/fast/tensorflow.py +10 -2
  45. doctr/models/detection/linknet/base.py +2 -3
  46. doctr/models/detection/linknet/pytorch.py +10 -1
  47. doctr/models/detection/linknet/tensorflow.py +10 -2
  48. doctr/models/factory/hub.py +3 -3
  49. doctr/models/kie_predictor/pytorch.py +1 -1
  50. doctr/models/kie_predictor/tensorflow.py +1 -1
  51. doctr/models/modules/layers/pytorch.py +49 -1
  52. doctr/models/predictor/pytorch.py +1 -1
  53. doctr/models/predictor/tensorflow.py +1 -1
  54. doctr/models/recognition/__init__.py +1 -0
  55. doctr/models/recognition/crnn/pytorch.py +10 -1
  56. doctr/models/recognition/crnn/tensorflow.py +10 -1
  57. doctr/models/recognition/master/pytorch.py +10 -1
  58. doctr/models/recognition/master/tensorflow.py +10 -3
  59. doctr/models/recognition/parseq/pytorch.py +23 -5
  60. doctr/models/recognition/parseq/tensorflow.py +13 -5
  61. doctr/models/recognition/predictor/_utils.py +107 -45
  62. doctr/models/recognition/predictor/pytorch.py +3 -3
  63. doctr/models/recognition/predictor/tensorflow.py +3 -3
  64. doctr/models/recognition/sar/pytorch.py +10 -1
  65. doctr/models/recognition/sar/tensorflow.py +10 -3
  66. doctr/models/recognition/utils.py +56 -47
  67. doctr/models/recognition/viptr/__init__.py +4 -0
  68. doctr/models/recognition/viptr/pytorch.py +277 -0
  69. doctr/models/recognition/vitstr/pytorch.py +10 -1
  70. doctr/models/recognition/vitstr/tensorflow.py +10 -3
  71. doctr/models/recognition/zoo.py +5 -0
  72. doctr/models/utils/pytorch.py +28 -18
  73. doctr/models/utils/tensorflow.py +15 -8
  74. doctr/utils/data.py +1 -1
  75. doctr/utils/geometry.py +1 -1
  76. doctr/version.py +1 -1
  77. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +19 -3
  78. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/RECORD +82 -75
  79. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  80. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  81. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  82. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -228,6 +228,15 @@ class SAR(nn.Module, RecognitionModel):
228
228
  nn.init.constant_(m.weight, 1)
229
229
  nn.init.constant_(m.bias, 0)
230
230
 
231
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
232
+ """Load pretrained parameters onto the model
233
+
234
+ Args:
235
+ path_or_url: the path or URL to the model parameters (checkpoint)
236
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
237
+ """
238
+ load_pretrained_params(self, path_or_url, **kwargs)
239
+
231
240
  def forward(
232
241
  self,
233
242
  x: torch.Tensor,
@@ -364,7 +373,7 @@ def _sar(
364
373
  # The number of classes is not the same as the number of classes in the pretrained model =>
365
374
  # remove the last layer weights
366
375
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
367
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
376
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
368
377
 
369
378
  return model
370
379
 
@@ -255,6 +255,15 @@ class SAR(Model, RecognitionModel):
255
255
 
256
256
  self.postprocessor = SARPostProcessor(vocab=vocab)
257
257
 
258
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
259
+ """Load pretrained parameters onto the model
260
+
261
+ Args:
262
+ path_or_url: the path or URL to the model parameters (checkpoint)
263
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
264
+ """
265
+ load_pretrained_params(self, path_or_url, **kwargs)
266
+
258
267
  @staticmethod
259
268
  def compute_loss(
260
269
  model_output: tf.Tensor,
@@ -389,9 +398,7 @@ def _sar(
389
398
  # Load pretrained parameters
390
399
  if pretrained:
391
400
  # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
392
- load_pretrained_params(
393
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
394
- )
401
+ model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
395
402
 
396
403
  return model
397
404
 
@@ -4,81 +4,90 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
- from rapidfuzz.distance import Levenshtein
7
+ from rapidfuzz.distance import Hamming
8
8
 
9
9
  __all__ = ["merge_strings", "merge_multi_strings"]
10
10
 
11
11
 
12
- def merge_strings(a: str, b: str, dil_factor: float) -> str:
12
+ def merge_strings(a: str, b: str, overlap_ratio: float) -> str:
13
13
  """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
14
14
 
15
15
  Args:
16
16
  a: first char seq, suffix should be similar to b's prefix.
17
17
  b: second char seq, prefix should be similar to a's suffix.
18
- dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
19
- only used when the mother sequence is splitted on a character repetition
18
+ overlap_ratio: estimated ratio of overlapping characters.
20
19
 
21
20
  Returns:
22
21
  A merged character sequence.
23
22
 
24
23
  Example::
25
- >>> from doctr.models.recognition.utils import merge_sequences
26
- >>> merge_sequences('abcd', 'cdefgh', 1.4)
24
+ >>> from doctr.models.recognition.utils import merge_strings
25
+ >>> merge_strings('abcd', 'cdefgh', 0.5)
27
26
  'abcdefgh'
28
- >>> merge_sequences('abcdi', 'cdefgh', 1.4)
27
+ >>> merge_strings('abcdi', 'cdefgh', 0.5)
29
28
  'abcdefgh'
30
29
  """
31
30
  seq_len = min(len(a), len(b))
32
- if seq_len == 0: # One sequence is empty, return the other
33
- return b if len(a) == 0 else a
34
-
35
- # Initialize merging index and corresponding score (mean Levenstein)
36
- min_score, index = 1.0, 0 # No overlap, just concatenate
37
-
38
- scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)]
39
-
40
- # Edge case (split in the middle of char repetitions): if it starts with 2 or more 0
41
- if len(scores) > 1 and (scores[0], scores[1]) == (0, 0):
42
- # Compute n_overlap (number of overlapping chars, geometrically determined)
43
- n_overlap = round(len(b) * (dil_factor - 1) / dil_factor)
44
- # Find the number of consecutive zeros in the scores list
45
- # Impossible to have a zero after a non-zero score in that case
46
- n_zeros = sum(val == 0 for val in scores)
47
- # Index is bounded by the geometrical overlap to avoid collapsing repetitions
48
- min_score, index = 0, min(n_zeros, n_overlap)
49
-
50
- else: # Common case: choose the min score index
51
- for i, score in enumerate(scores):
52
- if score < min_score:
53
- min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char
54
-
55
- # Merge with correct overlap
56
- if index == 0:
31
+ if seq_len <= 1: # One sequence is empty or will be after cropping in next step, return both to keep data
57
32
  return a + b
58
- return a[:-1] + b[index - 1 :]
59
33
 
34
+ a_crop, b_crop = a[:-1], b[1:] # Remove last letter of "a" and first of "b", because they might be cut off
35
+ max_overlap = min(len(a_crop), len(b_crop))
60
36
 
61
- def merge_multi_strings(seq_list: list[str], dil_factor: float) -> str:
62
- """Recursively merges consecutive string sequences with overlapping characters.
37
+ # Compute Hamming distances for all possible overlaps
38
+ scores = [Hamming.distance(a_crop[-i:], b_crop[:i], processor=None) for i in range(1, max_overlap + 1)]
39
+
40
+ # Find zero-score matches
41
+ zero_matches = [i for i, score in enumerate(scores) if score == 0]
42
+
43
+ expected_overlap = round(len(b) * overlap_ratio) - 3 # adjust for cropping and index
44
+
45
+ # Case 1: One perfect match - exactly one zero score - just merge there
46
+ if len(zero_matches) == 1:
47
+ i = zero_matches[0]
48
+ return a_crop + b_crop[i + 1 :]
49
+
50
+ # Case 2: Multiple perfect matches - likely due to repeated characters.
51
+ # Use the estimated overlap length to choose the match closest to the expected alignment.
52
+ elif len(zero_matches) > 1:
53
+ best_i = min(zero_matches, key=lambda x: abs(x - expected_overlap))
54
+ return a_crop + b_crop[best_i + 1 :]
55
+
56
+ # Case 3: Absence of zero scores indicates that the same character in the image was recognized differently OR that
57
+ # the overlap was too small and we just need to merge the crops fully
58
+ if expected_overlap < -1:
59
+ return a + b
60
+ elif expected_overlap < 0:
61
+ return a_crop + b_crop
62
+
63
+ # Find best overlap by minimizing Hamming distance + distance from expected overlap size
64
+ combined_scores = [score + abs(i - expected_overlap) for i, score in enumerate(scores)]
65
+ best_i = combined_scores.index(min(combined_scores))
66
+ return a_crop + b_crop[best_i + 1 :]
67
+
68
+
69
+ def merge_multi_strings(seq_list: list[str], overlap_ratio: float, last_overlap_ratio: float) -> str:
70
+ """
71
+ Merges consecutive string sequences with overlapping characters.
63
72
 
64
73
  Args:
65
74
  seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
66
- dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
67
- only used when the mother sequence is splitted on a character repetition
75
+ overlap_ratio: Estimated ratio of overlapping letters between neighboring strings.
76
+ last_overlap_ratio: Estimated ratio of overlapping letters for the last element in seq_list.
68
77
 
69
78
  Returns:
70
79
  A merged character sequence
71
80
 
72
81
  Example::
73
- >>> from doctr.models.recognition.utils import merge_multi_sequences
74
- >>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4)
82
+ >>> from doctr.models.recognition.utils import merge_multi_strings
83
+ >>> merge_multi_strings(['abc', 'bcdef', 'difghi', 'aijkl'], 0.5, 0.1)
75
84
  'abcdefghijkl'
76
85
  """
77
-
78
- def _recursive_merge(a: str, seq_list: list[str], dil_factor: float) -> str:
79
- # Recursive version of compute_overlap
80
- if len(seq_list) == 1:
81
- return merge_strings(a, seq_list[0], dil_factor)
82
- return _recursive_merge(merge_strings(a, seq_list[0], dil_factor), seq_list[1:], dil_factor)
83
-
84
- return _recursive_merge("", seq_list, dil_factor)
86
+ if not seq_list:
87
+ return ""
88
+ result = seq_list[0]
89
+ for i in range(1, len(seq_list)):
90
+ text_b = seq_list[i]
91
+ ratio = last_overlap_ratio if i == len(seq_list) - 1 else overlap_ratio
92
+ result = merge_strings(result, text_b, ratio)
93
+ return result
@@ -0,0 +1,4 @@
1
+ from doctr.file_utils import is_torch_available
2
+
3
+ if is_torch_available():
4
+ from .pytorch import *
@@ -0,0 +1,277 @@
1
+ # Copyright (C) 2021-2025, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from collections.abc import Callable
7
+ from copy import deepcopy
8
+ from itertools import groupby
9
+ from typing import Any
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torchvision.models._utils import IntermediateLayerGetter
15
+
16
+ from doctr.datasets import VOCABS, decode_sequence
17
+
18
+ from ...classification import vip_tiny
19
+ from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
20
+ from ..core import RecognitionModel, RecognitionPostProcessor
21
+
22
+ __all__ = ["VIPTR", "viptr_tiny"]
23
+
24
+
25
+ default_cfgs: dict[str, dict[str, Any]] = {
26
+ "viptr_tiny": {
27
+ "mean": (0.694, 0.695, 0.693),
28
+ "std": (0.299, 0.296, 0.301),
29
+ "input_shape": (3, 32, 128),
30
+ "vocab": VOCABS["french"],
31
+ "url": "https://doctr-static.mindee.com/models?id=v0.11.0/viptr_tiny-1cb2515e.pt&src=0",
32
+ },
33
+ }
34
+
35
+
36
+ class VIPTRPostProcessor(RecognitionPostProcessor):
37
+ """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
38
+
39
+ Args:
40
+ vocab: string containing the ordered sequence of supported characters
41
+ """
42
+
43
+ @staticmethod
44
+ def ctc_best_path(
45
+ logits: torch.Tensor,
46
+ vocab: str = VOCABS["french"],
47
+ blank: int = 0,
48
+ ) -> list[tuple[str, float]]:
49
+ """Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
50
+ <https://github.com/githubharald/CTCDecoder>`_.
51
+
52
+ Args:
53
+ logits: model output, shape: N x T x C
54
+ vocab: vocabulary to use
55
+ blank: index of blank label
56
+
57
+ Returns:
58
+ A list of tuples: (word, confidence)
59
+ """
60
+ # Gather the most confident characters, and assign the smallest conf among those to the sequence prob
61
+ probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values
62
+
63
+ # collapse best path (using itertools.groupby), map to chars, join char list to string
64
+ words = [
65
+ decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
66
+ for seq in torch.argmax(logits, dim=-1)
67
+ ]
68
+
69
+ return list(zip(words, probs.tolist()))
70
+
71
+ def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
72
+ """Performs decoding of raw output with CTC and decoding of CTC predictions
73
+ with label_to_idx mapping dictionnary
74
+
75
+ Args:
76
+ logits: raw output of the model, shape (N, C + 1, seq_len)
77
+
78
+ Returns:
79
+ A tuple of 2 lists: a list of str (words) and a list of float (probs)
80
+
81
+ """
82
+ # Decode CTC
83
+ return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab))
84
+
85
+
86
+ class VIPTR(RecognitionModel, nn.Module):
87
+ """Implements a VIPTR architecture as described in `"A Vision Permutable Extractor for Fast and Efficient
88
+ Scene Text Recognition" <https://arxiv.org/abs/2401.10110>`_.
89
+
90
+ Args:
91
+ feature_extractor: the backbone serving as feature extractor
92
+ vocab: vocabulary used for encoding
93
+ input_shape: input shape of the image
94
+ exportable: onnx exportable returns only logits
95
+ cfg: configuration dictionary
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ feature_extractor: nn.Module,
101
+ vocab: str,
102
+ input_shape: tuple[int, int, int] = (3, 32, 128),
103
+ exportable: bool = False,
104
+ cfg: dict[str, Any] | None = None,
105
+ ):
106
+ super().__init__()
107
+ self.vocab = vocab
108
+ self.exportable = exportable
109
+ self.cfg = cfg
110
+ self.max_length = 32
111
+ self.vocab_size = len(vocab)
112
+
113
+ self.feat_extractor = feature_extractor
114
+ with torch.inference_mode():
115
+ embedding_units = self.feat_extractor(torch.zeros((1, *input_shape)))["features"].shape[-1]
116
+
117
+ self.postprocessor = VIPTRPostProcessor(vocab=self.vocab)
118
+ self.head = nn.Linear(embedding_units, len(self.vocab) + 1) # +1 for PAD
119
+
120
+ for n, m in self.named_modules():
121
+ # Don't override the initialization of the backbone
122
+ if n.startswith("feat_extractor."):
123
+ continue
124
+ if isinstance(m, nn.Linear):
125
+ nn.init.trunc_normal_(m.weight, std=0.02)
126
+ if m.bias is not None:
127
+ nn.init.zeros_(m.bias)
128
+
129
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
130
+ """Load pretrained parameters onto the model
131
+
132
+ Args:
133
+ path_or_url: the path or URL to the model parameters (checkpoint)
134
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
135
+ """
136
+ load_pretrained_params(self, path_or_url, **kwargs)
137
+
138
+ def forward(
139
+ self,
140
+ x: torch.Tensor,
141
+ target: list[str] | None = None,
142
+ return_model_output: bool = False,
143
+ return_preds: bool = False,
144
+ ) -> dict[str, Any]:
145
+ if target is not None:
146
+ _gt, _seq_len = self.build_target(target)
147
+ gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len)
148
+ gt, seq_len = gt.to(x.device), seq_len.to(x.device)
149
+
150
+ if self.training and target is None:
151
+ raise ValueError("Need to provide labels during training")
152
+
153
+ features = self.feat_extractor(x)["features"] # (B, max_len, embed_dim)
154
+ B, N, E = features.size()
155
+ logits = self.head(features).view(B, N, len(self.vocab) + 1)
156
+
157
+ decoded_features = _bf16_to_float32(logits)
158
+
159
+ out: dict[str, Any] = {}
160
+ if self.exportable:
161
+ out["logits"] = decoded_features
162
+ return out
163
+
164
+ if return_model_output:
165
+ out["out_map"] = decoded_features
166
+
167
+ if target is None or return_preds:
168
+ # Disable for torch.compile compatibility
169
+ @torch.compiler.disable # type: ignore[attr-defined]
170
+ def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
171
+ return self.postprocessor(decoded_features)
172
+
173
+ # Post-process boxes
174
+ out["preds"] = _postprocess(decoded_features)
175
+
176
+ if target is not None:
177
+ out["loss"] = self.compute_loss(decoded_features, gt, seq_len, len(self.vocab))
178
+
179
+ return out
180
+
181
+ @staticmethod
182
+ def compute_loss(
183
+ model_output: torch.Tensor,
184
+ gt: torch.Tensor,
185
+ seq_len: torch.Tensor,
186
+ blank_idx: int = 0,
187
+ ) -> torch.Tensor:
188
+ """Compute CTC loss for the model.
189
+
190
+ Args:
191
+ model_output: predicted logits of the model
192
+ gt: ground truth tensor
193
+ seq_len: sequence lengths of the ground truth
194
+ blank_idx: index of the blank label
195
+
196
+ Returns:
197
+ The loss of the model on the batch
198
+ """
199
+ batch_len = model_output.shape[0]
200
+ input_length = model_output.shape[1] * torch.ones(size=(batch_len,), dtype=torch.int32)
201
+ # N x T x C -> T x N x C
202
+ logits = model_output.permute(1, 0, 2)
203
+ probs = F.log_softmax(logits, dim=-1)
204
+ ctc_loss = F.ctc_loss(
205
+ probs,
206
+ gt,
207
+ input_length,
208
+ seq_len,
209
+ blank_idx,
210
+ zero_infinity=True,
211
+ )
212
+
213
+ return ctc_loss
214
+
215
+
216
+ def _viptr(
217
+ arch: str,
218
+ pretrained: bool,
219
+ backbone_fn: Callable[[bool], nn.Module],
220
+ layer: str,
221
+ pretrained_backbone: bool = True,
222
+ ignore_keys: list[str] | None = None,
223
+ **kwargs: Any,
224
+ ) -> VIPTR:
225
+ pretrained_backbone = pretrained_backbone and not pretrained
226
+
227
+ # Patch the config
228
+ _cfg = deepcopy(default_cfgs[arch])
229
+ _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
230
+ _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
231
+
232
+ # Feature extractor
233
+ feat_extractor = IntermediateLayerGetter(
234
+ backbone_fn(pretrained_backbone, input_shape=_cfg["input_shape"]), # type: ignore[call-arg]
235
+ {layer: "features"},
236
+ )
237
+
238
+ kwargs["vocab"] = _cfg["vocab"]
239
+ kwargs["input_shape"] = _cfg["input_shape"]
240
+
241
+ model = VIPTR(feat_extractor, cfg=_cfg, **kwargs)
242
+
243
+ # Load pretrained parameters
244
+ if pretrained:
245
+ # The number of classes is not the same as the number of classes in the pretrained model =>
246
+ # remove the last layer weights
247
+ _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
248
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
249
+
250
+ return model
251
+
252
+
253
+ def viptr_tiny(pretrained: bool = False, **kwargs: Any) -> VIPTR:
254
+ """VIPTR-Tiny as described in `"A Vision Permutable Extractor for Fast and Efficient Scene Text Recognition"
255
+ <https://arxiv.org/abs/2401.10110>`_.
256
+
257
+ >>> import torch
258
+ >>> from doctr.models import viptr_tiny
259
+ >>> model = viptr_tiny(pretrained=False)
260
+ >>> input_tensor = torch.rand((1, 3, 32, 128))
261
+ >>> out = model(input_tensor)
262
+
263
+ Args:
264
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
265
+ **kwargs: keyword arguments of the VIPTR architecture
266
+
267
+ Returns:
268
+ VIPTR: a VIPTR model instance
269
+ """
270
+ return _viptr(
271
+ "viptr_tiny",
272
+ pretrained,
273
+ vip_tiny,
274
+ "5",
275
+ ignore_keys=["head.weight", "head.bias"],
276
+ **kwargs,
277
+ )
@@ -74,6 +74,15 @@ class ViTSTR(_ViTSTR, nn.Module):
74
74
 
75
75
  self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
76
76
 
77
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
78
+ """Load pretrained parameters onto the model
79
+
80
+ Args:
81
+ path_or_url: the path or URL to the model parameters (checkpoint)
82
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
83
+ """
84
+ load_pretrained_params(self, path_or_url, **kwargs)
85
+
77
86
  def forward(
78
87
  self,
79
88
  x: torch.Tensor,
@@ -214,7 +223,7 @@ def _vitstr(
214
223
  # The number of classes is not the same as the number of classes in the pretrained model =>
215
224
  # remove the last layer weights
216
225
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
217
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
226
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
218
227
 
219
228
  return model
220
229
 
@@ -74,6 +74,15 @@ class ViTSTR(_ViTSTR, Model):
74
74
 
75
75
  self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
76
76
 
77
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
78
+ """Load pretrained parameters onto the model
79
+
80
+ Args:
81
+ path_or_url: the path or URL to the model parameters (checkpoint)
82
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
83
+ """
84
+ load_pretrained_params(self, path_or_url, **kwargs)
85
+
77
86
  @staticmethod
78
87
  def compute_loss(
79
88
  model_output: tf.Tensor,
@@ -217,9 +226,7 @@ def _vitstr(
217
226
  # Load pretrained parameters
218
227
  if pretrained:
219
228
  # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
220
- load_pretrained_params(
221
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
222
- )
229
+ model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
223
230
 
224
231
  return model
225
232
 
@@ -25,6 +25,9 @@ ARCHS: list[str] = [
25
25
  "parseq",
26
26
  ]
27
27
 
28
+ if is_torch_available():
29
+ ARCHS.extend(["viptr_tiny"])
30
+
28
31
 
29
32
  def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor:
30
33
  if isinstance(arch, str):
@@ -37,6 +40,8 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
37
40
  else:
38
41
  allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
39
42
  if is_torch_available():
43
+ # Add VIPTR which is only available in torch at the moment
44
+ allowed_archs.append(recognition.VIPTR)
40
45
  # Adding the type for torch compiled models to the allowed architectures
41
46
  from doctr.models.utils import _CompiledModule
42
47
 
@@ -7,6 +7,7 @@ import logging
7
7
  from typing import Any
8
8
 
9
9
  import torch
10
+ import validators
10
11
  from torch import nn
11
12
 
12
13
  from doctr.utils.data import download_from_url
@@ -36,7 +37,7 @@ def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
36
37
 
37
38
  def load_pretrained_params(
38
39
  model: nn.Module,
39
- url: str | None = None,
40
+ path_or_url: str | None = None,
40
41
  hash_prefix: str | None = None,
41
42
  ignore_keys: list[str] | None = None,
42
43
  **kwargs: Any,
@@ -44,33 +45,42 @@ def load_pretrained_params(
44
45
  """Load a set of parameters onto a model
45
46
 
46
47
  >>> from doctr.models import load_pretrained_params
47
- >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
48
+ >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.pt")
48
49
 
49
50
  Args:
50
51
  model: the PyTorch model to be loaded
51
- url: URL of the zipped set of parameters
52
+ path_or_url: the path or URL to the model parameters (checkpoint)
52
53
  hash_prefix: first characters of SHA256 expected hash
53
54
  ignore_keys: list of weights to be ignored from the state_dict
54
55
  **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
55
56
  """
56
- if url is None:
57
- logging.warning("Invalid model URL, using default initialization.")
58
- else:
59
- archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
57
+ if path_or_url is None:
58
+ logging.warning("No model URL or Path provided, using default initialization.")
59
+ return
60
+
61
+ archive_path = (
62
+ download_from_url(path_or_url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
63
+ if validators.url(path_or_url)
64
+ else path_or_url
65
+ )
60
66
 
61
- # Read state_dict
62
- state_dict = torch.load(archive_path, map_location="cpu")
67
+ # Read state_dict
68
+ state_dict = torch.load(archive_path, map_location="cpu")
63
69
 
64
- # Remove weights from the state_dict
65
- if ignore_keys is not None and len(ignore_keys) > 0:
66
- for key in ignore_keys:
70
+ # Remove weights from the state_dict
71
+ if ignore_keys is not None and len(ignore_keys) > 0:
72
+ for key in ignore_keys:
73
+ if key in state_dict:
67
74
  state_dict.pop(key)
68
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
69
- if set(missing_keys) != set(ignore_keys) or len(unexpected_keys) > 0:
70
- raise ValueError("unable to load state_dict, due to non-matching keys.")
71
- else:
72
- # Load weights
73
- model.load_state_dict(state_dict)
75
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
76
+ if any(k not in ignore_keys for k in missing_keys + unexpected_keys):
77
+ raise ValueError(
78
+ "Unable to load state_dict, due to non-matching keys.\n"
79
+ + f"Unexpected keys: {unexpected_keys}\nMissing keys: {missing_keys}"
80
+ )
81
+ else:
82
+ # Load weights
83
+ model.load_state_dict(state_dict)
74
84
 
75
85
 
76
86
  def conv_sequence_pt(
@@ -9,6 +9,7 @@ from typing import Any
9
9
 
10
10
  import tensorflow as tf
11
11
  import tf2onnx
12
+ import validators
12
13
  from tensorflow.keras import Model, layers
13
14
 
14
15
  from doctr.utils.data import download_from_url
@@ -47,7 +48,7 @@ def _build_model(model: Model):
47
48
 
48
49
  def load_pretrained_params(
49
50
  model: Model,
50
- url: str | None = None,
51
+ path_or_url: str | None = None,
51
52
  hash_prefix: str | None = None,
52
53
  skip_mismatch: bool = False,
53
54
  **kwargs: Any,
@@ -59,17 +60,23 @@ def load_pretrained_params(
59
60
 
60
61
  Args:
61
62
  model: the keras model to be loaded
62
- url: URL of the zipped set of parameters
63
+ path_or_url: the path or URL to the model parameters (checkpoint)
63
64
  hash_prefix: first characters of SHA256 expected hash
64
65
  skip_mismatch: skip loading layers with mismatched shapes
65
66
  **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
66
67
  """
67
- if url is None:
68
- logging.warning("Invalid model URL, using default initialization.")
69
- else:
70
- archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
71
- # Load weights
72
- model.load_weights(archive_path, skip_mismatch=skip_mismatch)
68
+ if path_or_url is None:
69
+ logging.warning("No model URL or Path provided, using default initialization.")
70
+ return
71
+
72
+ archive_path = (
73
+ download_from_url(path_or_url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
74
+ if validators.url(path_or_url)
75
+ else path_or_url
76
+ )
77
+
78
+ # Load weights
79
+ model.load_weights(archive_path, skip_mismatch=skip_mismatch)
73
80
 
74
81
 
75
82
  def conv_sequence(