python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple, Union
6
+
7
+ import math
7
8
 
8
9
  import numpy as np
9
10
 
@@ -13,74 +14,132 @@ __all__ = ["split_crops", "remap_preds"]
13
14
 
14
15
 
15
16
  def split_crops(
16
- crops: List[np.ndarray],
17
+ crops: list[np.ndarray],
17
18
  max_ratio: float,
18
19
  target_ratio: int,
19
- dilation: float,
20
+ split_overlap_ratio: float,
20
21
  channels_last: bool = True,
21
- ) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]:
22
- """Chunk crops horizontally to match a given aspect ratio
22
+ ) -> tuple[list[np.ndarray], list[int | tuple[int, int, float]], bool]:
23
+ """
24
+ Split crops horizontally if they exceed a given aspect ratio.
23
25
 
24
26
  Args:
25
- ----
26
- crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
27
- max_ratio: the maximum aspect ratio that won't trigger the chunk
28
- target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
29
- dilation: the width dilation of final chunks (to provide some overlaps)
30
- channels_last: whether the numpy array has dimensions in channels last order
27
+ crops: List of image crops (H, W, C) if channels_last else (C, H, W).
28
+ max_ratio: Aspect ratio threshold above which crops are split.
29
+ target_ratio: Target aspect ratio after splitting (e.g., 4 for 128x32).
30
+ split_overlap_ratio: Desired overlap between splits (as a fraction of split width).
31
+ channels_last: Whether the crops are in channels-last format.
31
32
 
32
33
  Returns:
33
- -------
34
- a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
34
+ A tuple containing:
35
+ - The new list of crops (possibly with splits),
36
+ - A mapping indicating how to reassemble predictions,
37
+ - A boolean indicating whether remapping is required.
35
38
  """
36
- _remap_required = False
37
- crop_map: List[Union[int, Tuple[int, int]]] = []
38
- new_crops: List[np.ndarray] = []
39
+ if split_overlap_ratio <= 0.0 or split_overlap_ratio >= 1.0:
40
+ raise ValueError(f"Valid range for split_overlap_ratio is (0.0, 1.0), but is: {split_overlap_ratio}")
41
+
42
+ remap_required = False
43
+ new_crops: list[np.ndarray] = []
44
+ crop_map: list[int | tuple[int, int, float]] = []
45
+
39
46
  for crop in crops:
40
47
  h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
41
48
  aspect_ratio = w / h
49
+
42
50
  if aspect_ratio > max_ratio:
43
- # Determine the number of crops, reference aspect ratio = 4 = 128 / 32
44
- num_subcrops = int(aspect_ratio // target_ratio)
45
- # Find the new widths, additional dilation factor to overlap crops
46
- width = dilation * w / num_subcrops
47
- centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)]
48
- # Get the crops
49
- if channels_last:
50
- _crops = [
51
- crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :]
52
- for center in centers
53
- ]
51
+ split_width = max(1, math.ceil(h * target_ratio))
52
+ overlap_width = max(0, math.floor(split_width * split_overlap_ratio))
53
+
54
+ splits, last_overlap = _split_horizontally(crop, split_width, overlap_width, channels_last)
55
+
56
+ # Remove any empty splits
57
+ splits = [s for s in splits if all(dim > 0 for dim in s.shape)]
58
+ if splits:
59
+ crop_map.append((len(new_crops), len(new_crops) + len(splits), last_overlap))
60
+ new_crops.extend(splits)
61
+ remap_required = True
54
62
  else:
55
- _crops = [
56
- crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))]
57
- for center in centers
58
- ]
59
- # Avoid sending zero-sized crops
60
- _crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
61
- # Record the slice of crops
62
- crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
63
- new_crops.extend(_crops)
64
- # At least one crop will require merging
65
- _remap_required = True
63
+ # Fallback: treat it as a single crop
64
+ crop_map.append(len(new_crops))
65
+ new_crops.append(crop)
66
66
  else:
67
67
  crop_map.append(len(new_crops))
68
68
  new_crops.append(crop)
69
69
 
70
- return new_crops, crop_map, _remap_required
70
+ return new_crops, crop_map, remap_required
71
+
72
+
73
+ def _split_horizontally(
74
+ image: np.ndarray, split_width: int, overlap_width: int, channels_last: bool
75
+ ) -> tuple[list[np.ndarray], float]:
76
+ """
77
+ Horizontally split a single image with overlapping regions.
78
+
79
+ Args:
80
+ image: The image to split (H, W, C) if channels_last else (C, H, W).
81
+ split_width: Width of each split.
82
+ overlap_width: Width of the overlapping region.
83
+ channels_last: Whether the image is in channels-last format.
84
+
85
+ Returns:
86
+ - A list of horizontal image slices.
87
+ - The actual overlap ratio of the last split.
88
+ """
89
+ image_width = image.shape[1] if channels_last else image.shape[-1]
90
+ if image_width <= split_width:
91
+ return [image], 0.0
92
+
93
+ # Compute start columns for each split
94
+ step = split_width - overlap_width
95
+ starts = list(range(0, image_width - split_width + 1, step))
96
+
97
+ # Ensure the last patch reaches the end of the image
98
+ if starts[-1] + split_width < image_width:
99
+ starts.append(image_width - split_width)
100
+
101
+ splits = []
102
+ for start_col in starts:
103
+ end_col = start_col + split_width
104
+ if channels_last:
105
+ split = image[:, start_col:end_col, :]
106
+ else:
107
+ split = image[:, :, start_col:end_col]
108
+ splits.append(split)
109
+
110
+ # Calculate the last overlap ratio, if only one split no overlap
111
+ last_overlap = 0
112
+ if len(starts) > 1:
113
+ last_overlap = (starts[-2] + split_width) - starts[-1]
114
+ last_overlap_ratio = last_overlap / split_width if split_width else 0.0
115
+
116
+ return splits, last_overlap_ratio
71
117
 
72
118
 
73
119
  def remap_preds(
74
- preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float
75
- ) -> List[Tuple[str, float]]:
76
- remapped_out = []
77
- for _idx in crop_map:
78
- # Crop hasn't been split
79
- if isinstance(_idx, int):
80
- remapped_out.append(preds[_idx])
120
+ preds: list[tuple[str, float]],
121
+ crop_map: list[int | tuple[int, int, float]],
122
+ overlap_ratio: float,
123
+ ) -> list[tuple[str, float]]:
124
+ """
125
+ Reconstruct predictions from possibly split crops.
126
+
127
+ Args:
128
+ preds: List of (text, confidence) tuples from each crop.
129
+ crop_map: Map returned by `split_crops`.
130
+ overlap_ratio: Overlap ratio used during splitting.
131
+
132
+ Returns:
133
+ List of merged (text, confidence) tuples corresponding to original crops.
134
+ """
135
+ remapped = []
136
+ for item in crop_map:
137
+ if isinstance(item, int):
138
+ remapped.append(preds[item])
81
139
  else:
82
- # unzip
83
- vals, probs = zip(*preds[_idx[0] : _idx[1]])
84
- # Merge the string values
85
- remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type]
86
- return remapped_out
140
+ start_idx, end_idx, last_overlap = item
141
+ text_parts, confidences = zip(*preds[start_idx:end_idx])
142
+ merged_text = merge_multi_strings(list(text_parts), overlap_ratio, last_overlap)
143
+ merged_conf = sum(confidences) / len(confidences) # average confidence
144
+ remapped.append((merged_text, merged_conf))
145
+ return remapped
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Sequence, Tuple, Union
6
+ from collections.abc import Sequence
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -21,7 +22,6 @@ class RecognitionPredictor(nn.Module):
21
22
  """Implements an object able to identify character sequences in images
