python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -38,13 +38,13 @@ class RecognitionPredictor(nn.Module):
38
38
  self.model = model.eval()
39
39
  self.split_wide_crops = split_wide_crops
40
40
  self.critical_ar = 8 # Critical aspect ratio
41
- self.dil_factor = 1.4 # Dilation factor to overlap the crops
41
+ self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
42
42
  self.target_ar = 6 # Target aspect ratio
43
43
 
44
44
  @torch.inference_mode()
45
45
  def forward(
46
46
  self,
47
- crops: Sequence[np.ndarray | torch.Tensor],
47
+ crops: Sequence[np.ndarray],
48
48
  **kwargs: Any,
49
49
  ) -> list[tuple[str, float]]:
50
50
  if len(crops) == 0:
@@ -60,8 +60,7 @@ class RecognitionPredictor(nn.Module):
60
60
  crops, # type: ignore[arg-type]
61
61
  self.critical_ar,
62
62
  self.target_ar,
63
- self.dil_factor,
64
- isinstance(crops[0], np.ndarray),
63
+ self.overlap_ratio,
65
64
  )
66
65
  if remapped:
67
66
  crops = new_crops
@@ -81,6 +80,6 @@ class RecognitionPredictor(nn.Module):
81
80
 
82
81
  # Remap crops
83
82
  if self.split_wide_crops and remapped:
84
- out = remap_preds(out, crop_map, self.dil_factor)
83
+ out = remap_preds(out, crop_map, self.overlap_ratio)
85
84
 
86
85
  return out
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -15,7 +15,7 @@ from torchvision.models._utils import IntermediateLayerGetter
15
15
  from doctr.datasets import VOCABS
16
16
 
17
17
  from ...classification import resnet31
18
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
18
+ from ...utils import _bf16_to_float32, load_pretrained_params
19
19
  from ..core import RecognitionModel, RecognitionPostProcessor
20
20
 
21
21
  __all__ = ["SAR", "sar_resnet31"]
@@ -228,6 +228,15 @@ class SAR(nn.Module, RecognitionModel):
228
228
  nn.init.constant_(m.weight, 1)
229
229
  nn.init.constant_(m.bias, 0)
230
230
 
