python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/contrib/artefacts.py +1 -1
  3. doctr/contrib/base.py +1 -1
  4. doctr/datasets/__init__.py +0 -5
  5. doctr/datasets/coco_text.py +1 -1
  6. doctr/datasets/cord.py +1 -1
  7. doctr/datasets/datasets/__init__.py +1 -6
  8. doctr/datasets/datasets/base.py +1 -1
  9. doctr/datasets/datasets/pytorch.py +3 -3
  10. doctr/datasets/detection.py +1 -1
  11. doctr/datasets/doc_artefacts.py +1 -1
  12. doctr/datasets/funsd.py +1 -1
  13. doctr/datasets/generator/__init__.py +1 -6
  14. doctr/datasets/generator/base.py +1 -1
  15. doctr/datasets/generator/pytorch.py +1 -1
  16. doctr/datasets/ic03.py +1 -1
  17. doctr/datasets/ic13.py +1 -1
  18. doctr/datasets/iiit5k.py +1 -1
  19. doctr/datasets/iiithws.py +1 -1
  20. doctr/datasets/imgur5k.py +1 -1
  21. doctr/datasets/mjsynth.py +1 -1
  22. doctr/datasets/ocr.py +1 -1
  23. doctr/datasets/orientation.py +1 -1
  24. doctr/datasets/recognition.py +1 -1
  25. doctr/datasets/sroie.py +1 -1
  26. doctr/datasets/svhn.py +1 -1
  27. doctr/datasets/svt.py +1 -1
  28. doctr/datasets/synthtext.py +1 -1
  29. doctr/datasets/utils.py +1 -1
  30. doctr/datasets/vocabs.py +1 -3
  31. doctr/datasets/wildreceipt.py +1 -1
  32. doctr/file_utils.py +3 -102
  33. doctr/io/elements.py +1 -1
  34. doctr/io/html.py +1 -1
  35. doctr/io/image/__init__.py +1 -7
  36. doctr/io/image/base.py +1 -1
  37. doctr/io/image/pytorch.py +2 -2
  38. doctr/io/pdf.py +1 -1
  39. doctr/io/reader.py +1 -1
  40. doctr/models/_utils.py +56 -18
  41. doctr/models/builder.py +1 -1
  42. doctr/models/classification/magc_resnet/__init__.py +1 -6
  43. doctr/models/classification/magc_resnet/pytorch.py +3 -3
  44. doctr/models/classification/mobilenet/__init__.py +1 -6
  45. doctr/models/classification/mobilenet/pytorch.py +1 -1
  46. doctr/models/classification/predictor/__init__.py +1 -6
  47. doctr/models/classification/predictor/pytorch.py +2 -2
  48. doctr/models/classification/resnet/__init__.py +1 -6
  49. doctr/models/classification/resnet/pytorch.py +1 -1
  50. doctr/models/classification/textnet/__init__.py +1 -6
  51. doctr/models/classification/textnet/pytorch.py +2 -2
  52. doctr/models/classification/vgg/__init__.py +1 -6
  53. doctr/models/classification/vgg/pytorch.py +1 -1
  54. doctr/models/classification/vip/__init__.py +1 -4
  55. doctr/models/classification/vip/layers/__init__.py +1 -4
  56. doctr/models/classification/vip/layers/pytorch.py +2 -2
  57. doctr/models/classification/vip/pytorch.py +1 -1
  58. doctr/models/classification/vit/__init__.py +1 -6
  59. doctr/models/classification/vit/pytorch.py +3 -3
  60. doctr/models/classification/zoo.py +7 -12
  61. doctr/models/core.py +1 -1
  62. doctr/models/detection/_utils/__init__.py +1 -6
  63. doctr/models/detection/_utils/base.py +1 -1
  64. doctr/models/detection/_utils/pytorch.py +1 -1
  65. doctr/models/detection/core.py +2 -2
  66. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  67. doctr/models/detection/differentiable_binarization/base.py +5 -13
  68. doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
  69. doctr/models/detection/fast/__init__.py +1 -6
  70. doctr/models/detection/fast/base.py +5 -15
  71. doctr/models/detection/fast/pytorch.py +5 -5
  72. doctr/models/detection/linknet/__init__.py +1 -6
  73. doctr/models/detection/linknet/base.py +4 -13
  74. doctr/models/detection/linknet/pytorch.py +3 -3
  75. doctr/models/detection/predictor/__init__.py +1 -6
  76. doctr/models/detection/predictor/pytorch.py +2 -2
  77. doctr/models/detection/zoo.py +16 -33
  78. doctr/models/factory/hub.py +26 -34
  79. doctr/models/kie_predictor/__init__.py +1 -6
  80. doctr/models/kie_predictor/base.py +1 -1
  81. doctr/models/kie_predictor/pytorch.py +3 -7
  82. doctr/models/modules/layers/__init__.py +1 -6
  83. doctr/models/modules/layers/pytorch.py +4 -4
  84. doctr/models/modules/transformer/__init__.py +1 -6
  85. doctr/models/modules/transformer/pytorch.py +3 -3
  86. doctr/models/modules/vision_transformer/__init__.py +1 -6
  87. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  88. doctr/models/predictor/__init__.py +1 -6
  89. doctr/models/predictor/base.py +4 -9
  90. doctr/models/predictor/pytorch.py +3 -6
  91. doctr/models/preprocessor/__init__.py +1 -6
  92. doctr/models/preprocessor/pytorch.py +28 -33
  93. doctr/models/recognition/core.py +1 -1
  94. doctr/models/recognition/crnn/__init__.py +1 -6
  95. doctr/models/recognition/crnn/pytorch.py +7 -7
  96. doctr/models/recognition/master/__init__.py +1 -6
  97. doctr/models/recognition/master/base.py +1 -1
  98. doctr/models/recognition/master/pytorch.py +6 -6
  99. doctr/models/recognition/parseq/__init__.py +1 -6
  100. doctr/models/recognition/parseq/base.py +1 -1
  101. doctr/models/recognition/parseq/pytorch.py +6 -6
  102. doctr/models/recognition/predictor/__init__.py +1 -6
  103. doctr/models/recognition/predictor/_utils.py +8 -17
  104. doctr/models/recognition/predictor/pytorch.py +2 -3
  105. doctr/models/recognition/sar/__init__.py +1 -6
  106. doctr/models/recognition/sar/pytorch.py +4 -4
  107. doctr/models/recognition/utils.py +1 -1
  108. doctr/models/recognition/viptr/__init__.py +1 -4
  109. doctr/models/recognition/viptr/pytorch.py +4 -4
  110. doctr/models/recognition/vitstr/__init__.py +1 -6
  111. doctr/models/recognition/vitstr/base.py +1 -1
  112. doctr/models/recognition/vitstr/pytorch.py +4 -4
  113. doctr/models/recognition/zoo.py +14 -14
  114. doctr/models/utils/__init__.py +1 -6
  115. doctr/models/utils/pytorch.py +3 -2
  116. doctr/models/zoo.py +1 -1
  117. doctr/transforms/functional/__init__.py +1 -6
  118. doctr/transforms/functional/base.py +3 -2
  119. doctr/transforms/functional/pytorch.py +5 -5
  120. doctr/transforms/modules/__init__.py +1 -7
  121. doctr/transforms/modules/base.py +28 -94
  122. doctr/transforms/modules/pytorch.py +29 -27
  123. doctr/utils/common_types.py +1 -1
  124. doctr/utils/data.py +1 -2
  125. doctr/utils/fonts.py +1 -1
  126. doctr/utils/geometry.py +7 -11
  127. doctr/utils/metrics.py +1 -1
  128. doctr/utils/multithreading.py +1 -1
  129. doctr/utils/reconstitution.py +1 -1
  130. doctr/utils/repr.py +1 -1
  131. doctr/utils/visualization.py +2 -2
  132. doctr/version.py +1 -1
  133. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
  134. python_doctr-1.0.1.dist-info/RECORD +149 -0
  135. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
  136. doctr/datasets/datasets/tensorflow.py +0 -59
  137. doctr/datasets/generator/tensorflow.py +0 -58
  138. doctr/datasets/loader.py +0 -94
  139. doctr/io/image/tensorflow.py +0 -101
  140. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  141. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  142. doctr/models/classification/predictor/tensorflow.py +0 -60
  143. doctr/models/classification/resnet/tensorflow.py +0 -418
  144. doctr/models/classification/textnet/tensorflow.py +0 -275
  145. doctr/models/classification/vgg/tensorflow.py +0 -125
  146. doctr/models/classification/vit/tensorflow.py +0 -201
  147. doctr/models/detection/_utils/tensorflow.py +0 -34
  148. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  149. doctr/models/detection/fast/tensorflow.py +0 -427
  150. doctr/models/detection/linknet/tensorflow.py +0 -377
  151. doctr/models/detection/predictor/tensorflow.py +0 -70
  152. doctr/models/kie_predictor/tensorflow.py +0 -187
  153. doctr/models/modules/layers/tensorflow.py +0 -171
  154. doctr/models/modules/transformer/tensorflow.py +0 -235
  155. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  156. doctr/models/predictor/tensorflow.py +0 -155
  157. doctr/models/preprocessor/tensorflow.py +0 -122
  158. doctr/models/recognition/crnn/tensorflow.py +0 -317
  159. doctr/models/recognition/master/tensorflow.py +0 -320
  160. doctr/models/recognition/parseq/tensorflow.py +0 -516
  161. doctr/models/recognition/predictor/tensorflow.py +0 -79
  162. doctr/models/recognition/sar/tensorflow.py +0 -423
  163. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  164. doctr/models/utils/tensorflow.py +0 -189
  165. doctr/transforms/functional/tensorflow.py +0 -254
  166. doctr/transforms/modules/tensorflow.py +0 -562
  167. python_doctr-0.12.0.dist-info/RECORD +0 -180
  168. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
  169. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
  170. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -19,7 +19,7 @@ from doctr.datasets import VOCABS
