python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import Model, Sequential, layers
@@ -18,7 +18,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
18
18
 
19
19
  __all__ = ["SAR", "sar_resnet31"]
20
20
 
21
- default_cfgs: Dict[str, Dict[str, Any]] = {
21
+ default_cfgs: dict[str, dict[str, Any]] = {
22
22
  "sar_resnet31": {
23
23
  "mean": (0.694, 0.695, 0.693),
24
24
  "std": (0.299, 0.296, 0.301),
@@ -33,7 +33,6 @@ class SAREncoder(layers.Layer, NestedObject):
33
33
  """Implements encoder module of the SAR model
34
34
 
35
35
  Args:
36
- ----
37
36
  rnn_units: number of hidden rnn units
38
37
  dropout_prob: dropout probability
39
38
  """
@@ -58,7 +57,6 @@ class AttentionModule(layers.Layer, NestedObject):
58
57
  """Implements attention module of the SAR model
59
58
 
60
59
  Args:
61
- ----
62
60
  attention_units: number of hidden attention units
63
61
 
64
62
  """
@@ -120,7 +118,6 @@ class SARDecoder(layers.Layer, NestedObject):
120
118
  """Implements decoder module of the SAR model
121
119
 
122
120
  Args:
123
- ----
124
121
  rnn_units: number of hidden units in recurrent cells
125
122
  max_length: maximum length of a sequence
126
123
  vocab_size: number of classes in the model alphabet
@@ -159,13 +156,13 @@ class SARDecoder(layers.Layer, NestedObject):
159
156
  self,
160
157
  features: tf.Tensor,
161
158
  holistic: tf.Tensor,
162
- gt: Optional[tf.Tensor] = None,
159
+ gt: tf.Tensor | None = None,
163
160
  **kwargs: Any,
164
161
  ) -> tf.Tensor:
165
162
  if gt is not None:
166
163
  gt_embedding = self.embed_tgt(gt, **kwargs)
167
164
 
168
- logits_list: List[tf.Tensor] = []
165
+ logits_list: list[tf.Tensor] = []
169
166
 
170
167
  for t in range(self.max_length + 1): # 32
171
168
  if t == 0:
@@ -210,7 +207,6 @@ class SAR(Model, RecognitionModel):
210
207
  Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
211
208
 
212
209
  Args:
213
- ----
214
210
  feature_extractor: the backbone serving as feature extractor
215
211
  vocab: vocabulary used for encoding
216
212
  rnn_units: number of hidden units in both encoder and decoder LSTM
@@ -223,7 +219,7 @@ class SAR(Model, RecognitionModel):
223
219
  cfg: dictionary containing information about the model
224
220
  """
225
221
 
226
- _children_names: List[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"]
222
+ _children_names: list[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"]
227
223
 
228
224
  def __init__(
229
225
  self,
@@ -236,7 +232,7 @@ class SAR(Model, RecognitionModel):
236
232
  num_decoder_cells: int = 2,
237
233
  dropout_prob: float = 0.0,
238
234
  exportable: bool = False,
239
- cfg: Optional[Dict[str, Any]] = None,
235
+ cfg: dict[str, Any] | None = None,
240
236
  ) -> None:
241
237
  super().__init__()
242
238
  self.vocab = vocab
@@ -259,6 +255,15 @@ class SAR(Model, RecognitionModel):
259
255
 
260
256
  self.postprocessor = SARPostProcessor(vocab=vocab)
261
257
 
258
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
259
+ """Load pretrained parameters onto the model
260
+
261
+ Args:
262
+ path_or_url: the path or URL to the model parameters (checkpoint)
263
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
264
+ """
265
+ load_pretrained_params(self, path_or_url, **kwargs)
266
+
262
267
  @staticmethod
263
268
  def compute_loss(
264
269
  model_output: tf.Tensor,
@@ -269,13 +274,11 @@ class SAR(Model, RecognitionModel):
269
274
  Sequences are masked after the EOS character.
270
275
 
271
276
  Args:
272
- ----
273
277
  gt: the encoded tensor with gt labels
274
278
  model_output: predicted logits of the model
275
279
  seq_len: lengths of each gt word inside the batch
276
280
 
277
281
  Returns:
278
- -------
279
282
  The loss of the model on the batch
280
283
  """
281
284
  # Input length : number of timesteps
@@ -296,11 +299,11 @@ class SAR(Model, RecognitionModel):
296
299
  def call(
297
300
  self,
298
301
  x: tf.Tensor,
299
- target: Optional[List[str]] = None,
302
+ target: list[str] | None = None,
300
303
  return_model_output: bool = False,
301
304
  return_preds: bool = False,
302
305
  **kwargs: Any,
303
- ) -> Dict[str, Any]:
306
+ ) -> dict[str, Any]:
304
307
  features = self.feat_extractor(x, **kwargs)
305
308
  # vertical max pooling --> (N, C, W)
306
309
  pooled_features = tf.reduce_max(features, axis=1)
@@ -318,7 +321,7 @@ class SAR(Model, RecognitionModel):
318
321
  self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
319
322
  )
320
323
 
321
- out: Dict[str, tf.Tensor] = {}
324
+ out: dict[str, tf.Tensor] = {}
322
325
  if self.exportable:
323
326
  out["logits"] = decoded_features
324
327
  return out
@@ -340,14 +343,13 @@ class SARPostProcessor(RecognitionPostProcessor):
340
343
  """Post processor for SAR architectures
341
344
 
342
345
  Args:
343
- ----
344
346
  vocab: string containing the ordered sequence of supported characters
345
347
  """
346
348
 
347
349
  def __call__(
348
350
  self,
349
351
  logits: tf.Tensor,
350
- ) -> List[Tuple[str, float]]:
352
+ ) -> list[tuple[str, float]]:
351
353
  # compute pred with argmax for attention models
352
354
  out_idxs = tf.math.argmax(logits, axis=2)
353
355
  # N x L
@@ -371,7 +373,7 @@ def _sar(
371
373
  pretrained: bool,
372
374
  backbone_fn,
373
375
  pretrained_backbone: bool = True,
374
- input_shape: Optional[Tuple[int, int, int]] = None,
376
+ input_shape: tuple[int, int, int] | None = None,
375
377
  **kwargs: Any,
376
378
  ) -> SAR:
377
379
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -396,9 +398,7 @@ def _sar(
396
398
  # Load pretrained parameters
397
399
  if pretrained:
398
400
  # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
399
- load_pretrained_params(
400
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
401
- )
401
+ model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
402
402
 
403
403
  return model
404
404
 
@@ -414,12 +414,10 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
414
414
  >>> out = model(input_tensor)
415
415
 
416
416
  Args:
417
- ----
418
417
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
419
418
  **kwargs: keyword arguments of the SAR architecture
420
419
 
421
420
  Returns:
422
- -------
423
421
  text recognition architecture
424
422
  """
425
423
  return _sar("sar_resnet31", pretrained, resnet31, **kwargs)
@@ -1,89 +1,93 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List
7
6
 
8
- from rapidfuzz.distance import Levenshtein
7
+ from rapidfuzz.distance import Hamming
9
8
 
10
9
  __all__ = ["merge_strings", "merge_multi_strings"]
11
10
 
12
11
 
13
- def merge_strings(a: str, b: str, dil_factor: float) -> str:
12
+ def merge_strings(a: str, b: str, overlap_ratio: float) -> str:
14
13
  """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
15
14
 
16
15
  Args:
17
- ----
18
16
  a: first char seq, suffix should be similar to b's prefix.
19
17
  b: second char seq, prefix should be similar to a's suffix.
20
- dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
21
- only used when the mother sequence is splitted on a character repetition
18
+ overlap_ratio: estimated ratio of overlapping characters.
22
19
 
23
20
  Returns:
24
- -------
25
21
  A merged character sequence.
26
22
 
27
23
  Example::
28
- >>> from doctr.model.recognition.utils import merge_sequences
29
- >>> merge_sequences('abcd', 'cdefgh', 1.4)
24
+ >>> from doctr.models.recognition.utils import merge_strings
25
+ >>> merge_strings('abcd', 'cdefgh', 0.5)
30
26
  'abcdefgh'
31
- >>> merge_sequences('abcdi', 'cdefgh', 1.4)
27
+ >>> merge_strings('abcdi', 'cdefgh', 0.5)
32
28
  'abcdefgh'
33
29
  """
34
30
  seq_len = min(len(a), len(b))
35
- if seq_len == 0: # One sequence is empty, return the other
36
- return b if len(a) == 0 else a
37
-
38
- # Initialize merging index and corresponding score (mean Levenstein)
39
- min_score, index = 1.0, 0 # No overlap, just concatenate
40
-
41
- scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)]
42
-
43
- # Edge case (split in the middle of char repetitions): if it starts with 2 or more 0
44
- if len(scores) > 1 and (scores[0], scores[1]) == (0, 0):
45
- # Compute n_overlap (number of overlapping chars, geometrically determined)
46
- n_overlap = round(len(b) * (dil_factor - 1) / dil_factor)
47
- # Find the number of consecutive zeros in the scores list
48
- # Impossible to have a zero after a non-zero score in that case
49
- n_zeros = sum(val == 0 for val in scores)
50
- # Index is bounded by the geometrical overlap to avoid collapsing repetitions
51
- min_score, index = 0, min(n_zeros, n_overlap)
52
-
53
- else: # Common case: choose the min score index
54
- for i, score in enumerate(scores):
55
- if score < min_score:
56
- min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char
57
-
58
- # Merge with correct overlap
59
- if index == 0:
31
+ if seq_len <= 1: # One sequence is empty or will be after cropping in next step, return both to keep data
60
32
  return a + b
61
- return a[:-1] + b[index - 1 :]
62
33
 
34
+ a_crop, b_crop = a[:-1], b[1:] # Remove last letter of "a" and first of "b", because they might be cut off
35
+ max_overlap = min(len(a_crop), len(b_crop))
63
36
 
64
- def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str:
65
- """Recursively merges consecutive string sequences with overlapping characters.
37
+ # Compute Hamming distances for all possible overlaps
38
+ scores = [Hamming.distance(a_crop[-i:], b_crop[:i], processor=None) for i in range(1, max_overlap + 1)]
39
+
40
+ # Find zero-score matches
41
+ zero_matches = [i for i, score in enumerate(scores) if score == 0]
42
+
43
+ expected_overlap = round(len(b) * overlap_ratio) - 3 # adjust for cropping and index
44
+
45
+ # Case 1: One perfect match - exactly one zero score - just merge there
46
+ if len(zero_matches) == 1:
47
+ i = zero_matches[0]
48
+ return a_crop + b_crop[i + 1 :]
49
+
50
+ # Case 2: Multiple perfect matches - likely due to repeated characters.
51
+ # Use the estimated overlap length to choose the match closest to the expected alignment.
52
+ elif len(zero_matches) > 1:
53
+ best_i = min(zero_matches, key=lambda x: abs(x - expected_overlap))
54
+ return a_crop + b_crop[best_i + 1 :]
55
+
56
+ # Case 3: Absence of zero scores indicates that the same character in the image was recognized differently OR that
57
+ # the overlap was too small and we just need to merge the crops fully
58
+ if expected_overlap < -1:
59
+ return a + b
60
+ elif expected_overlap < 0:
61
+ return a_crop + b_crop
62
+
63
+ # Find best overlap by minimizing Hamming distance + distance from expected overlap size
64
+ combined_scores = [score + abs(i - expected_overlap) for i, score in enumerate(scores)]
65
+ best_i = combined_scores.index(min(combined_scores))
66
+ return a_crop + b_crop[best_i + 1 :]
67
+
68
+
69
+ def merge_multi_strings(seq_list: list[str], overlap_ratio: float, last_overlap_ratio: float) -> str:
70
+ """
71
+ Merges consecutive string sequences with overlapping characters.
66
72
 
67
73
  Args:
68
- ----
69
74
  seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
70
- dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
71
- only used when the mother sequence is splitted on a character repetition
75
+ overlap_ratio: Estimated ratio of overlapping letters between neighboring strings.
76
+ last_overlap_ratio: Estimated ratio of overlapping letters for the last element in seq_list.
72
77
 
73
78
  Returns:
74
- -------
75
79
  A merged character sequence
76
80
 
77
81
  Example::
78
- >>> from doctr.model.recognition.utils import merge_multi_sequences
79
- >>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4)
82
+ >>> from doctr.models.recognition.utils import merge_multi_strings
83
+ >>> merge_multi_strings(['abc', 'bcdef', 'difghi', 'aijkl'], 0.5, 0.1)
80
84
  'abcdefghijkl'
81
85
  """
82
-
83
- def _recursive_merge(a: str, seq_list: List[str], dil_factor: float) -> str:
84
- # Recursive version of compute_overlap
85
- if len(seq_list) == 1:
86
- return merge_strings(a, seq_list[0], dil_factor)
87
- return _recursive_merge(merge_strings(a, seq_list[0], dil_factor), seq_list[1:], dil_factor)
88
-
89
- return _recursive_merge("", seq_list, dil_factor)
86
+ if not seq_list:
87
+ return ""
88
+ result = seq_list[0]
89
+ for i in range(1, len(seq_list)):
90
+ text_b = seq_list[i]
91
+ ratio = last_overlap_ratio if i == len(seq_list) - 1 else overlap_ratio
92
+ result = merge_strings(result, text_b, ratio)
93
+ return result
@@ -0,0 +1,4 @@
1
+ from doctr.file_utils import is_torch_available
2
+
3
+ if is_torch_available():
4
+ from .pytorch import *
@@ -0,0 +1,277 @@
1
+ # Copyright (C) 2021-2025, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from collections.abc import Callable
7
+ from copy import deepcopy
8
+ from itertools import groupby
9
+ from typing import Any
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torchvision.models._utils import IntermediateLayerGetter
15
+
16
+ from doctr.datasets import VOCABS, decode_sequence
17
+
18
+ from ...classification import vip_tiny
19
+ from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
20
+ from ..core import RecognitionModel, RecognitionPostProcessor
21
+
22
+ __all__ = ["VIPTR", "viptr_tiny"]
23
+
24
+
25
+ default_cfgs: dict[str, dict[str, Any]] = {
26
+ "viptr_tiny": {
27
+ "mean": (0.694, 0.695, 0.693),
28
+ "std": (0.299, 0.296, 0.301),
29
+ "input_shape": (3, 32, 128),
30
+ "vocab": VOCABS["french"],
31
+ "url": "https://doctr-static.mindee.com/models?id=v0.11.0/viptr_tiny-1cb2515e.pt&src=0",
32
+ },
33
+ }
34
+
35
+
36
+ class VIPTRPostProcessor(RecognitionPostProcessor):
37
+ """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
38
+
39
+ Args:
40
+ vocab: string containing the ordered sequence of supported characters
41
+ """
42
+
43
+ @staticmethod
44
+ def ctc_best_path(
45
+ logits: torch.Tensor,
46
+ vocab: str = VOCABS["french"],
47
+ blank: int = 0,
48
+ ) -> list[tuple[str, float]]:
49
+ """Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
50
+ <https://github.com/githubharald/CTCDecoder>`_.
51
+
52
+ Args:
53
+ logits: model output, shape: N x T x C
54
+ vocab: vocabulary to use
55
+ blank: index of blank label
56
+
57
+ Returns:
58
+ A list of tuples: (word, confidence)
59
+ """
60
+ # Gather the most confident characters, and assign the smallest conf among those to the sequence prob
61
+ probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values
62
+
63
+ # collapse best path (using itertools.groupby), map to chars, join char list to string
64
+ words = [
65
+ decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
66
+ for seq in torch.argmax(logits, dim=-1)
67
+ ]
68
+
69
+ return list(zip(words, probs.tolist()))
70
+
71
+ def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
72
+ """Performs decoding of raw output with CTC and decoding of CTC predictions
73
+ with label_to_idx mapping dictionnary
74
+
75
+ Args:
76
+ logits: raw output of the model, shape (N, C + 1, seq_len)
77
+
78
+ Returns:
79
+ A tuple of 2 lists: a list of str (words) and a list of float (probs)
80
+
81
+ """
82
+ # Decode CTC
83
+ return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab))
84
+
85
+
86
+ class VIPTR(RecognitionModel, nn.Module):
87
+ """Implements a VIPTR architecture as described in `"A Vision Permutable Extractor for Fast and Efficient
88
+ Scene Text Recognition" <https://arxiv.org/abs/2401.10110>`_.
89
+
90
+ Args:
91
+ feature_extractor: the backbone serving as feature extractor
92
+ vocab: vocabulary used for encoding
93
+ input_shape: input shape of the image
94
+ exportable: onnx exportable returns only logits
95
+ cfg: configuration dictionary
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ feature_extractor: nn.Module,
101
+ vocab: str,
102
+ input_shape: tuple[int, int, int] = (3, 32, 128),
103
+ exportable: bool = False,
104
+ cfg: dict[str, Any] | None = None,
105
+ ):
106
+ super().__init__()
107
+ self.vocab = vocab
108
+ self.exportable = exportable
109
+ self.cfg = cfg
110
+ self.max_length = 32
111
+ self.vocab_size = len(vocab)
112
+
113
+ self.feat_extractor = feature_extractor
114
+ with torch.inference_mode():
115
+ embedding_units = self.feat_extractor(torch.zeros((1, *input_shape)))["features"].shape[-1]
116
+
117
+ self.postprocessor = VIPTRPostProcessor(vocab=self.vocab)
118
+ self.head = nn.Linear(embedding_units, len(self.vocab) + 1) # +1 for PAD
119
+
120
+ for n, m in self.named_modules():
121
+ # Don't override the initialization of the backbone
122
+ if n.startswith("feat_extractor."):
123
+ continue
124
+ if isinstance(m, nn.Linear):
125
+ nn.init.trunc_normal_(m.weight, std=0.02)
126
+ if m.bias is not None:
127
+ nn.init.zeros_(m.bias)
128
+
129
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
130
+ """Load pretrained parameters onto the model
131
+
132
+ Args:
133
+ path_or_url: the path or URL to the model parameters (checkpoint)
134
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
135
+ """
136
+ load_pretrained_params(self, path_or_url, **kwargs)
137
+
138
+ def forward(
139
+ self,
140
+ x: torch.Tensor,
141
+ target: list[str] | None = None,
142
+ return_model_output: bool = False,
143
+ return_preds: bool = False,
144
+ ) -> dict[str, Any]:
145
+ if target is not None:
146
+ _gt, _seq_len = self.build_target(target)
147
+ gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len)
148
+ gt, seq_len = gt.to(x.device), seq_len.to(x.device)
149
+
150
+ if self.training and target is None:
151
+ raise ValueError("Need to provide labels during training")
152
+
153
+ features = self.feat_extractor(x)["features"] # (B, max_len, embed_dim)
154
+ B, N, E = features.size()
155
+ logits = self.head(features).view(B, N, len(self.vocab) + 1)
156
+
157
+ decoded_features = _bf16_to_float32(logits)
158
+
159
+ out: dict[str, Any] = {}
160
+ if self.exportable:
161
+ out["logits"] = decoded_features
162
+ return out
163
+
164
+ if return_model_output:
165
+ out["out_map"] = decoded_features
166
+
167
+ if target is None or return_preds:
168
+ # Disable for torch.compile compatibility
169
+ @torch.compiler.disable # type: ignore[attr-defined]
170
+ def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
171
+ return self.postprocessor(decoded_features)
172
+
173
+ # Post-process boxes
174
+ out["preds"] = _postprocess(decoded_features)
175
+
176
+ if target is not None:
177
+ out["loss"] = self.compute_loss(decoded_features, gt, seq_len, len(self.vocab))
178
+
179
+ return out
180
+
181
+ @staticmethod
182
+ def compute_loss(
183
+ model_output: torch.Tensor,
184
+ gt: torch.Tensor,
185
+ seq_len: torch.Tensor,
186
+ blank_idx: int = 0,
187
+ ) -> torch.Tensor:
188
+ """Compute CTC loss for the model.
189
+
190
+ Args:
191
+ model_output: predicted logits of the model
192
+ gt: ground truth tensor
193
+ seq_len: sequence lengths of the ground truth
194
+ blank_idx: index of the blank label
195
+
196
+ Returns:
197
+ The loss of the model on the batch
198
+ """
199
+ batch_len = model_output.shape[0]
200
+ input_length = model_output.shape[1] * torch.ones(size=(batch_len,), dtype=torch.int32)
201
+ # N x T x C -> T x N x C
202
+ logits = model_output.permute(1, 0, 2)
203
+ probs = F.log_softmax(logits, dim=-1)
204
+ ctc_loss = F.ctc_loss(
205
+ probs,
206
+ gt,
207
+ input_length,
208
+ seq_len,
209
+ blank_idx,
210
+ zero_infinity=True,
211
+ )
212
+
213
+ return ctc_loss
214
+
215
+
216
+ def _viptr(
217
+ arch: str,
218
+ pretrained: bool,
219
+ backbone_fn: Callable[[bool], nn.Module],
220
+ layer: str,
221
+ pretrained_backbone: bool = True,
222
+ ignore_keys: list[str] | None = None,
223
+ **kwargs: Any,
224
+ ) -> VIPTR:
225
+ pretrained_backbone = pretrained_backbone and not pretrained
226
+
227
+ # Patch the config
228
+ _cfg = deepcopy(default_cfgs[arch])
229
+ _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
230
+ _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
231
+
232
+ # Feature extractor
233
+ feat_extractor = IntermediateLayerGetter(
234
+ backbone_fn(pretrained_backbone, input_shape=_cfg["input_shape"]), # type: ignore[call-arg]
235
+ {layer: "features"},
236
+ )
237
+
238
+ kwargs["vocab"] = _cfg["vocab"]
239
+ kwargs["input_shape"] = _cfg["input_shape"]
240
+
241
+ model = VIPTR(feat_extractor, cfg=_cfg, **kwargs)
242
+
243
+ # Load pretrained parameters
244
+ if pretrained:
245
+ # The number of classes is not the same as the number of classes in the pretrained model =>
246
+ # remove the last layer weights
247
+ _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
248
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
249
+
250
+ return model
251
+
252
+
253
+ def viptr_tiny(pretrained: bool = False, **kwargs: Any) -> VIPTR:
254
+ """VIPTR-Tiny as described in `"A Vision Permutable Extractor for Fast and Efficient Scene Text Recognition"
255
+ <https://arxiv.org/abs/2401.10110>`_.
256
+
257
+ >>> import torch
258
+ >>> from doctr.models import viptr_tiny
259
+ >>> model = viptr_tiny(pretrained=False)
260
+ >>> input_tensor = torch.rand((1, 3, 32, 128))
261
+ >>> out = model(input_tensor)
262
+
263
+ Args:
264
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
265
+ **kwargs: keyword arguments of the VIPTR architecture
266
+
267
+ Returns:
268
+ VIPTR: a VIPTR model instance
269
+ """
270
+ return _viptr(
271
+ "viptr_tiny",
272
+ pretrained,
273
+ vip_tiny,
274
+ "5",
275
+ ignore_keys=["head.weight", "head.bias"],
276
+ **kwargs,
277
+ )
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple
7
6
 
8
7
  import numpy as np
9
8
 
@@ -17,17 +16,15 @@ class _ViTSTR:
17
16
 
18
17
  def build_target(
19
18
  self,
20
- gts: List[str],
21
- ) -> Tuple[np.ndarray, List[int]]:
19
+ gts: list[str],
20
+ ) -> tuple[np.ndarray, list[int]]:
22
21
  """Encode a list of gts sequences into a np array and gives the corresponding*
23
22
  sequence lengths.
24
23
 
25
24
  Args:
26
- ----
27
25
  gts: list of ground-truth labels
28
26
 
29
27
  Returns:
30
- -------
31
28
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
32
29
  """
33
30
  encoded = encode_sequences(
@@ -45,7 +42,6 @@ class _ViTSTRPostProcessor(RecognitionPostProcessor):
45
42
  """Abstract class to postprocess the raw output of the model
46
43
 
47
44
  Args:
48
- ----
49
45
  vocab: string containing the ordered sequence of supported characters
50
46
  """
51
47