231
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
232
+ """Load pretrained parameters onto the model
233
+
234
+ Args:
235
+ path_or_url: the path or URL to the model parameters (checkpoint)
236
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
237
+ """
238
+ load_pretrained_params(self, path_or_url, **kwargs)
239
+
231
240
  def forward(
232
241
  self,
233
242
  x: torch.Tensor,
@@ -263,7 +272,7 @@ class SAR(nn.Module, RecognitionModel):
263
272
 
264
273
  if target is None or return_preds:
265
274
  # Disable for torch.compile compatibility
266
- @torch.compiler.disable # type: ignore[attr-defined]
275
+ @torch.compiler.disable
267
276
  def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
268
277
  return self.postprocessor(decoded_features)
269
278
 
@@ -295,7 +304,7 @@ class SAR(nn.Module, RecognitionModel):
295
304
  # Input length : number of timesteps
296
305
  input_len = model_output.shape[1]
297
306
  # Add one for additional <eos> token
298
- seq_len = seq_len + 1 # type: ignore[assignment]
307
+ seq_len = seq_len + 1
299
308
  # Compute loss
300
309
  # (N, L, vocab_size + 1)
301
310
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
@@ -364,7 +373,7 @@ def _sar(
364
373
  # The number of classes is not the same as the number of classes in the pretrained model =>
365
374
  # remove the last layer weights
366
375
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
367
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
376
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
368
377
 
369
378
  return model
370
379
 
@@ -4,81 +4,90 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
- from rapidfuzz.distance import Levenshtein
7
+ from rapidfuzz.distance import Hamming
8
8
 
9
9
  __all__ = ["merge_strings", "merge_multi_strings"]
10
10
 
11
11
 
12
- def merge_strings(a: str, b: str, dil_factor: float) -> str:
12
+ def merge_strings(a: str, b: str, overlap_ratio: float) -> str:
13
13
  """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
14
14
 
15
15
  Args:
16
16
  a: first char seq, suffix should be similar to b's prefix.
17
17
  b: second char seq, prefix should be similar to a's suffix.
18
- dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
19
- only used when the mother sequence is splitted on a character repetition
18
+ overlap_ratio: estimated ratio of overlapping characters.
20
19
 
21
20
  Returns:
22
21
  A merged character sequence.
23
22
 
24
23
  Example::
25
- >>> from doctr.models.recognition.utils import merge_sequences
26
- >>> merge_sequences('abcd', 'cdefgh', 1.4)
24
+ >>> from doctr.models.recognition.utils import merge_strings
25
+ >>> merge_strings('abcd', 'cdefgh', 0.5)
27
26
  'abcdefgh'
28
- >>> merge_sequences('abcdi', 'cdefgh', 1.4)
27
+ >>> merge_strings('abcdi', 'cdefgh', 0.5)
29
28
  'abcdefgh'
30
29
  """
31
30
  seq_len = min(len(a), len(b))
32
- if seq_len == 0: # One sequence is empty, return the other
33
- return b if len(a) == 0 else a
34
-
35
- # Initialize merging index and corresponding score (mean Levenstein)
36
- min_score, index = 1.0, 0 # No overlap, just concatenate
37
-
38
- scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)]
39
-
40
- # Edge case (split in the middle of char repetitions): if it starts with 2 or more 0
41
- if len(scores) > 1 and (scores[0], scores[1]) == (0, 0):
42
- # Compute n_overlap (number of overlapping chars, geometrically determined)
43
- n_overlap = round(len(b) * (dil_factor - 1) / dil_factor)
44
- # Find the number of consecutive zeros in the scores list
45
- # Impossible to have a zero after a non-zero score in that case
46
- n_zeros = sum(val == 0 for val in scores)
47
- # Index is bounded by the geometrical overlap to avoid collapsing repetitions
48
- min_score, index = 0, min(n_zeros, n_overlap)
49
-
50
- else: # Common case: choose the min score index
51
- for i, score in enumerate(scores):
52
- if score < min_score:
53
- min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char
54
-
55
- # Merge with correct overlap
56
- if index == 0:
31
+ if seq_len <= 1: # One sequence is empty or will be after cropping in next step, return both to keep data
57
32
  return a + b
58
- return a[:-1] + b[index - 1 :]
59
33
 
34
+ a_crop, b_crop = a[:-1], b[1:] # Remove last letter of "a" and first of "b", because they might be cut off
35
+ max_overlap = min(len(a_crop), len(b_crop))
60
36
 
61
- def merge_multi_strings(seq_list: list[str], dil_factor: float) -> str:
62
- """Recursively merges consecutive string sequences with overlapping characters.
37
+ # Compute Hamming distances for all possible overlaps
38
+ scores = [Hamming.distance(a_crop[-i:], b_crop[:i], processor=None) for i in range(1, max_overlap + 1)]
39
+
40
+ # Find zero-score matches
41
+ zero_matches = [i for i, score in enumerate(scores) if score == 0]
42
+
43
+ expected_overlap = round(len(b) * overlap_ratio) - 3 # adjust for cropping and index
44
+
45
+ # Case 1: One perfect match - exactly one zero score - just merge there
46
+ if len(zero_matches) == 1:
47
+ i = zero_matches[0]
48
+ return a_crop + b_crop[i + 1 :]
49
+
50
+ # Case 2: Multiple perfect matches - likely due to repeated characters.
51
+ # Use the estimated overlap length to choose the match closest to the expected alignment.
52
+ elif len(zero_matches) > 1:
53
+ best_i = min(zero_matches, key=lambda x: abs(x - expected_overlap))
54
+ return a_crop + b_crop[best_i + 1 :]
55
+
56
+ # Case 3: Absence of zero scores indicates that the same character in the image was recognized differently OR that
57
+ # the overlap was too small and we just need to merge the crops fully
58
+ if expected_overlap < -1:
59
+ return a + b
60
+ elif expected_overlap < 0:
61
+ return a_crop + b_crop
62
+
63
+ # Find best overlap by minimizing Hamming distance + distance from expected overlap size
64
+ combined_scores = [score + abs(i - expected_overlap) for i, score in enumerate(scores)]
65
+ best_i = combined_scores.index(min(combined_scores))
66
+ return a_crop + b_crop[best_i + 1 :]
67
+
68
+
69
+ def merge_multi_strings(seq_list: list[str], overlap_ratio: float, last_overlap_ratio: float) -> str:
70
+ """
71
+ Merges consecutive string sequences with overlapping characters.
63
72
 
