python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple, Union
7
6
 
8
7
  import numpy as np
9
8
 
@@ -13,16 +12,15 @@ __all__ = ["split_crops", "remap_preds"]
13
12
 
14
13
 
15
14
  def split_crops(
16
- crops: List[np.ndarray],
15
+ crops: list[np.ndarray],
17
16
  max_ratio: float,
18
17
  target_ratio: int,
19
18
  dilation: float,
20
19
  channels_last: bool = True,
21
- ) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]:
20
+ ) -> tuple[list[np.ndarray], list[int | tuple[int, int]], bool]:
22
21
  """Chunk crops horizontally to match a given aspect ratio
23
22
 
24
23
  Args:
25
- ----
26
24
  crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
27
25
  max_ratio: the maximum aspect ratio that won't trigger the chunk
28
26
  target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
@@ -30,12 +28,11 @@ def split_crops(
30
28
  channels_last: whether the numpy array has dimensions in channels last order
31
29
 
32
30
  Returns:
33
- -------
34
31
  a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
35
32
  """
36
33
  _remap_required = False
37
- crop_map: List[Union[int, Tuple[int, int]]] = []
38
- new_crops: List[np.ndarray] = []
34
+ crop_map: list[int | tuple[int, int]] = []
35
+ new_crops: list[np.ndarray] = []
39
36
  for crop in crops:
40
37
  h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
41
38
  aspect_ratio = w / h
@@ -71,8 +68,8 @@ def split_crops(
71
68
 
72
69
 
73
70
  def remap_preds(
74
- preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float
75
- ) -> List[Tuple[str, float]]:
71
+ preds: list[tuple[str, float]], crop_map: list[int | tuple[int, int]], dilation: float
72
+ ) -> list[tuple[str, float]]:
76
73
  remapped_out = []
77
74
  for _idx in crop_map:
78
75
  # Crop hasn't been split
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Sequence, Tuple, Union
6
+ from collections.abc import Sequence
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -21,7 +22,6 @@ class RecognitionPredictor(nn.Module):
21
22
  """Implements an object able to identify character sequences in images
22
23
 
23
24
  Args:
24
- ----
25
25
  pre_processor: transform inputs for easier batched model inference
26
26
  model: core detection architecture
27
27
  split_wide_crops: wether to use crop splitting for high aspect ratio crops
@@ -44,9 +44,9 @@ class RecognitionPredictor(nn.Module):
44
44
  @torch.inference_mode()
45
45
  def forward(
46
46
  self,
47
- crops: Sequence[Union[np.ndarray, torch.Tensor]],
47
+ crops: Sequence[np.ndarray | torch.Tensor],
48
48
  **kwargs: Any,
49
- ) -> List[Tuple[str, float]]:
49
+ ) -> list[tuple[str, float]]:
50
50
  if len(crops) == 0:
51
51
  return []
52
52
  # Dimension check
@@ -67,7 +67,7 @@ class RecognitionPredictor(nn.Module):
67
67
  crops = new_crops
68
68
 
69
69
  # Resize & batch them
70
- processed_batches = self.pre_processor(crops)
70
+ processed_batches = self.pre_processor(crops) # type: ignore[arg-type]
71
71
 
72
72
  # Forward it
73
73
  _params = next(self.model.parameters())
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -21,13 +21,12 @@ class RecognitionPredictor(NestedObject):
21
21
  """Implements an object able to identify character sequences in images
22
22
 
23
23
  Args:
24
- ----
25
24
  pre_processor: transform inputs for easier batched model inference
26
25
  model: core detection architecture
27
26
  split_wide_crops: wether to use crop splitting for high aspect ratio crops
28
27
  """
29
28
 
30
- _children_names: List[str] = ["pre_processor", "model"]
29
+ _children_names: list[str] = ["pre_processor", "model"]
31
30
 
32
31
  def __init__(
33
32
  self,
@@ -45,9 +44,9 @@ class RecognitionPredictor(NestedObject):
45
44
 
46
45
  def __call__(
47
46
  self,
48
- crops: List[Union[np.ndarray, tf.Tensor]],
47
+ crops: list[np.ndarray | tf.Tensor],
49
48
  **kwargs: Any,
50
- ) -> List[Tuple[str, float]]:
49
+ ) -> list[tuple[str, float]]:
51
50
  if len(crops) == 0:
52
51
  return []
53
52
  # Dimension check
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ from collections.abc import Callable
6
7
  from copy import deepcopy
7
- from typing import Any, Callable, Dict, List, Optional, Tuple
8
+ from typing import Any
8
9
 
9
10
  import torch
10
11
  from torch import nn
@@ -19,7 +20,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
19
20
 
20
21
  __all__ = ["SAR", "sar_resnet31"]
21
22
 
22
- default_cfgs: Dict[str, Dict[str, Any]] = {
23
+ default_cfgs: dict[str, dict[str, Any]] = {
23
24
  "sar_resnet31": {
24
25
  "mean": (0.694, 0.695, 0.693),
25
26
  "std": (0.299, 0.296, 0.301),
@@ -80,7 +81,6 @@ class SARDecoder(nn.Module):
80
81
  """Implements decoder module of the SAR model
81
82
 
82
83
  Args:
83
- ----
84
84
  rnn_units: number of hidden units in recurrent cells
85
85
  max_length: maximum length of a sequence
86
86
  vocab_size: number of classes in the model alphabet
@@ -114,12 +114,12 @@ class SARDecoder(nn.Module):
114
114
  self,
115
115
  features: torch.Tensor, # (N, C, H, W)
116
116
  holistic: torch.Tensor, # (N, C)
117
- gt: Optional[torch.Tensor] = None, # (N, L)
117
+ gt: torch.Tensor | None = None, # (N, L)
118
118
  ) -> torch.Tensor:
119
119
  if gt is not None:
120
120
  gt_embedding = self.embed_tgt(gt)
121
121
 
122
- logits_list: List[torch.Tensor] = []
122
+ logits_list: list[torch.Tensor] = []
123
123
 
124
124
  for t in range(self.max_length + 1): # 32
125
125
  if t == 0:
@@ -166,7 +166,6 @@ class SAR(nn.Module, RecognitionModel):
166
166
  Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
167
167
 
168
168
  Args:
169
- ----
170
169
  feature_extractor: the backbone serving as feature extractor
171
170
  vocab: vocabulary used for encoding
172
171
  rnn_units: number of hidden units in both encoder and decoder LSTM
@@ -187,9 +186,9 @@ class SAR(nn.Module, RecognitionModel):
187
186
  attention_units: int = 512,
188
187
  max_length: int = 30,
189
188
  dropout_prob: float = 0.0,
190
- input_shape: Tuple[int, int, int] = (3, 32, 128),
189
+ input_shape: tuple[int, int, int] = (3, 32, 128),
191
190
  exportable: bool = False,
192
- cfg: Optional[Dict[str, Any]] = None,
191
+ cfg: dict[str, Any] | None = None,
193
192
  ) -> None:
194
193
  super().__init__()
195
194
  self.vocab = vocab
@@ -232,10 +231,10 @@ class SAR(nn.Module, RecognitionModel):
232
231
  def forward(
233
232
  self,
234
233
  x: torch.Tensor,
235
- target: Optional[List[str]] = None,
234
+ target: list[str] | None = None,
236
235
  return_model_output: bool = False,
237
236
  return_preds: bool = False,
238
- ) -> Dict[str, Any]:
237
+ ) -> dict[str, Any]:
239
238
  features = self.feat_extractor(x)["features"]
240
239
  # NOTE: use max instead of functional max_pool2d which leads to ONNX incompatibility (kernel_size)
241
240
  # Vertical max pooling (N, C, H, W) --> (N, C, W)
@@ -254,7 +253,7 @@ class SAR(nn.Module, RecognitionModel):
254
253
 
255
254
  decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt))
256
255
 
257
- out: Dict[str, Any] = {}
256
+ out: dict[str, Any] = {}
258
257
  if self.exportable:
259
258
  out["logits"] = decoded_features
260
259
  return out
@@ -263,8 +262,13 @@ class SAR(nn.Module, RecognitionModel):
263
262
  out["out_map"] = decoded_features
264
263
 
265
264
  if target is None or return_preds:
265
+ # Disable for torch.compile compatibility
266
+ @torch.compiler.disable # type: ignore[attr-defined]
267
+ def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
268
+ return self.postprocessor(decoded_features)
269
+
266
270
  # Post-process boxes
267
- out["preds"] = self.postprocessor(decoded_features)
271
+ out["preds"] = _postprocess(decoded_features)
268
272
 
269
273
  if target is not None:
270
274
  out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
@@ -281,19 +285,17 @@ class SAR(nn.Module, RecognitionModel):
281
285
  Sequences are masked after the EOS character.
282
286
 
283
287
  Args:
284
- ----
285
288
  model_output: predicted logits of the model
286
289
  gt: the encoded tensor with gt labels
287
290
  seq_len: lengths of each gt word inside the batch
288
291
 
289
292
  Returns:
290
- -------
291
293
  The loss of the model on the batch
292
294
  """
293
295
  # Input length : number of timesteps
294
296
  input_len = model_output.shape[1]
295
297
  # Add one for additional <eos> token
296
- seq_len = seq_len + 1
298
+ seq_len = seq_len + 1 # type: ignore[assignment]
297
299
  # Compute loss
298
300
  # (N, L, vocab_size + 1)
299
301
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
@@ -308,14 +310,13 @@ class SARPostProcessor(RecognitionPostProcessor):
308
310
  """Post processor for SAR architectures
309
311
 
310
312
  Args:
311
- ----
312
313
  vocab: string containing the ordered sequence of supported characters
313
314
  """
314
315
 
315
316
  def __call__(
316
317
  self,
317
318
  logits: torch.Tensor,
318
- ) -> List[Tuple[str, float]]:
319
+ ) -> list[tuple[str, float]]:
319
320
  # compute pred with argmax for attention models
320
321
  out_idxs = logits.argmax(-1)
321
322
  # N x L
@@ -338,7 +339,7 @@ def _sar(
338
339
  backbone_fn: Callable[[bool], nn.Module],
339
340
  layer: str,
340
341
  pretrained_backbone: bool = True,
341
- ignore_keys: Optional[List[str]] = None,
342
+ ignore_keys: list[str] | None = None,
342
343
  **kwargs: Any,
343
344
  ) -> SAR:
344
345
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -379,12 +380,10 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
379
380
  >>> out = model(input_tensor)
380
381
 
381
382
  Args:
382
- ----
383
383
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
384
384
  **kwargs: keyword arguments of the SAR architecture
385
385
 
386
386
  Returns:
387
- -------
388
387
  text recognition architecture
389
388
  """
390
389
  return _sar(
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import Model, Sequential, layers
@@ -13,18 +13,18 @@ from doctr.datasets import VOCABS
13
13
  from doctr.utils.repr import NestedObject
14
14
 
15
15
  from ...classification import resnet31
16
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
16
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
17
17
  from ..core import RecognitionModel, RecognitionPostProcessor
18
18
 
19
19
  __all__ = ["SAR", "sar_resnet31"]
20
20
 
21
- default_cfgs: Dict[str, Dict[str, Any]] = {
21
+ default_cfgs: dict[str, dict[str, Any]] = {
22
22
  "sar_resnet31": {
23
23
  "mean": (0.694, 0.695, 0.693),
24
24
  "std": (0.299, 0.296, 0.301),
25
25
  "input_shape": (32, 128, 3),
26
26
  "vocab": VOCABS["french"],
27
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/sar_resnet31-c41e32a5.zip&src=0",
27
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0",
28
28
  },
29
29
  }
30
30
 
@@ -33,7 +33,6 @@ class SAREncoder(layers.Layer, NestedObject):
33
33
  """Implements encoder module of the SAR model
34
34
 
35
35
  Args:
36
- ----
37
36
  rnn_units: number of hidden rnn units
38
37
  dropout_prob: dropout probability
39
38
  """
@@ -58,7 +57,6 @@ class AttentionModule(layers.Layer, NestedObject):
58
57
  """Implements attention module of the SAR model
59
58
 
60
59
  Args:
61
- ----
62
60
  attention_units: number of hidden attention units
63
61
 
64
62
  """
@@ -120,7 +118,6 @@ class SARDecoder(layers.Layer, NestedObject):
120
118
  """Implements decoder module of the SAR model
121
119
 
122
120
  Args:
123
- ----
124
121
  rnn_units: number of hidden units in recurrent cells
125
122
  max_length: maximum length of a sequence
126
123
  vocab_size: number of classes in the model alphabet
@@ -159,13 +156,13 @@ class SARDecoder(layers.Layer, NestedObject):
159
156
  self,
160
157
  features: tf.Tensor,
161
158
  holistic: tf.Tensor,
162
- gt: Optional[tf.Tensor] = None,
159
+ gt: tf.Tensor | None = None,
163
160
  **kwargs: Any,
164
161
  ) -> tf.Tensor:
165
162
  if gt is not None:
166
163
  gt_embedding = self.embed_tgt(gt, **kwargs)
167
164
 
168
- logits_list: List[tf.Tensor] = []
165
+ logits_list: list[tf.Tensor] = []
169
166
 
170
167
  for t in range(self.max_length + 1): # 32
171
168
  if t == 0:
@@ -210,7 +207,6 @@ class SAR(Model, RecognitionModel):
210
207
  Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
211
208
 
212
209
  Args:
213
- ----
214
210
  feature_extractor: the backbone serving as feature extractor
215
211
  vocab: vocabulary used for encoding
216
212
  rnn_units: number of hidden units in both encoder and decoder LSTM
@@ -223,7 +219,7 @@ class SAR(Model, RecognitionModel):
223
219
  cfg: dictionary containing information about the model
224
220
  """
225
221
 
226
- _children_names: List[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"]
222
+ _children_names: list[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"]
227
223
 
228
224
  def __init__(
229
225
  self,
@@ -236,7 +232,7 @@ class SAR(Model, RecognitionModel):
236
232
  num_decoder_cells: int = 2,
237
233
  dropout_prob: float = 0.0,
238
234
  exportable: bool = False,
239
- cfg: Optional[Dict[str, Any]] = None,
235
+ cfg: dict[str, Any] | None = None,
240
236
  ) -> None:
241
237
  super().__init__()
242
238
  self.vocab = vocab
@@ -269,13 +265,11 @@ class SAR(Model, RecognitionModel):
269
265
  Sequences are masked after the EOS character.
270
266
 
271
267
  Args:
272
- ----
273
268
  gt: the encoded tensor with gt labels
274
269
  model_output: predicted logits of the model
275
270
  seq_len: lengths of each gt word inside the batch
276
271
 
277
272
  Returns:
278
- -------
279
273
  The loss of the model on the batch
280
274
  """
281
275
  # Input length : number of timesteps
@@ -296,11 +290,11 @@ class SAR(Model, RecognitionModel):
296
290
  def call(
297
291
  self,
298
292
  x: tf.Tensor,
299
- target: Optional[List[str]] = None,
293
+ target: list[str] | None = None,
300
294
  return_model_output: bool = False,
301
295
  return_preds: bool = False,
302
296
  **kwargs: Any,
303
- ) -> Dict[str, Any]:
297
+ ) -> dict[str, Any]:
304
298
  features = self.feat_extractor(x, **kwargs)
305
299
  # vertical max pooling --> (N, C, W)
306
300
  pooled_features = tf.reduce_max(features, axis=1)
@@ -318,7 +312,7 @@ class SAR(Model, RecognitionModel):
318
312
  self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
319
313
  )
320
314
 
321
- out: Dict[str, tf.Tensor] = {}
315
+ out: dict[str, tf.Tensor] = {}
322
316
  if self.exportable:
323
317
  out["logits"] = decoded_features
324
318
  return out
@@ -340,14 +334,13 @@ class SARPostProcessor(RecognitionPostProcessor):
340
334
  """Post processor for SAR architectures
341
335
 
342
336
  Args:
343
- ----
344
337
  vocab: string containing the ordered sequence of supported characters
345
338
  """
346
339
 
347
340
  def __call__(
348
341
  self,
349
342
  logits: tf.Tensor,
350
- ) -> List[Tuple[str, float]]:
343
+ ) -> list[tuple[str, float]]:
351
344
  # compute pred with argmax for attention models
352
345
  out_idxs = tf.math.argmax(logits, axis=2)
353
346
  # N x L
@@ -371,7 +364,7 @@ def _sar(
371
364
  pretrained: bool,
372
365
  backbone_fn,
373
366
  pretrained_backbone: bool = True,
374
- input_shape: Optional[Tuple[int, int, int]] = None,
367
+ input_shape: tuple[int, int, int] | None = None,
375
368
  **kwargs: Any,
376
369
  ) -> SAR:
377
370
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -392,9 +385,13 @@ def _sar(
392
385
 
393
386
  # Build the model
394
387
  model = SAR(feat_extractor, cfg=_cfg, **kwargs)
388
+ _build_model(model)
395
389
  # Load pretrained parameters
396
390
  if pretrained:
397
- load_pretrained_params(model, default_cfgs[arch]["url"])
391
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
392
+ load_pretrained_params(
393
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
394
+ )
398
395
 
399
396
  return model
400
397
 
@@ -410,12 +407,10 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
410
407
  >>> out = model(input_tensor)
411
408
 
412
409
  Args:
413
- ----
414
410
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
415
411
  **kwargs: keyword arguments of the SAR architecture
416
412
 
417
413
  Returns:
418
- -------
419
414
  text recognition architecture
420
415
  """
421
416
  return _sar("sar_resnet31", pretrained, resnet31, **kwargs)
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List
7
6
 
8
7
  from rapidfuzz.distance import Levenshtein
9
8
 
@@ -14,18 +13,16 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
14
13
  """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
15
14
 
16
15
  Args:
17
- ----
18
16
  a: first char seq, suffix should be similar to b's prefix.
19
17
  b: second char seq, prefix should be similar to a's suffix.
20
18
  dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
21
19
  only used when the mother sequence is splitted on a character repetition
22
20
 
23
21
  Returns:
24
- -------
25
22
  A merged character sequence.
26
23
 
27
24
  Example::
28
- >>> from doctr.model.recognition.utils import merge_sequences
25
+ >>> from doctr.models.recognition.utils import merge_sequences
29
26
  >>> merge_sequences('abcd', 'cdefgh', 1.4)
30
27
  'abcdefgh'
31
28
  >>> merge_sequences('abcdi', 'cdefgh', 1.4)
@@ -61,26 +58,24 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
61
58
  return a[:-1] + b[index - 1 :]
62
59
 
63
60
 
64
- def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str:
61
+ def merge_multi_strings(seq_list: list[str], dil_factor: float) -> str:
65
62
  """Recursively merges consecutive string sequences with overlapping characters.
66
63
 
67
64
  Args:
68
- ----
69
65
  seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
70
66
  dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
71
67
  only used when the mother sequence is splitted on a character repetition
72
68
 
73
69
  Returns:
74
- -------
75
70
  A merged character sequence
76
71
 
77
72
  Example::
78
- >>> from doctr.model.recognition.utils import merge_multi_sequences
73
+ >>> from doctr.models.recognition.utils import merge_multi_sequences
79
74
  >>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4)
80
75
  'abcdefghijkl'
81
76
  """
82
77
 
83
- def _recursive_merge(a: str, seq_list: List[str], dil_factor: float) -> str:
78
+ def _recursive_merge(a: str, seq_list: list[str], dil_factor: float) -> str:
84
79
  # Recursive version of compute_overlap
85
80
  if len(seq_list) == 1:
86
81
  return merge_strings(a, seq_list[0], dil_factor)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple
7
6
 
8
7
  import numpy as np
9
8
 
@@ -17,17 +16,15 @@ class _ViTSTR:
17
16
 
18
17
  def build_target(
19
18
  self,
20
- gts: List[str],
21
- ) -> Tuple[np.ndarray, List[int]]:
19
+ gts: list[str],
20
+ ) -> tuple[np.ndarray, list[int]]:
22
21
  """Encode a list of gts sequences into a np array and gives the corresponding*
23
22
  sequence lengths.
24
23
 
25
24
  Args:
26
- ----
27
25
  gts: list of ground-truth labels
28
26
 
29
27
  Returns:
30
- -------
31
28
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
32
29
  """
33
30
  encoded = encode_sequences(
@@ -45,7 +42,6 @@ class _ViTSTRPostProcessor(RecognitionPostProcessor):
45
42
  """Abstract class to postprocess the raw output of the model
46
43
 
47
44
  Args:
48
- ----
49
45
  vocab: string containing the ordered sequence of supported characters
50
46
  """
51
47