22
23
 
23
24
  Args:
24
- ----
25
25
  pre_processor: transform inputs for easier batched model inference
26
26
  model: core detection architecture
27
27
  split_wide_crops: wether to use crop splitting for high aspect ratio crops
@@ -38,15 +38,15 @@ class RecognitionPredictor(nn.Module):
38
38
  self.model = model.eval()
39
39
  self.split_wide_crops = split_wide_crops
40
40
  self.critical_ar = 8 # Critical aspect ratio
41
- self.dil_factor = 1.4 # Dilation factor to overlap the crops
41
+ self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
42
42
  self.target_ar = 6 # Target aspect ratio
43
43
 
44
44
  @torch.inference_mode()
45
45
  def forward(
46
46
  self,
47
- crops: Sequence[Union[np.ndarray, torch.Tensor]],
47
+ crops: Sequence[np.ndarray | torch.Tensor],
48
48
  **kwargs: Any,
49
- ) -> List[Tuple[str, float]]:
49
+ ) -> list[tuple[str, float]]:
50
50
  if len(crops) == 0:
51
51
  return []
52
52
  # Dimension check
@@ -60,14 +60,14 @@ class RecognitionPredictor(nn.Module):
60
60
  crops, # type: ignore[arg-type]
61
61
  self.critical_ar,
62
62
  self.target_ar,
63
- self.dil_factor,
63
+ self.overlap_ratio,
64
64
  isinstance(crops[0], np.ndarray),
65
65
  )
66
66
  if remapped:
67
67
  crops = new_crops
68
68
 
69
69
  # Resize & batch them
70
- processed_batches = self.pre_processor(crops)
70
+ processed_batches = self.pre_processor(crops) # type: ignore[arg-type]
71
71
 
72
72
  # Forward it
73
73
  _params = next(self.model.parameters())
@@ -81,6 +81,6 @@ class RecognitionPredictor(nn.Module):
81
81
 
82
82
  # Remap crops
83
83
  if self.split_wide_crops and remapped:
84
- out = remap_preds(out, crop_map, self.dil_factor)
84
+ out = remap_preds(out, crop_map, self.overlap_ratio)
85
85
 
86
86
  return out
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -21,13 +21,12 @@ class RecognitionPredictor(NestedObject):
21
21
  """Implements an object able to identify character sequences in images
22
22
 
23
23
  Args:
24
- ----
25
24
  pre_processor: transform inputs for easier batched model inference
26
25
  model: core detection architecture
27
26
  split_wide_crops: wether to use crop splitting for high aspect ratio crops
28
27
  """
29
28
 
30
- _children_names: List[str] = ["pre_processor", "model"]
29
+ _children_names: list[str] = ["pre_processor", "model"]
31
30
 
32
31
  def __init__(
33
32
  self,
@@ -40,14 +39,14 @@ class RecognitionPredictor(NestedObject):
40
39
  self.model = model
41
40
  self.split_wide_crops = split_wide_crops
42
41
  self.critical_ar = 8 # Critical aspect ratio
43
- self.dil_factor = 1.4 # Dilation factor to overlap the crops
42
+ self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
44
43
  self.target_ar = 6 # Target aspect ratio
45
44
 
46
45
  def __call__(
47
46
  self,
48
- crops: List[Union[np.ndarray, tf.Tensor]],
47
+ crops: list[np.ndarray | tf.Tensor],
49
48
  **kwargs: Any,
50
- ) -> List[Tuple[str, float]]:
49
+ ) -> list[tuple[str, float]]:
51
50
  if len(crops) == 0:
52
51
  return []
53
52
  # Dimension check
@@ -57,7 +56,7 @@ class RecognitionPredictor(NestedObject):
57
56
  # Split crops that are too wide
58
57
  remapped = False
59
58
  if self.split_wide_crops:
60
- new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.dil_factor)
59
+ new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.overlap_ratio)
61
60
  if remapped:
62
61
  crops = new_crops
63
62
 
@@ -75,6 +74,6 @@ class RecognitionPredictor(NestedObject):
75
74
 
76
75
  # Remap crops
77
76
  if self.split_wide_crops and remapped:
78
- out = remap_preds(out, crop_map, self.dil_factor)
77
+ out = remap_preds(out, crop_map, self.overlap_ratio)
79
78
 
80
79
  return out
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ from collections.abc import Callable
6
7
  from copy import deepcopy
7
- from typing import Any, Callable, Dict, List, Optional, Tuple
8
+ from typing import Any
8
9
 
9
10
  import torch
10
11
  from torch import nn
@@ -19,7 +20,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
19
20
 
20
21
  __all__ = ["SAR", "sar_resnet31"]
21
22
 
22
- default_cfgs: Dict[str, Dict[str, Any]] = {
23
+ default_cfgs: dict[str, dict[str, Any]] = {
23
24
  "sar_resnet31": {
24
25
  "mean": (0.694, 0.695, 0.693),
25
26
  "std": (0.299, 0.296, 0.301),
@@ -80,7 +81,6 @@ class SARDecoder(nn.Module):
80
81
  """Implements decoder module of the SAR model
81
82
 
82
83
  Args:
83
- ----
84
84
  rnn_units: number of hidden units in recurrent cells
85
85
  max_length: maximum length of a sequence
86
86
  vocab_size: number of classes in the model alphabet
@@ -114,12 +114,12 @@ class SARDecoder(nn.Module):
114
114
  self,
115
115
  features: torch.Tensor, # (N, C, H, W)
116
116
  holistic: torch.Tensor, # (N, C)
117
- gt: Optional[torch.Tensor] = None, # (N, L)
117
+ gt: torch.Tensor | None = None, # (N, L)
118
118
  ) -> torch.Tensor:
119
119
  if gt is not None:
120
120
  gt_embedding = self.embed_tgt(gt)
121
121
 
122
- logits_list: List[torch.Tensor] = []
122
+ logits_list: list[torch.Tensor] = []
123
123
 
124
124
  for t in range(self.max_length + 1): # 32
125
125
  if t == 0:
@@ -166,7 +166,6 @@ class SAR(nn.Module, RecognitionModel):
166
166
  Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
167
167
 
168
168
  Args:
169
- ----
170
169
  feature_extractor: the backbone serving as feature extractor
171
170
  vocab: vocabulary used for encoding
172
171
  rnn_units: number of hidden units in both encoder and decoder LSTM
@@ -187,9 +186,9 @@ class SAR(nn.Module, RecognitionModel):
187
186
  attention_units: int = 512,
188
187
  max_length: int = 30,
189
188
  dropout_prob: float = 0.0,
190
- input_shape: Tuple[int, int, int] = (3, 32, 128),
189
+ input_shape: tuple[int, int, int] = (3, 32, 128),
191
190
  exportable: bool = False,
192
- cfg: Optional[Dict[str, Any]] = None,
191
+ cfg: dict[str, Any] | None = None,
193
192
  ) -> None:
194
193
  super().__init__()
195
194
  self.vocab = vocab
@@ -229,13 +228,22 @@ class SAR(nn.Module, RecognitionModel):
229
228
  nn.init.constant_(m.weight, 1)
230
229
  nn.init.constant_(m.bias, 0)
231
230
 
231
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
232
+ """Load pretrained parameters onto the model
233
+
234
+ Args:
235
+ path_or_url: the path or URL to the model parameters (checkpoint)
236
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
237
+ """
238
+ load_pretrained_params(self, path_or_url, **kwargs)
239
+
232
240
  def forward(
233
241
  self,
234
242
  x: torch.Tensor,
235
- target: Optional[List[str]] = None,
243
+ target: list[str] | None = None,
236
244
  return_model_output: bool = False,
237
245
  return_preds: bool = False,
238
- ) -> Dict[str, Any]:
246
+ ) -> dict[str, Any]:
239
247
  features = self.feat_extractor(x)["features"]
240
248
  # NOTE: use max instead of functional max_pool2d which leads to ONNX incompatibility (kernel_size)
241
249
  # Vertical max pooling (N, C, H, W) --> (N, C, W)
@@ -254,7 +262,7 @@ class SAR(nn.Module, RecognitionModel):
254
262
 
255
263
  decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt))
256
264
 
257
- out: Dict[str, Any] = {}
265
+ out: dict[str, Any] = {}
258
266
  if self.exportable:
259
267
  out["logits"] = decoded_features
260
268
  return out
@@ -263,8 +271,13 @@ class SAR(nn.Module, RecognitionModel):
263
271
  out["out_map"] = decoded_features
264
272
 
265
273
  if target is None or return_preds:
274
+ # Disable for torch.compile compatibility
275
+ @torch.compiler.disable # type: ignore[attr-defined]
276
+ def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
277
+ return self.postprocessor(decoded_features)
278
+
266
279
  # Post-process boxes
267
- out["preds"] = self.postprocessor(decoded_features)
280
+ out["preds"] = _postprocess(decoded_features)
268
281
 
269
282
  if target is not None:
270
283
  out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
@@ -281,19 +294,17 @@ class SAR(nn.Module, RecognitionModel):
281
294
  Sequences are masked after the EOS character.
282
295
 
283
296
  Args:
284
- ----
285
297
  model_output: predicted logits of the model
286
298
  gt: the encoded tensor with gt labels
287
299
  seq_len: lengths of each gt word inside the batch
288
300
 
289
301
  Returns:
290
- -------
291
302
  The loss of the model on the batch
292
303
  """
293
304
  # Input length : number of timesteps
294
305
  input_len = model_output.shape[1]
295
306
  # Add one for additional <eos> token
296
- seq_len = seq_len + 1
307
+ seq_len = seq_len + 1 # type: ignore[assignment]
297
308
  # Compute loss
298
309
  # (N, L, vocab_size + 1)
299
310
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
@@ -308,14 +319,13 @@ class SARPostProcessor(RecognitionPostProcessor):
308
319
  """Post processor for SAR architectures
309
320
 
310
321
  Args:
311
- ----
312
322
  vocab: string containing the ordered sequence of supported characters
313
323
  """
314
324
 
315
325
  def __call__(
316
326
  self,
317
327
  logits: torch.Tensor,
318
- ) -> List[Tuple[str, float]]:
328
+ ) -> list[tuple[str, float]]:
319
329
  # compute pred with argmax for attention models
320
330
  out_idxs = logits.argmax(-1)
321
331
  # N x L
@@ -338,7 +348,7 @@ def _sar(
338
348
  backbone_fn: Callable[[bool], nn.Module],
339
349
  layer: str,
340
350
  pretrained_backbone: bool = True,
341
- ignore_keys: Optional[List[str]] = None,
351
+ ignore_keys: list[str] | None = None,
342
352
  **kwargs: Any,
343
353
  ) -> SAR:
344
354
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -363,7 +373,7 @@ def _sar(
363
373
  # The number of classes is not the same as the number of classes in the pretrained model =>
364
374
  # remove the last layer weights
365
375
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
366
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
376
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
367
377
 
368
378
  return model
369
379
 
@@ -379,12 +389,10 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
379
389
  >>> out = model(input_tensor)
380
390
 
381
391
  Args:
382
- ----
383
392
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
384
393
  **kwargs: keyword arguments of the SAR architecture
385
394
 
386
395
  Returns:
387
- -------
388
396
  text recognition architecture
389
397
  """
390
398
  return _sar(