64
73
  Args:
65
74
  seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
66
- dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
67
- only used when the mother sequence is splitted on a character repetition
75
+ overlap_ratio: Estimated ratio of overlapping letters between neighboring strings.
76
+ last_overlap_ratio: Estimated ratio of overlapping letters for the last element in seq_list.
68
77
 
69
78
  Returns:
70
79
  A merged character sequence
71
80
 
72
81
  Example::
73
- >>> from doctr.models.recognition.utils import merge_multi_sequences
74
- >>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4)
82
+ >>> from doctr.models.recognition.utils import merge_multi_strings
83
+ >>> merge_multi_strings(['abc', 'bcdef', 'difghi', 'aijkl'], 0.5, 0.1)
75
84
  'abcdefghijkl'
76
85
  """
77
-
78
- def _recursive_merge(a: str, seq_list: list[str], dil_factor: float) -> str:
79
- # Recursive version of compute_overlap
80
- if len(seq_list) == 1:
81
- return merge_strings(a, seq_list[0], dil_factor)
82
- return _recursive_merge(merge_strings(a, seq_list[0], dil_factor), seq_list[1:], dil_factor)
83
-
84
- return _recursive_merge("", seq_list, dil_factor)
86
+ if not seq_list:
87
+ return ""
88
+ result = seq_list[0]
89
+ for i in range(1, len(seq_list)):
90
+ text_b = seq_list[i]
91
+ ratio = last_overlap_ratio if i == len(seq_list) - 1 else overlap_ratio
92
+ result = merge_strings(result, text_b, ratio)
93
+ return result
@@ -0,0 +1 @@
1
+ from .pytorch import *
@@ -0,0 +1,277 @@
1
+ # Copyright (C) 2021-2025, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from collections.abc import Callable
7
+ from copy import deepcopy
8
+ from itertools import groupby
9
+ from typing import Any
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torchvision.models._utils import IntermediateLayerGetter
15
+
16
+ from doctr.datasets import VOCABS, decode_sequence
17
+
18
+ from ...classification import vip_tiny
19
+ from ...utils import _bf16_to_float32, load_pretrained_params
20
+ from ..core import RecognitionModel, RecognitionPostProcessor
21
+
22
+ __all__ = ["VIPTR", "viptr_tiny"]
23
+
24
+
25
+ default_cfgs: dict[str, dict[str, Any]] = {
26
+ "viptr_tiny": {
27
+ "mean": (0.694, 0.695, 0.693),
28
+ "std": (0.299, 0.296, 0.301),
29
+ "input_shape": (3, 32, 128),
30
+ "vocab": VOCABS["french"],
31
+ "url": "https://doctr-static.mindee.com/models?id=v0.11.0/viptr_tiny-1cb2515e.pt&src=0",
32
+ },
33
+ }
34
+
35
+
36
+ class VIPTRPostProcessor(RecognitionPostProcessor):
37
+ """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
38
+
39
+ Args:
40
+ vocab: string containing the ordered sequence of supported characters
41
+ """
42
+
43
+ @staticmethod
44
+ def ctc_best_path(
45
+ logits: torch.Tensor,
46
+ vocab: str = VOCABS["french"],
47
+ blank: int = 0,
48
+ ) -> list[tuple[str, float]]:
49
+ """Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
50
+ <https://github.com/githubharald/CTCDecoder>`_.
51
+
52
+ Args:
53
+ logits: model output, shape: N x T x C
54
+ vocab: vocabulary to use
55
+ blank: index of blank label
56
+
57
+ Returns:
58
+ A list of tuples: (word, confidence)
59
+ """
60
+ # Gather the most confident characters, and assign the smallest conf among those to the sequence prob
61
+ probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values
62
+
63
+ # collapse best path (using itertools.groupby), map to chars, join char list to string
64
+ words = [
65
+ decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
66
+ for seq in torch.argmax(logits, dim=-1)
67
+ ]
68
+
69
+ return list(zip(words, probs.tolist()))
70
+
71
+ def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
72
+ """Performs decoding of raw output with CTC and decoding of CTC predictions
73
+ with label_to_idx mapping dictionary
74
+
75
+ Args:
76
+ logits: raw output of the model, shape (N, C + 1, seq_len)
77
+
78
+ Returns:
79
+ A tuple of 2 lists: a list of str (words) and a list of float (probs)
80
+
81
+ """
82
+ # Decode CTC
83
+ return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab))
84
+
85
+
86
+ class VIPTR(RecognitionModel, nn.Module):
87
+ """Implements a VIPTR architecture as described in `"A Vision Permutable Extractor for Fast and Efficient
88
+ Scene Text Recognition" <https://arxiv.org/abs/2401.10110>`_.
89
+
90
+ Args:
91
+ feature_extractor: the backbone serving as feature extractor
92
+ vocab: vocabulary used for encoding
93
+ input_shape: input shape of the image
94
+ exportable: onnx exportable returns only logits
95
+ cfg: configuration dictionary
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ feature_extractor: nn.Module,
101
+ vocab: str,
102
+ input_shape: tuple[int, int, int] = (3, 32, 128),
103
+ exportable: bool = False,
104
+ cfg: dict[str, Any] | None = None,
105
+ ):
106
+ super().__init__()
107
+ self.vocab = vocab
108
+ self.exportable = exportable
109
+ self.cfg = cfg
110
+ self.max_length = 32
111
+ self.vocab_size = len(vocab)
112
+
113
+ self.feat_extractor = feature_extractor
114
+ with torch.inference_mode():
115
+ embedding_units = self.feat_extractor(torch.zeros((1, *input_shape)))["features"].shape[-1]
116
+
117
+ self.postprocessor = VIPTRPostProcessor(vocab=self.vocab)
118
+ self.head = nn.Linear(embedding_units, len(self.vocab) + 1) # +1 for PAD
119
+
120
+ for n, m in self.named_modules():
121
+ # Don't override the initialization of the backbone
122
+ if n.startswith("feat_extractor."):
123
+ continue
124
+ if isinstance(m, nn.Linear):
125
+ nn.init.trunc_normal_(m.weight, std=0.02)
126
+ if m.bias is not None:
127
+ nn.init.zeros_(m.bias)
128
+
129
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
130
+ """Load pretrained parameters onto the model
131
+
132
+ Args:
133
+ path_or_url: the path or URL to the model parameters (checkpoint)
134
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
135
+ """
136
+ load_pretrained_params(self, path_or_url, **kwargs)
137
+
138
+ def forward(
139
+ self,
140
+ x: torch.Tensor,
141
+ target: list[str] | None = None,
142
+ return_model_output: bool = False,
143
+ return_preds: bool = False,
144
+ ) -> dict[str, Any]:
145
+ if target is not None:
146
+ _gt, _seq_len = self.build_target(target)
147
+ gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len)
148
+ gt, seq_len = gt.to(x.device), seq_len.to(x.device)
149
+
150
+ if self.training and target is None:
151
+ raise ValueError("Need to provide labels during training")
152
+
153
+ features = self.feat_extractor(x)["features"] # (B, max_len, embed_dim)
154
+ B, N, E = features.size()
155
+ logits = self.head(features).view(B, N, len(self.vocab) + 1)
156
+
157
+ decoded_features = _bf16_to_float32(logits)
158
+
159
+ out: dict[str, Any] = {}
160
+ if self.exportable:
161
+ out["logits"] = decoded_features
162
+ return out
163
+
164
+ if return_model_output:
165
+ out["out_map"] = decoded_features
166
+
167
+ if target is None or return_preds:
168
+ # Disable for torch.compile compatibility
169
+ @torch.compiler.disable
170
+ def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
171
+ return self.postprocessor(decoded_features)
172
+
173
+ # Post-process boxes
174
+ out["preds"] = _postprocess(decoded_features)
175
+
176
+ if target is not None:
177
+ out["loss"] = self.compute_loss(decoded_features, gt, seq_len, len(self.vocab))
178
+
179
+ return out
180
+
181
+ @staticmethod
182
+ def compute_loss(
183
+ model_output: torch.Tensor,
184
+ gt: torch.Tensor,
185
+ seq_len: torch.Tensor,
186
+ blank_idx: int = 0,
187
+ ) -> torch.Tensor:
188
+ """Compute CTC loss for the model.
189
+
190
+ Args:
191
+ model_output: predicted logits of the model
192
+ gt: ground truth tensor
193
+ seq_len: sequence lengths of the ground truth
194
+ blank_idx: index of the blank label
195
+
196
+ Returns:
197
+ The loss of the model on the batch
198
+ """
199
+ batch_len = model_output.shape[0]
200
+ input_length = model_output.shape[1] * torch.ones(size=(batch_len,), dtype=torch.int32)
201
+ # N x T x C -> T x N x C
202
+ logits = model_output.permute(1, 0, 2)
203
+ probs = F.log_softmax(logits, dim=-1)
204
+ ctc_loss = F.ctc_loss(
205
+ probs,
206
+ gt,
207
+ input_length,
208
+ seq_len,
209
+ blank_idx,
210
+ zero_infinity=True,
211
+ )
212
+
213
+ return ctc_loss
214
+
215
+
216
+ def _viptr(
217
+ arch: str,
218
+ pretrained: bool,
219
+ backbone_fn: Callable[[bool], nn.Module],
220
+ layer: str,
221
+ pretrained_backbone: bool = True,
222
+ ignore_keys: list[str] | None = None,
223
+ **kwargs: Any,
224
+ ) -> VIPTR:
225
+ pretrained_backbone = pretrained_backbone and not pretrained
226
+
227
+ # Patch the config
228
+ _cfg = deepcopy(default_cfgs[arch])
229
+ _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
230
+ _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
231
+
232
+ # Feature extractor
233
+ feat_extractor = IntermediateLayerGetter(
234
+ backbone_fn(pretrained_backbone, input_shape=_cfg["input_shape"]), # type: ignore[call-arg]
235
+ {layer: "features"},
236
+ )
237
+
238
+ kwargs["vocab"] = _cfg["vocab"]
239
+ kwargs["input_shape"] = _cfg["input_shape"]
240
+
241
+ model = VIPTR(feat_extractor, cfg=_cfg, **kwargs)
242
+
243
+ # Load pretrained parameters
244
+ if pretrained:
245
+ # The number of classes is not the same as the number of classes in the pretrained model =>
246
+ # remove the last layer weights
247
+ _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
248
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
249
+
250
+ return model
251
+
252
+
253
+ def viptr_tiny(pretrained: bool = False, **kwargs: Any) -> VIPTR:
254
+ """VIPTR-Tiny as described in `"A Vision Permutable Extractor for Fast and Efficient Scene Text Recognition"
255
+ <https://arxiv.org/abs/2401.10110>`_.
256
+
257
+ >>> import torch
258
+ >>> from doctr.models import viptr_tiny
259
+ >>> model = viptr_tiny(pretrained=False)
260
+ >>> input_tensor = torch.rand((1, 3, 32, 128))
261
+ >>> out = model(input_tensor)
262
+
263
+ Args:
264
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
265
+ **kwargs: keyword arguments of the VIPTR architecture
266
+
267
+ Returns:
268
+ VIPTR: a VIPTR model instance
269
+ """
270
+ return _viptr(
271
+ "viptr_tiny",
272
+ pretrained,
273
+ vip_tiny,
274
+ "5",
275
+ ignore_keys=["head.weight", "head.bias"],
276
+ **kwargs,
277
+ )
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -15,7 +15,7 @@ from torchvision.models._utils import IntermediateLayerGetter
15
15
  from doctr.datasets import VOCABS
