python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +8 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +7 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +4 -5
  17. doctr/datasets/ic13.py +4 -5
  18. doctr/datasets/iiit5k.py +6 -5
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +6 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +6 -5
  27. doctr/datasets/svhn.py +6 -5
  28. doctr/datasets/svt.py +4 -5
  29. doctr/datasets/synthtext.py +4 -5
  30. doctr/datasets/utils.py +34 -29
  31. doctr/datasets/vocabs.py +17 -7
  32. doctr/datasets/wildreceipt.py +14 -10
  33. doctr/file_utils.py +2 -7
  34. doctr/io/elements.py +59 -79
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +30 -48
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +8 -11
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +5 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +8 -21
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +6 -8
  52. doctr/models/classification/predictor/tensorflow.py +6 -8
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +20 -31
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +8 -15
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +9 -12
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +6 -12
  65. doctr/models/classification/zoo.py +19 -14
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +14 -26
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +14 -23
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +5 -6
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +3 -7
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +4 -5
  91. doctr/models/kie_predictor/pytorch.py +18 -19
  92. doctr/models/kie_predictor/tensorflow.py +13 -14
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -10
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  101. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +28 -29
  104. doctr/models/predictor/pytorch.py +12 -13
  105. doctr/models/predictor/tensorflow.py +8 -9
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +10 -14
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +11 -23
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +12 -22
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +16 -22
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +12 -21
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +12 -20
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +14 -17
  136. doctr/models/utils/tensorflow.py +17 -16
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +20 -28
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +58 -22
  145. doctr/transforms/modules/tensorflow.py +18 -32
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +16 -47
  150. doctr/utils/metrics.py +17 -37
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +9 -13
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.10.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -21,13 +21,12 @@ class RecognitionPredictor(NestedObject):
21
21
  """Implements an object able to identify character sequences in images
22
22
 
23
23
  Args:
24
- ----
25
24
  pre_processor: transform inputs for easier batched model inference
26
25
  model: core detection architecture
27
26
  split_wide_crops: wether to use crop splitting for high aspect ratio crops
28
27
  """
29
28
 
30
- _children_names: List[str] = ["pre_processor", "model"]
29
+ _children_names: list[str] = ["pre_processor", "model"]
31
30
 
32
31
  def __init__(
33
32
  self,
@@ -45,9 +44,9 @@ class RecognitionPredictor(NestedObject):
45
44
 
46
45
  def __call__(
47
46
  self,
48
- crops: List[Union[np.ndarray, tf.Tensor]],
47
+ crops: list[np.ndarray | tf.Tensor],
49
48
  **kwargs: Any,
50
- ) -> List[Tuple[str, float]]:
49
+ ) -> list[tuple[str, float]]:
51
50
  if len(crops) == 0:
52
51
  return []
53
52
  # Dimension check
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ from collections.abc import Callable
6
7
  from copy import deepcopy
7
- from typing import Any, Callable, Dict, List, Optional, Tuple
8
+ from typing import Any
8
9
 
9
10
  import torch
10
11
  from torch import nn
@@ -19,7 +20,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
19
20
 
20
21
  __all__ = ["SAR", "sar_resnet31"]
21
22
 
22
- default_cfgs: Dict[str, Dict[str, Any]] = {
23
+ default_cfgs: dict[str, dict[str, Any]] = {
23
24
  "sar_resnet31": {
24
25
  "mean": (0.694, 0.695, 0.693),
25
26
  "std": (0.299, 0.296, 0.301),
@@ -80,7 +81,6 @@ class SARDecoder(nn.Module):
80
81
  """Implements decoder module of the SAR model
81
82
 
82
83
  Args:
83
- ----
84
84
  rnn_units: number of hidden units in recurrent cells
85
85
  max_length: maximum length of a sequence
86
86
  vocab_size: number of classes in the model alphabet
@@ -114,12 +114,12 @@ class SARDecoder(nn.Module):
114
114
  self,
115
115
  features: torch.Tensor, # (N, C, H, W)
116
116
  holistic: torch.Tensor, # (N, C)
117
- gt: Optional[torch.Tensor] = None, # (N, L)
117
+ gt: torch.Tensor | None = None, # (N, L)
118
118
  ) -> torch.Tensor:
119
119
  if gt is not None:
120
120
  gt_embedding = self.embed_tgt(gt)
121
121
 
122
- logits_list: List[torch.Tensor] = []
122
+ logits_list: list[torch.Tensor] = []
123
123
 
124
124
  for t in range(self.max_length + 1): # 32
125
125
  if t == 0:
@@ -166,7 +166,6 @@ class SAR(nn.Module, RecognitionModel):
166
166
  Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
167
167
 
168
168
  Args:
169
- ----
170
169
  feature_extractor: the backbone serving as feature extractor
171
170
  vocab: vocabulary used for encoding
172
171
  rnn_units: number of hidden units in both encoder and decoder LSTM
@@ -187,9 +186,9 @@ class SAR(nn.Module, RecognitionModel):
187
186
  attention_units: int = 512,
188
187
  max_length: int = 30,
189
188
  dropout_prob: float = 0.0,
190
- input_shape: Tuple[int, int, int] = (3, 32, 128),
189
+ input_shape: tuple[int, int, int] = (3, 32, 128),
191
190
  exportable: bool = False,
192
- cfg: Optional[Dict[str, Any]] = None,
191
+ cfg: dict[str, Any] | None = None,
193
192
  ) -> None:
194
193
  super().__init__()
195
194
  self.vocab = vocab
@@ -232,10 +231,10 @@ class SAR(nn.Module, RecognitionModel):
232
231
  def forward(
233
232
  self,
234
233
  x: torch.Tensor,
235
- target: Optional[List[str]] = None,
234
+ target: list[str] | None = None,
236
235
  return_model_output: bool = False,
237
236
  return_preds: bool = False,
238
- ) -> Dict[str, Any]:
237
+ ) -> dict[str, Any]:
239
238
  features = self.feat_extractor(x)["features"]
240
239
  # NOTE: use max instead of functional max_pool2d which leads to ONNX incompatibility (kernel_size)
241
240
  # Vertical max pooling (N, C, H, W) --> (N, C, W)
@@ -254,7 +253,7 @@ class SAR(nn.Module, RecognitionModel):
254
253
 
255
254
  decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt))
256
255
 
257
- out: Dict[str, Any] = {}
256
+ out: dict[str, Any] = {}
258
257
  if self.exportable:
259
258
  out["logits"] = decoded_features
260
259
  return out
@@ -263,8 +262,13 @@ class SAR(nn.Module, RecognitionModel):
263
262
  out["out_map"] = decoded_features
264
263
 
265
264
  if target is None or return_preds:
265
+ # Disable for torch.compile compatibility
266
+ @torch.compiler.disable # type: ignore[attr-defined]
267
+ def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
268
+ return self.postprocessor(decoded_features)
269
+
266
270
  # Post-process boxes
267
- out["preds"] = self.postprocessor(decoded_features)
271
+ out["preds"] = _postprocess(decoded_features)
268
272
 
269
273
  if target is not None:
270
274
  out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
@@ -281,19 +285,17 @@ class SAR(nn.Module, RecognitionModel):
281
285
  Sequences are masked after the EOS character.
282
286
 
283
287
  Args:
284
- ----
285
288
  model_output: predicted logits of the model
286
289
  gt: the encoded tensor with gt labels
287
290
  seq_len: lengths of each gt word inside the batch
288
291
 
289
292
  Returns:
290
- -------
291
293
  The loss of the model on the batch
292
294
  """
293
295
  # Input length : number of timesteps
294
296
  input_len = model_output.shape[1]
295
297
  # Add one for additional <eos> token
296
- seq_len = seq_len + 1
298
+ seq_len = seq_len + 1 # type: ignore[assignment]
297
299
  # Compute loss
298
300
  # (N, L, vocab_size + 1)
299
301
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
@@ -308,14 +310,13 @@ class SARPostProcessor(RecognitionPostProcessor):
308
310
  """Post processor for SAR architectures
309
311
 
310
312
  Args:
311
- ----
312
313
  vocab: string containing the ordered sequence of supported characters
313
314
  """
314
315
 
315
316
  def __call__(
316
317
  self,
317
318
  logits: torch.Tensor,
318
- ) -> List[Tuple[str, float]]:
319
+ ) -> list[tuple[str, float]]:
319
320
  # compute pred with argmax for attention models
320
321
  out_idxs = logits.argmax(-1)
321
322
  # N x L
@@ -338,7 +339,7 @@ def _sar(
338
339
  backbone_fn: Callable[[bool], nn.Module],
339
340
  layer: str,
340
341
  pretrained_backbone: bool = True,
341
- ignore_keys: Optional[List[str]] = None,
342
+ ignore_keys: list[str] | None = None,
342
343
  **kwargs: Any,
343
344
  ) -> SAR:
344
345
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -379,12 +380,10 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
379
380
  >>> out = model(input_tensor)
380
381
 
381
382
  Args:
382
- ----
383
383
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
384
384
  **kwargs: keyword arguments of the SAR architecture
385
385
 
386
386
  Returns:
387
- -------
388
387
  text recognition architecture
389
388
  """
390
389
  return _sar(
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import Model, Sequential, layers
@@ -18,7 +18,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
18
18
 
19
19
  __all__ = ["SAR", "sar_resnet31"]
20
20
 
21
- default_cfgs: Dict[str, Dict[str, Any]] = {
21
+ default_cfgs: dict[str, dict[str, Any]] = {
22
22
  "sar_resnet31": {
23
23
  "mean": (0.694, 0.695, 0.693),
24
24
  "std": (0.299, 0.296, 0.301),
@@ -33,7 +33,6 @@ class SAREncoder(layers.Layer, NestedObject):
33
33
  """Implements encoder module of the SAR model
34
34
 
35
35
  Args:
36
- ----
37
36
  rnn_units: number of hidden rnn units
38
37
  dropout_prob: dropout probability
39
38
  """
@@ -58,7 +57,6 @@ class AttentionModule(layers.Layer, NestedObject):
58
57
  """Implements attention module of the SAR model
59
58
 
60
59
  Args:
61
- ----
62
60
  attention_units: number of hidden attention units
63
61
 
64
62
  """
@@ -120,7 +118,6 @@ class SARDecoder(layers.Layer, NestedObject):
120
118
  """Implements decoder module of the SAR model
121
119
 
122
120
  Args:
123
- ----
124
121
  rnn_units: number of hidden units in recurrent cells
125
122
  max_length: maximum length of a sequence
126
123
  vocab_size: number of classes in the model alphabet
@@ -159,13 +156,13 @@ class SARDecoder(layers.Layer, NestedObject):
159
156
  self,
160
157
  features: tf.Tensor,
161
158
  holistic: tf.Tensor,
162
- gt: Optional[tf.Tensor] = None,
159
+ gt: tf.Tensor | None = None,
163
160
  **kwargs: Any,
164
161
  ) -> tf.Tensor:
165
162
  if gt is not None:
166
163
  gt_embedding = self.embed_tgt(gt, **kwargs)
167
164
 
168
- logits_list: List[tf.Tensor] = []
165
+ logits_list: list[tf.Tensor] = []
169
166
 
170
167
  for t in range(self.max_length + 1): # 32
171
168
  if t == 0:
@@ -210,7 +207,6 @@ class SAR(Model, RecognitionModel):
210
207
  Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
211
208
 
212
209
  Args:
213
- ----
214
210
  feature_extractor: the backbone serving as feature extractor
215
211
  vocab: vocabulary used for encoding
216
212
  rnn_units: number of hidden units in both encoder and decoder LSTM
@@ -223,7 +219,7 @@ class SAR(Model, RecognitionModel):
223
219
  cfg: dictionary containing information about the model
224
220
  """
225
221
 
226
- _children_names: List[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"]
222
+ _children_names: list[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"]
227
223
 
228
224
  def __init__(
229
225
  self,
@@ -236,7 +232,7 @@ class SAR(Model, RecognitionModel):
236
232
  num_decoder_cells: int = 2,
237
233
  dropout_prob: float = 0.0,
238
234
  exportable: bool = False,
239
- cfg: Optional[Dict[str, Any]] = None,
235
+ cfg: dict[str, Any] | None = None,
240
236
  ) -> None:
241
237
  super().__init__()
242
238
  self.vocab = vocab
@@ -269,13 +265,11 @@ class SAR(Model, RecognitionModel):
269
265
  Sequences are masked after the EOS character.
270
266
 
271
267
  Args:
272
- ----
273
268
  gt: the encoded tensor with gt labels
274
269
  model_output: predicted logits of the model
275
270
  seq_len: lengths of each gt word inside the batch
276
271
 
277
272
  Returns:
278
- -------
279
273
  The loss of the model on the batch
280
274
  """
281
275
  # Input length : number of timesteps
@@ -296,11 +290,11 @@ class SAR(Model, RecognitionModel):
296
290
  def call(
297
291
  self,
298
292
  x: tf.Tensor,
299
- target: Optional[List[str]] = None,
293
+ target: list[str] | None = None,
300
294
  return_model_output: bool = False,
301
295
  return_preds: bool = False,
302
296
  **kwargs: Any,
303
- ) -> Dict[str, Any]:
297
+ ) -> dict[str, Any]:
304
298
  features = self.feat_extractor(x, **kwargs)
305
299
  # vertical max pooling --> (N, C, W)
306
300
  pooled_features = tf.reduce_max(features, axis=1)
@@ -318,7 +312,7 @@ class SAR(Model, RecognitionModel):
318
312
  self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
319
313
  )
320
314
 
321
- out: Dict[str, tf.Tensor] = {}
315
+ out: dict[str, tf.Tensor] = {}
322
316
  if self.exportable:
323
317
  out["logits"] = decoded_features
324
318
  return out
@@ -340,14 +334,13 @@ class SARPostProcessor(RecognitionPostProcessor):
340
334
  """Post processor for SAR architectures
341
335
 
342
336
  Args:
343
- ----
344
337
  vocab: string containing the ordered sequence of supported characters
345
338
  """
346
339
 
347
340
  def __call__(
348
341
  self,
349
342
  logits: tf.Tensor,
350
- ) -> List[Tuple[str, float]]:
343
+ ) -> list[tuple[str, float]]:
351
344
  # compute pred with argmax for attention models
352
345
  out_idxs = tf.math.argmax(logits, axis=2)
353
346
  # N x L
@@ -371,7 +364,7 @@ def _sar(
371
364
  pretrained: bool,
372
365
  backbone_fn,
373
366
  pretrained_backbone: bool = True,
374
- input_shape: Optional[Tuple[int, int, int]] = None,
367
+ input_shape: tuple[int, int, int] | None = None,
375
368
  **kwargs: Any,
376
369
  ) -> SAR:
377
370
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -414,12 +407,10 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
414
407
  >>> out = model(input_tensor)
415
408
 
416
409
  Args:
417
- ----
418
410
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
419
411
  **kwargs: keyword arguments of the SAR architecture
420
412
 
421
413
  Returns:
422
- -------
423
414
  text recognition architecture
424
415
  """
425
416
  return _sar("sar_resnet31", pretrained, resnet31, **kwargs)
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List
7
6
 
8
7
  from rapidfuzz.distance import Levenshtein
9
8
 
@@ -14,18 +13,16 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
14
13
  """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
15
14
 
16
15
  Args:
17
- ----
18
16
  a: first char seq, suffix should be similar to b's prefix.
19
17
  b: second char seq, prefix should be similar to a's suffix.
20
18
  dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
21
19
  only used when the mother sequence is splitted on a character repetition
22
20
 
23
21
  Returns:
24
- -------
25
22
  A merged character sequence.
26
23
 
27
24
  Example::
28
- >>> from doctr.model.recognition.utils import merge_sequences
25
+ >>> from doctr.models.recognition.utils import merge_sequences
29
26
  >>> merge_sequences('abcd', 'cdefgh', 1.4)
30
27
  'abcdefgh'
31
28
  >>> merge_sequences('abcdi', 'cdefgh', 1.4)
@@ -61,26 +58,24 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
61
58
  return a[:-1] + b[index - 1 :]
62
59
 
63
60
 
64
- def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str:
61
+ def merge_multi_strings(seq_list: list[str], dil_factor: float) -> str:
65
62
  """Recursively merges consecutive string sequences with overlapping characters.
66
63
 
67
64
  Args:
68
- ----
69
65
  seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
70
66
  dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
71
67
  only used when the mother sequence is splitted on a character repetition
72
68
 
73
69
  Returns:
74
- -------
75
70
  A merged character sequence
76
71
 
77
72
  Example::
78
- >>> from doctr.model.recognition.utils import merge_multi_sequences
73
+ >>> from doctr.models.recognition.utils import merge_multi_sequences
79
74
  >>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4)
80
75
  'abcdefghijkl'
81
76
  """
82
77
 
83
- def _recursive_merge(a: str, seq_list: List[str], dil_factor: float) -> str:
78
+ def _recursive_merge(a: str, seq_list: list[str], dil_factor: float) -> str:
84
79
  # Recursive version of compute_overlap
85
80
  if len(seq_list) == 1:
86
81
  return merge_strings(a, seq_list[0], dil_factor)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple
7
6
 
8
7
  import numpy as np
9
8
 
@@ -17,17 +16,15 @@ class _ViTSTR:
17
16
 
18
17
  def build_target(
19
18
  self,
20
- gts: List[str],
21
- ) -> Tuple[np.ndarray, List[int]]:
19
+ gts: list[str],
20
+ ) -> tuple[np.ndarray, list[int]]:
22
21
  """Encode a list of gts sequences into a np array and gives the corresponding*
23
22
  sequence lengths.
24
23
 
25
24
  Args:
26
- ----
27
25
  gts: list of ground-truth labels
28
26
 
29
27
  Returns:
30
- -------
31
28
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
32
29
  """
33
30
  encoded = encode_sequences(
@@ -45,7 +42,6 @@ class _ViTSTRPostProcessor(RecognitionPostProcessor):
45
42
  """Abstract class to postprocess the raw output of the model
46
43
 
47
44
  Args:
48
- ----
49
45
  vocab: string containing the ordered sequence of supported characters
50
46
  """
51
47
 
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ from collections.abc import Callable
6
7
  from copy import deepcopy
7
- from typing import Any, Callable, Dict, List, Optional, Tuple
8
+ from typing import Any
8
9
 
9
10
  import torch
10
11
  from torch import nn
@@ -19,7 +20,7 @@ from .base import _ViTSTR, _ViTSTRPostProcessor
19
20
 
20
21
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
21
22
 
22
- default_cfgs: Dict[str, Dict[str, Any]] = {
23
+ default_cfgs: dict[str, dict[str, Any]] = {
23
24
  "vitstr_small": {
24
25
  "mean": (0.694, 0.695, 0.693),
25
26
  "std": (0.299, 0.296, 0.301),
@@ -42,7 +43,6 @@ class ViTSTR(_ViTSTR, nn.Module):
42
43
  Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
43
44
 
44
45
  Args:
45
- ----
46
46
  feature_extractor: the backbone serving as feature extractor
47
47
  vocab: vocabulary used for encoding
48
48
  embedding_units: number of embedding units
@@ -59,9 +59,9 @@ class ViTSTR(_ViTSTR, nn.Module):
59
59
  vocab: str,
60
60
  embedding_units: int,
61
61
  max_length: int = 32, # different from paper
62
- input_shape: Tuple[int, int, int] = (3, 32, 128), # different from paper
62
+ input_shape: tuple[int, int, int] = (3, 32, 128), # different from paper
63
63
  exportable: bool = False,
64
- cfg: Optional[Dict[str, Any]] = None,
64
+ cfg: dict[str, Any] | None = None,
65
65
  ) -> None:
66
66
  super().__init__()
67
67
  self.vocab = vocab
@@ -77,10 +77,10 @@ class ViTSTR(_ViTSTR, nn.Module):
77
77
  def forward(
78
78
  self,
79
79
  x: torch.Tensor,
80
- target: Optional[List[str]] = None,
80
+ target: list[str] | None = None,
81
81
  return_model_output: bool = False,
82
82
  return_preds: bool = False,
83
- ) -> Dict[str, Any]:
83
+ ) -> dict[str, Any]:
84
84
  features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
85
85
 
86
86
  if target is not None:
@@ -98,7 +98,7 @@ class ViTSTR(_ViTSTR, nn.Module):
98
98
  logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1)
99
99
  decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
100
100
 
101
- out: Dict[str, Any] = {}
101
+ out: dict[str, Any] = {}
102
102
  if self.exportable:
103
103
  out["logits"] = decoded_features
104
104
  return out
@@ -107,8 +107,13 @@ class ViTSTR(_ViTSTR, nn.Module):
107
107
  out["out_map"] = decoded_features
108
108
 
109
109
  if target is None or return_preds:
110
+ # Disable for torch.compile compatibility
111
+ @torch.compiler.disable # type: ignore[attr-defined]
112
+ def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
113
+ return self.postprocessor(decoded_features)
114
+
110
115
  # Post-process boxes
111
- out["preds"] = self.postprocessor(decoded_features)
116
+ out["preds"] = _postprocess(decoded_features)
112
117
 
113
118
  if target is not None:
114
119
  out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
@@ -125,19 +130,17 @@ class ViTSTR(_ViTSTR, nn.Module):
125
130
  Sequences are masked after the EOS character.
126
131
 
127
132
  Args:
128
- ----
129
133
  model_output: predicted logits of the model
130
134
  gt: the encoded tensor with gt labels
131
135
  seq_len: lengths of each gt word inside the batch
132
136
 
133
137
  Returns:
134
- -------
135
138
  The loss of the model on the batch
136
139
  """
137
140
  # Input length : number of steps
138
141
  input_len = model_output.shape[1]
139
142
  # Add one for additional <eos> token (sos disappear in shift!)
140
- seq_len = seq_len + 1
143
+ seq_len = seq_len + 1 # type: ignore[assignment]
141
144
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
142
145
  # The "masked" first gt char is <sos>.
143
146
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -153,14 +156,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
153
156
  """Post processor for ViTSTR architecture
154
157
 
155
158
  Args:
156
- ----
157
159
  vocab: string containing the ordered sequence of supported characters
158
160
  """
159
161
 
160
162
  def __call__(
161
163
  self,
162
164
  logits: torch.Tensor,
163
- ) -> List[Tuple[str, float]]:
165
+ ) -> list[tuple[str, float]]:
164
166
  # compute pred with argmax for attention models
165
167
  out_idxs = logits.argmax(-1)
166
168
  preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
@@ -183,7 +185,7 @@ def _vitstr(
183
185
  pretrained: bool,
184
186
  backbone_fn: Callable[[bool], nn.Module],
185
187
  layer: str,
186
- ignore_keys: Optional[List[str]] = None,
188
+ ignore_keys: list[str] | None = None,
187
189
  **kwargs: Any,
188
190
  ) -> ViTSTR:
189
191
  # Patch the config
@@ -228,12 +230,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
228
230
  >>> out = model(input_tensor)
229
231
 
230
232
  Args:
231
- ----
232
233
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
233
234
  kwargs: keyword arguments of the ViTSTR architecture
234
235
 
235
236
  Returns:
236
- -------
237
237
  text recognition architecture
238
238
  """
239
239
  return _vitstr(
@@ -259,12 +259,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
259
259
  >>> out = model(input_tensor)
260
260
 
261
261
  Args:
262
- ----
263
262
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
264
263
  kwargs: keyword arguments of the ViTSTR architecture
265
264
 
266
265
  Returns:
267
- -------
268
266
  text recognition architecture
269
267
  """
270
268
  return _vitstr(