19
19
  from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
20
20
 
21
21
  from ...classification import vit_s
22
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
22
+ from ...utils import _bf16_to_float32, load_pretrained_params
23
23
  from .base import _PARSeq, _PARSeqPostProcessor
24
24
 
25
25
  __all__ = ["PARSeq", "parseq"]
@@ -299,7 +299,7 @@ class PARSeq(_PARSeq, nn.Module):
299
299
 
300
300
  # Stop decoding if all sequences have reached the EOS token
301
301
  # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
302
- if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): # type: ignore[attr-defined]
302
+ if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
303
303
  break
304
304
 
305
305
  logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
@@ -314,7 +314,7 @@ class PARSeq(_PARSeq, nn.Module):
314
314
 
315
315
  # Create padding mask for refined target input maskes all behind EOS token as False
316
316
  # (N, 1, 1, max_length)
317
- target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
317
+ target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
318
318
  mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
319
319
  logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
320
320
 
@@ -367,7 +367,7 @@ class PARSeq(_PARSeq, nn.Module):
367
367
  # remove the [EOS] tokens for the succeeding perms
368
368
  if i == 1:
369
369
  gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out)
370
- n = (gt_out != self.vocab_size + 2).sum().item() # type: ignore[attr-defined]
370
+ n = (gt_out != self.vocab_size + 2).sum().item()
371
371
 
372
372
  loss /= loss_numel
373
373
 
@@ -391,7 +391,7 @@ class PARSeq(_PARSeq, nn.Module):
391
391
 
392
392
  if target is None or return_preds:
393
393
  # Disable for torch.compile compatibility
394
- @torch.compiler.disable # type: ignore[attr-defined]
394
+ @torch.compiler.disable
395
395
  def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
396
396
  return self.postprocessor(logits)
397
397
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -18,17 +18,15 @@ def split_crops(
18
18
  max_ratio: float,
19
19
  target_ratio: int,
20
20
  split_overlap_ratio: float,
21
- channels_last: bool = True,
22
21
  ) -> tuple[list[np.ndarray], list[int | tuple[int, int, float]], bool]:
23
22
  """
24
23
  Split crops horizontally if they exceed a given aspect ratio.
25
24
 
26
25
  Args:
27
- crops: List of image crops (H, W, C) if channels_last else (C, H, W).
26
+ crops: List of image crops (H, W, C).
28
27
  max_ratio: Aspect ratio threshold above which crops are split.
29
28
  target_ratio: Target aspect ratio after splitting (e.g., 4 for 128x32).
30
29
  split_overlap_ratio: Desired overlap between splits (as a fraction of split width).
31
- channels_last: Whether the crops are in channels-last format.
32
30
 
33
31
  Returns:
34
32
  A tuple containing:
@@ -44,14 +42,14 @@ def split_crops(
44
42
  crop_map: list[int | tuple[int, int, float]] = []
45
43
 
46
44
  for crop in crops:
47
- h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
45
+ h, w = crop.shape[:2]
48
46
  aspect_ratio = w / h
49
47
 
50
48
  if aspect_ratio > max_ratio:
51
49
  split_width = max(1, math.ceil(h * target_ratio))
52
50
  overlap_width = max(0, math.floor(split_width * split_overlap_ratio))
53
51
 
54
- splits, last_overlap = _split_horizontally(crop, split_width, overlap_width, channels_last)
52
+ splits, last_overlap = _split_horizontally(crop, split_width, overlap_width)
55
53
 
56
54
  # Remove any empty splits
57
55
  splits = [s for s in splits if all(dim > 0 for dim in s.shape)]
@@ -70,23 +68,20 @@ def split_crops(
70
68
  return new_crops, crop_map, remap_required
71
69
 
72
70
 
73
- def _split_horizontally(
74
- image: np.ndarray, split_width: int, overlap_width: int, channels_last: bool
75
- ) -> tuple[list[np.ndarray], float]:
71
+ def _split_horizontally(image: np.ndarray, split_width: int, overlap_width: int) -> tuple[list[np.ndarray], float]:
76
72
  """
77
73
  Horizontally split a single image with overlapping regions.
78
74
 
79
75
  Args:
80
- image: The image to split (H, W, C) if channels_last else (C, H, W).
76
+ image: The image to split (H, W, C).
81
77
  split_width: Width of each split.
82
78
  overlap_width: Width of the overlapping region.
83
- channels_last: Whether the image is in channels-last format.
84
79
 
85
80
  Returns:
86
81
  - A list of horizontal image slices.
87
82
  - The actual overlap ratio of the last split.
88
83
  """
89
- image_width = image.shape[1] if channels_last else image.shape[-1]
84
+ image_width = image.shape[1]
90
85
  if image_width <= split_width:
91
86
  return [image], 0.0
92
87
 
@@ -101,11 +96,7 @@ def _split_horizontally(
101
96
  splits = []
102
97
  for start_col in starts:
103
98
  end_col = start_col + split_width
104
- if channels_last:
105
- split = image[:, start_col:end_col, :]
106
- else:
107
- split = image[:, :, start_col:end_col]
108
- splits.append(split)
99
+ splits.append(image[:, start_col:end_col, :])
109
100
 
110
101
  # Calculate the last overlap ratio, if only one split no overlap
111
102
  last_overlap = 0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -44,7 +44,7 @@ class RecognitionPredictor(nn.Module):
44
44
  @torch.inference_mode()
45
45
  def forward(
46
46
  self,
47
- crops: Sequence[np.ndarray | torch.Tensor],
47
+ crops: Sequence[np.ndarray],
48
48
  **kwargs: Any,
49
49
  ) -> list[tuple[str, float]]:
50
50
  if len(crops) == 0:
@@ -61,7 +61,6 @@ class RecognitionPredictor(nn.Module):
61
61
  self.critical_ar,
62
62
  self.target_ar,
63
63
  self.overlap_ratio,
64
- isinstance(crops[0], np.ndarray),
65
64
  )
66
65
  if remapped:
67
66
  crops = new_crops
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -15,7 +15,7 @@ from torchvision.models._utils import IntermediateLayerGetter
15
15
  from doctr.datasets import VOCABS
16
16
 
17
17
  from ...classification import resnet31
18
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
18
+ from ...utils import _bf16_to_float32, load_pretrained_params
19
19
  from ..core import RecognitionModel, RecognitionPostProcessor
20
20
 
21
21
  __all__ = ["SAR", "sar_resnet31"]
@@ -272,7 +272,7 @@ class SAR(nn.Module, RecognitionModel):
272
272
 
273
273
  if target is None or return_preds:
274
274
  # Disable for torch.compile compatibility
275
- @torch.compiler.disable # type: ignore[attr-defined]
275
+ @torch.compiler.disable
276
276
  def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
277
277
  return self.postprocessor(decoded_features)
278
278
 
@@ -304,7 +304,7 @@ class SAR(nn.Module, RecognitionModel):
304
304
  # Input length : number of timesteps
305
305
  input_len = model_output.shape[1]
306
306
  # Add one for additional <eos> token
307
- seq_len = seq_len + 1 # type: ignore[assignment]
307
+ seq_len = seq_len + 1
308
308
  # Compute loss
309
309
  # (N, L, vocab_size + 1)
310
310
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,4 +1 @@
1
- from doctr.file_utils import is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -16,7 +16,7 @@ from torchvision.models._utils import IntermediateLayerGetter
16
16
  from doctr.datasets import VOCABS, decode_sequence
17
17
 
18
18
  from ...classification import vip_tiny
19
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
19
+ from ...utils import _bf16_to_float32, load_pretrained_params
20
20
  from ..core import RecognitionModel, RecognitionPostProcessor
21
21
 
22
22
  __all__ = ["VIPTR", "viptr_tiny"]
@@ -70,7 +70,7 @@ class VIPTRPostProcessor(RecognitionPostProcessor):
70
70
 
71
71
  def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
72
72
  """Performs decoding of raw output with CTC and decoding of CTC predictions
73
- with label_to_idx mapping dictionnary
73
+ with label_to_idx mapping dictionary
74
74
 
75
75
  Args:
76
76
  logits: raw output of the model, shape (N, C + 1, seq_len)
@@ -166,7 +166,7 @@ class VIPTR(RecognitionModel, nn.Module):
166
166
 
167
167
  if target is None or return_preds:
168
168
  # Disable for torch.compile compatibility
169
- @torch.compiler.disable # type: ignore[attr-defined]
169
+ @torch.compiler.disable
170
170
  def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
171
171
  return self.postprocessor(decoded_features)
172
172
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -15,7 +15,7 @@ from torchvision.models._utils import IntermediateLayerGetter
15
15
  from doctr.datasets import VOCABS
16
16
 
17
17
  from ...classification import vit_b, vit_s
18
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
18
+ from ...utils import _bf16_to_float32, load_pretrained_params
19
19
  from .base import _ViTSTR, _ViTSTRPostProcessor
20
20
 
21
21
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
@@ -117,7 +117,7 @@ class ViTSTR(_ViTSTR, nn.Module):
117
117
 
118
118
  if target is None or return_preds:
119
119
  # Disable for torch.compile compatibility
120
- @torch.compiler.disable # type: ignore[attr-defined]
120
+ @torch.compiler.disable
121
121
  def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
122
122
  return self.postprocessor(decoded_features)
123
123
 
@@ -149,7 +149,7 @@ class ViTSTR(_ViTSTR, nn.Module):
149
149
  # Input length : number of steps
150
150
  input_len = model_output.shape[1]
151
151
  # Add one for additional <eos> token (sos disappear in shift!)
152
- seq_len = seq_len + 1 # type: ignore[assignment]
152
+ seq_len = seq_len + 1
153
153
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
154
154
  # The "masked" first gt char is <sos>.
155
155
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -1,12 +1,12 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available, is_torch_available
9
8
  from doctr.models.preprocessor import PreProcessor
9
+ from doctr.models.utils import _CompiledModule
10
10
 
11
11
  from .. import recognition
12
12
  from .predictor import RecognitionPredictor
@@ -23,11 +23,9 @@ ARCHS: list[str] = [
23
23
  "vitstr_small",
24
24
  "vitstr_base",
25
25
  "parseq",
26
+ "viptr_tiny",
26
27
  ]
27
28
 
28
- if is_torch_available():
29
- ARCHS.extend(["viptr_tiny"])
30
-
31
29
 
32
30
  def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor:
33
31
  if isinstance(arch, str):
@@ -38,14 +36,16 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
38
36
  pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
39
37
  )
40
38
  else:
41
- allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
42
- if is_torch_available():
43
- # Add VIPTR which is only available in torch at the moment
44
- allowed_archs.append(recognition.VIPTR)
45
- # Adding the type for torch compiled models to the allowed architectures
46
- from doctr.models.utils import _CompiledModule
47
-
48
- allowed_archs.append(_CompiledModule)
39
+ # Adding the type for torch compiled models to the allowed architectures
40
+ allowed_archs = [
41
+ recognition.CRNN,
42
+ recognition.SAR,
43
+ recognition.MASTER,
44
+ recognition.ViTSTR,
45
+ recognition.PARSeq,
46
+ recognition.VIPTR,
47
+ _CompiledModule,
48
+ ]
49
49
 
50
50
  if not isinstance(arch, tuple(allowed_archs)):
51
51
  raise ValueError(f"unknown architecture: {type(arch)}")
@@ -56,7 +56,7 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
56
56
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
57
57
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
58
58
  kwargs["batch_size"] = kwargs.get("batch_size", 128)
59
- input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:]
59
+ input_shape = _model.cfg["input_shape"][-2:]
60
60
  predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model)
61
61
 
62
62
  return predictor
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -164,12 +164,13 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
164
164
  """
165
165
  torch.onnx.export(
166
166
  model,
167
- dummy_input,
167
+ dummy_input, # type: ignore[arg-type]
168
168
  f"{model_name}.onnx",
169
169
  input_names=["input"],
170
170
  output_names=["logits"],
171
171
  dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}},
172
172
  export_params=True,
173
+ dynamo=False,
173
174
  verbose=False,
174
175
  **kwargs,
175
176
  )
doctr/models/zoo.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -145,7 +145,8 @@ def create_shadow_mask(
145
145
 
146
146
  # Convert to absolute coords
147
147
  abs_contour: np.ndarray = (
148
- np.stack(
148
+ np
149
+ .stack(
149
150
  (contour[:, 0] * target_shape[1], contour[:, 1] * target_shape[0]),
150
151
  axis=-1,
151
152
  )
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -33,9 +33,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
33
33
  rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
34
34
  # Inverse the color
35
35
  if out.dtype == torch.uint8:
36
- out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) # type: ignore[attr-defined]
36
+ out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
37
37
  else:
38
- out = out * rgb_shift.to(dtype=out.dtype) # type: ignore[attr-defined]
38
+ out = out * rgb_shift.to(dtype=out.dtype)
39
39
  # Inverse the color
40
40
  out = 255 - out if out.dtype == torch.uint8 else 1 - out
41
41
  return out
@@ -77,7 +77,7 @@ def rotate_sample(
77
77
  rotated_geoms: np.ndarray = rotate_abs_geoms(
78
78
  _geoms,
79
79
  angle,
80
- img.shape[1:],
80
+ img.shape[1:], # type: ignore[arg-type]
81
81
  expand,
82
82
  ).astype(np.float32)
83
83
 
@@ -124,7 +124,7 @@ def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwarg
124
124
  Returns:
125
125
  Shadowed image as a PyTorch tensor (same shape as input).
126
126
  """
127
- shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
127
+ shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type]
128
128
  opacity = np.random.uniform(*opacity_range)
129
129
 
130
130
  # Apply Gaussian blur to the shadow mask
@@ -1,8 +1,2 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
1
  from .base import *
4
-
5
- if is_torch_available():
6
- from .pytorch import *
7
- elif is_tf_available():
8
- from .tensorflow import * # type: ignore[assignment]
2
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -20,27 +20,13 @@ __all__ = ["SampleCompose", "ImageTransform", "ColorInversion", "OneOf", "Random
20
20
  class SampleCompose(NestedObject):
21
21
  """Implements a wrapper that will apply transformations sequentially on both image and target
22
22
 
23
- .. tabs::
23
+ .. code:: python
24
24
 
25
- .. tab:: PyTorch
26
-
27
- .. code:: python
28
-
29
- >>> import numpy as np
30
- >>> import torch
31
- >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
32
- >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
33
- >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
34
-
35
- .. tab:: TensorFlow
36
-
37
- .. code:: python
38
-
39
- >>> import numpy as np
40
- >>> import tensorflow as tf
41
- >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
42
- >>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
43
- >>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4)))
25
+ >>> import numpy as np
26
+ >>> import torch
27
+ >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
28
+ >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
29
+ >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
44
30
 
45
31
  Args:
46
32
  transforms: list of transformation modules
@@ -61,25 +47,12 @@ class SampleCompose(NestedObject):
61
47
  class ImageTransform(NestedObject):
62
48
  """Implements a transform wrapper to turn an image-only transformation into an image+target transform
63
49
 
64
- .. tabs::
65
-
66
- .. tab:: PyTorch
67
-
68
- .. code:: python
50
+ .. code:: python
69
51
 
70
- >>> import torch
71
- >>> from doctr.transforms import ImageTransform, ColorInversion
72
- >>> transfo = ImageTransform(ColorInversion((32, 32)))
73
- >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
74
-
75
- .. tab:: TensorFlow
76
-
77
- .. code:: python
78
-
79
- >>> import tensorflow as tf
80
- >>> from doctr.transforms import ImageTransform, ColorInversion
81
- >>> transfo = ImageTransform(ColorInversion((32, 32)))
82
- >>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None)
52
+ >>> import torch
53
+ >>> from doctr.transforms import ImageTransform, ColorInversion
54
+ >>> transfo = ImageTransform(ColorInversion((32, 32)))
55
+ >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
83
56
 
84
57
  Args:
85
58
  transform: the image transformation module to wrap
@@ -99,25 +72,12 @@ class ColorInversion(NestedObject):
99
72
  """Applies the following tranformation to a tensor (image or batch of images):
100
73
  convert to grayscale, colorize (shift 0-values randomly), and then invert colors
101
74
 
102
- .. tabs::
103
-
104
- .. tab:: PyTorch
105
-
106
- .. code:: python
107
-
108
- >>> import torch
109
- >>> from doctr.transforms import ColorInversion
110
- >>> transfo = ColorInversion(min_val=0.6)
111
- >>> out = transfo(torch.rand(8, 64, 64, 3))
75
+ .. code:: python
112
76
 
113
- .. tab:: TensorFlow
114
-
115
- .. code:: python
116
-
117
- >>> import tensorflow as tf
118
- >>> from doctr.transforms import ColorInversion
119
- >>> transfo = ColorInversion(min_val=0.6)
120
- >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1))
77
+ >>> import torch
78
+ >>> from doctr.transforms import ColorInversion
79
+ >>> transfo = ColorInversion(min_val=0.6)
80
+ >>> out = transfo(torch.rand(8, 64, 64, 3))
121
81
 
122
82
  Args:
123
83
  min_val: range [min_val, 1] to colorize RGB pixels
@@ -136,25 +96,12 @@ class ColorInversion(NestedObject):
136
96
  class OneOf(NestedObject):
137
97
  """Randomly apply one of the input transformations
138
98
 
139
- .. tabs::
140
-
141
- .. tab:: PyTorch
142
-
143
- .. code:: python
144
-
145
- >>> import torch
146
- >>> from doctr.transforms import OneOf
147
- >>> transfo = OneOf([JpegQuality(), Gamma()])
148
- >>> out = transfo(torch.rand(1, 64, 64, 3))
149
-
150
- .. tab:: TensorFlow
99
+ .. code:: python
151
100
 
152
- .. code:: python
153
-
154
- >>> import tensorflow as tf
155
- >>> from doctr.transforms import OneOf
156
- >>> transfo = OneOf([JpegQuality(), Gamma()])
157
- >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
101
+ >>> import torch
102
+ >>> from doctr.transforms import OneOf
103
+ >>> transfo = OneOf([JpegQuality(), Gamma()])
104
+ >>> out = transfo(torch.rand(1, 64, 64, 3))
158
105
 
159
106
  Args:
160
107
  transforms: list of transformations, one only will be picked
@@ -175,25 +122,12 @@ class OneOf(NestedObject):
175
122
  class RandomApply(NestedObject):
176
123
  """Apply with a probability p the input transformation
177
124
 
178
- .. tabs::
179
-
180
- .. tab:: PyTorch
181
-
182
- .. code:: python
183
-
184
- >>> import torch
185
- >>> from doctr.transforms import RandomApply
186
- >>> transfo = RandomApply(Gamma(), p=.5)
187
- >>> out = transfo(torch.rand(1, 64, 64, 3))
188
-
189
- .. tab:: TensorFlow
190
-
191
- .. code:: python
125
+ .. code:: python
192
126
 
193
- >>> import tensorflow as tf
194
- >>> from doctr.transforms import RandomApply
195
- >>> transfo = RandomApply(Gamma(), p=.5)
196
- >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
127
+ >>> import torch
128
+ >>> from doctr.transforms import RandomApply
129
+ >>> transfo = RandomApply(Gamma(), p=.5)
130
+ >>> out = transfo(torch.rand(1, 64, 64, 3))
197
131
 
198
132
  Args:
199
133
  transform: transformation to apply
@@ -258,7 +192,7 @@ class RandomCrop(NestedObject):
258
192
  scale = random.uniform(self.scale[0], self.scale[1])
259
193
  ratio = random.uniform(self.ratio[0], self.ratio[1])
260
194
 
261
- height, width = img.shape[:2]
195
+ height, width = img.shape[-2:]
262
196
 
263
197
  # Calculate crop size
264
198
  crop_area = scale * width * height