16
16
 
17
17
  from ...classification import vit_b, vit_s
18
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
18
+ from ...utils import _bf16_to_float32, load_pretrained_params
19
19
  from .base import _ViTSTR, _ViTSTRPostProcessor
20
20
 
21
21
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
@@ -74,6 +74,15 @@ class ViTSTR(_ViTSTR, nn.Module):
74
74
 
75
75
  self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
76
76
 
77
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
78
+ """Load pretrained parameters onto the model
79
+
80
+ Args:
81
+ path_or_url: the path or URL to the model parameters (checkpoint)
82
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
83
+ """
84
+ load_pretrained_params(self, path_or_url, **kwargs)
85
+
77
86
  def forward(
78
87
  self,
79
88
  x: torch.Tensor,
@@ -108,7 +117,7 @@ class ViTSTR(_ViTSTR, nn.Module):
108
117
 
109
118
  if target is None or return_preds:
110
119
  # Disable for torch.compile compatibility
111
- @torch.compiler.disable # type: ignore[attr-defined]
120
+ @torch.compiler.disable
112
121
  def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
113
122
  return self.postprocessor(decoded_features)
114
123
 
@@ -140,7 +149,7 @@ class ViTSTR(_ViTSTR, nn.Module):
140
149
  # Input length : number of steps
141
150
  input_len = model_output.shape[1]
142
151
  # Add one for additional <eos> token (sos disappear in shift!)
143
- seq_len = seq_len + 1 # type: ignore[assignment]
152
+ seq_len = seq_len + 1
144
153
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
145
154
  # The "masked" first gt char is <sos>.
146
155
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -214,7 +223,7 @@ def _vitstr(
214
223
  # The number of classes is not the same as the number of classes in the pretrained model =>
215
224
  # remove the last layer weights
216
225
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
217
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
226
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
218
227
 
219
228
  return model
220
229
 
@@ -5,8 +5,8 @@
5
5
 
6
6
  from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available, is_torch_available
9
8
  from doctr.models.preprocessor import PreProcessor
9
+ from doctr.models.utils import _CompiledModule
10
10
 
11
11
  from .. import recognition
12
12
  from .predictor import RecognitionPredictor
@@ -23,6 +23,7 @@ ARCHS: list[str] = [
23
23
  "vitstr_small",
24
24
  "vitstr_base",
25
25
  "parseq",
26
+ "viptr_tiny",
26
27
  ]
27
28
 
28
29
 
@@ -35,12 +36,16 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
35
36
  pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
36
37
  )
37
38
  else:
38
- allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
39
- if is_torch_available():
40
- # Adding the type for torch compiled models to the allowed architectures
41
- from doctr.models.utils import _CompiledModule
42
-
43
- allowed_archs.append(_CompiledModule)
39
+ # Adding the type for torch compiled models to the allowed architectures
40
+ allowed_archs = [
41
+ recognition.CRNN,
42
+ recognition.SAR,
43
+ recognition.MASTER,
44
+ recognition.ViTSTR,
45
+ recognition.PARSeq,
46
+ recognition.VIPTR,
47
+ _CompiledModule,
48
+ ]
44
49
 
45
50
  if not isinstance(arch, tuple(allowed_archs)):
46
51
  raise ValueError(f"unknown architecture: {type(arch)}")
@@ -51,7 +56,7 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
51
56
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
52
57
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
53
58
  kwargs["batch_size"] = kwargs.get("batch_size", 128)
54
- input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:]
59
+ input_shape = _model.cfg["input_shape"][-2:]
55
60
  predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model)
56
61
 
57
62
  return predictor
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *