python-doctr 0.12.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +0 -5
  3. doctr/datasets/datasets/__init__.py +1 -6
  4. doctr/datasets/datasets/pytorch.py +2 -2
  5. doctr/datasets/generator/__init__.py +1 -6
  6. doctr/datasets/vocabs.py +0 -2
  7. doctr/file_utils.py +2 -101
  8. doctr/io/image/__init__.py +1 -7
  9. doctr/io/image/pytorch.py +1 -1
  10. doctr/models/_utils.py +3 -3
  11. doctr/models/classification/magc_resnet/__init__.py +1 -6
  12. doctr/models/classification/magc_resnet/pytorch.py +2 -2
  13. doctr/models/classification/mobilenet/__init__.py +1 -6
  14. doctr/models/classification/predictor/__init__.py +1 -6
  15. doctr/models/classification/predictor/pytorch.py +1 -1
  16. doctr/models/classification/resnet/__init__.py +1 -6
  17. doctr/models/classification/textnet/__init__.py +1 -6
  18. doctr/models/classification/textnet/pytorch.py +1 -1
  19. doctr/models/classification/vgg/__init__.py +1 -6
  20. doctr/models/classification/vip/__init__.py +1 -4
  21. doctr/models/classification/vip/layers/__init__.py +1 -4
  22. doctr/models/classification/vip/layers/pytorch.py +1 -1
  23. doctr/models/classification/vit/__init__.py +1 -6
  24. doctr/models/classification/vit/pytorch.py +2 -2
  25. doctr/models/classification/zoo.py +6 -11
  26. doctr/models/detection/_utils/__init__.py +1 -6
  27. doctr/models/detection/core.py +1 -1
  28. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  29. doctr/models/detection/differentiable_binarization/base.py +4 -12
  30. doctr/models/detection/differentiable_binarization/pytorch.py +3 -3
  31. doctr/models/detection/fast/__init__.py +1 -6
  32. doctr/models/detection/fast/base.py +4 -14
  33. doctr/models/detection/fast/pytorch.py +4 -4
  34. doctr/models/detection/linknet/__init__.py +1 -6
  35. doctr/models/detection/linknet/base.py +3 -12
  36. doctr/models/detection/linknet/pytorch.py +2 -2
  37. doctr/models/detection/predictor/__init__.py +1 -6
  38. doctr/models/detection/predictor/pytorch.py +1 -1
  39. doctr/models/detection/zoo.py +15 -32
  40. doctr/models/factory/hub.py +8 -21
  41. doctr/models/kie_predictor/__init__.py +1 -6
  42. doctr/models/kie_predictor/pytorch.py +2 -6
  43. doctr/models/modules/layers/__init__.py +1 -6
  44. doctr/models/modules/layers/pytorch.py +3 -3
  45. doctr/models/modules/transformer/__init__.py +1 -6
  46. doctr/models/modules/transformer/pytorch.py +2 -2
  47. doctr/models/modules/vision_transformer/__init__.py +1 -6
  48. doctr/models/predictor/__init__.py +1 -6
  49. doctr/models/predictor/base.py +3 -8
  50. doctr/models/predictor/pytorch.py +2 -5
  51. doctr/models/preprocessor/__init__.py +1 -6
  52. doctr/models/preprocessor/pytorch.py +27 -32
  53. doctr/models/recognition/crnn/__init__.py +1 -6
  54. doctr/models/recognition/crnn/pytorch.py +6 -6
  55. doctr/models/recognition/master/__init__.py +1 -6
  56. doctr/models/recognition/master/pytorch.py +5 -5
  57. doctr/models/recognition/parseq/__init__.py +1 -6
  58. doctr/models/recognition/parseq/pytorch.py +5 -5
  59. doctr/models/recognition/predictor/__init__.py +1 -6
  60. doctr/models/recognition/predictor/_utils.py +7 -16
  61. doctr/models/recognition/predictor/pytorch.py +1 -2
  62. doctr/models/recognition/sar/__init__.py +1 -6
  63. doctr/models/recognition/sar/pytorch.py +3 -3
  64. doctr/models/recognition/viptr/__init__.py +1 -4
  65. doctr/models/recognition/viptr/pytorch.py +3 -3
  66. doctr/models/recognition/vitstr/__init__.py +1 -6
  67. doctr/models/recognition/vitstr/pytorch.py +3 -3
  68. doctr/models/recognition/zoo.py +13 -13
  69. doctr/models/utils/__init__.py +1 -6
  70. doctr/models/utils/pytorch.py +1 -1
  71. doctr/transforms/functional/__init__.py +1 -6
  72. doctr/transforms/functional/pytorch.py +4 -4
  73. doctr/transforms/modules/__init__.py +1 -7
  74. doctr/transforms/modules/base.py +26 -92
  75. doctr/transforms/modules/pytorch.py +28 -26
  76. doctr/utils/geometry.py +6 -10
  77. doctr/utils/visualization.py +1 -1
  78. doctr/version.py +1 -1
  79. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +18 -75
  80. python_doctr-1.0.0.dist-info/RECORD +149 -0
  81. doctr/datasets/datasets/tensorflow.py +0 -59
  82. doctr/datasets/generator/tensorflow.py +0 -58
  83. doctr/datasets/loader.py +0 -94
  84. doctr/io/image/tensorflow.py +0 -101
  85. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  86. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  87. doctr/models/classification/predictor/tensorflow.py +0 -60
  88. doctr/models/classification/resnet/tensorflow.py +0 -418
  89. doctr/models/classification/textnet/tensorflow.py +0 -275
  90. doctr/models/classification/vgg/tensorflow.py +0 -125
  91. doctr/models/classification/vit/tensorflow.py +0 -201
  92. doctr/models/detection/_utils/tensorflow.py +0 -34
  93. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  94. doctr/models/detection/fast/tensorflow.py +0 -427
  95. doctr/models/detection/linknet/tensorflow.py +0 -377
  96. doctr/models/detection/predictor/tensorflow.py +0 -70
  97. doctr/models/kie_predictor/tensorflow.py +0 -187
  98. doctr/models/modules/layers/tensorflow.py +0 -171
  99. doctr/models/modules/transformer/tensorflow.py +0 -235
  100. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  101. doctr/models/predictor/tensorflow.py +0 -155
  102. doctr/models/preprocessor/tensorflow.py +0 -122
  103. doctr/models/recognition/crnn/tensorflow.py +0 -317
  104. doctr/models/recognition/master/tensorflow.py +0 -320
  105. doctr/models/recognition/parseq/tensorflow.py +0 -516
  106. doctr/models/recognition/predictor/tensorflow.py +0 -79
  107. doctr/models/recognition/sar/tensorflow.py +0 -423
  108. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  109. doctr/models/utils/tensorflow.py +0 -189
  110. doctr/transforms/functional/tensorflow.py +0 -254
  111. doctr/transforms/modules/tensorflow.py +0 -562
  112. python_doctr-0.12.0.dist-info/RECORD +0 -180
  113. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +0 -0
  114. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/licenses/LICENSE +0 -0
  115. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  116. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -19,7 +19,7 @@ from doctr.datasets import VOCABS
19
19
  from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
20
20
 
21
21
  from ...classification import vit_s
22
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
22
+ from ...utils import _bf16_to_float32, load_pretrained_params
23
23
  from .base import _PARSeq, _PARSeqPostProcessor
24
24
 
25
25
  __all__ = ["PARSeq", "parseq"]
@@ -299,7 +299,7 @@ class PARSeq(_PARSeq, nn.Module):
299
299
 
300
300
  # Stop decoding if all sequences have reached the EOS token
301
301
  # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
302
- if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): # type: ignore[attr-defined]
302
+ if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
303
303
  break
304
304
 
305
305
  logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
@@ -314,7 +314,7 @@ class PARSeq(_PARSeq, nn.Module):
314
314
 
315
315
  # Create padding mask for refined target input maskes all behind EOS token as False
316
316
  # (N, 1, 1, max_length)
317
- target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
317
+ target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
318
318
  mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
319
319
  logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
320
320
 
@@ -367,7 +367,7 @@ class PARSeq(_PARSeq, nn.Module):
367
367
  # remove the [EOS] tokens for the succeeding perms
368
368
  if i == 1:
369
369
  gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out)
370
- n = (gt_out != self.vocab_size + 2).sum().item() # type: ignore[attr-defined]
370
+ n = (gt_out != self.vocab_size + 2).sum().item()
371
371
 
372
372
  loss /= loss_numel
373
373
 
@@ -391,7 +391,7 @@ class PARSeq(_PARSeq, nn.Module):
391
391
 
392
392
  if target is None or return_preds:
393
393
  # Disable for torch.compile compatibility
394
- @torch.compiler.disable # type: ignore[attr-defined]
394
+ @torch.compiler.disable
395
395
  def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
396
396
  return self.postprocessor(logits)
397
397
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -18,17 +18,15 @@ def split_crops(
18
18
  max_ratio: float,
19
19
  target_ratio: int,
20
20
  split_overlap_ratio: float,
21
- channels_last: bool = True,
22
21
  ) -> tuple[list[np.ndarray], list[int | tuple[int, int, float]], bool]:
23
22
  """
24
23
  Split crops horizontally if they exceed a given aspect ratio.
25
24
 
26
25
  Args:
27
- crops: List of image crops (H, W, C) if channels_last else (C, H, W).
26
+ crops: List of image crops (H, W, C).
28
27
  max_ratio: Aspect ratio threshold above which crops are split.
29
28
  target_ratio: Target aspect ratio after splitting (e.g., 4 for 128x32).
30
29
  split_overlap_ratio: Desired overlap between splits (as a fraction of split width).
31
- channels_last: Whether the crops are in channels-last format.
32
30
 
33
31
  Returns:
34
32
  A tuple containing:
@@ -44,14 +42,14 @@ def split_crops(
44
42
  crop_map: list[int | tuple[int, int, float]] = []
45
43
 
46
44
  for crop in crops:
47
- h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
45
+ h, w = crop.shape[:2]
48
46
  aspect_ratio = w / h
49
47
 
50
48
  if aspect_ratio > max_ratio:
51
49
  split_width = max(1, math.ceil(h * target_ratio))
52
50
  overlap_width = max(0, math.floor(split_width * split_overlap_ratio))
53
51
 
54
- splits, last_overlap = _split_horizontally(crop, split_width, overlap_width, channels_last)
52
+ splits, last_overlap = _split_horizontally(crop, split_width, overlap_width)
55
53
 
56
54
  # Remove any empty splits
57
55
  splits = [s for s in splits if all(dim > 0 for dim in s.shape)]
@@ -70,23 +68,20 @@ def split_crops(
70
68
  return new_crops, crop_map, remap_required
71
69
 
72
70
 
73
- def _split_horizontally(
74
- image: np.ndarray, split_width: int, overlap_width: int, channels_last: bool
75
- ) -> tuple[list[np.ndarray], float]:
71
+ def _split_horizontally(image: np.ndarray, split_width: int, overlap_width: int) -> tuple[list[np.ndarray], float]:
76
72
  """
77
73
  Horizontally split a single image with overlapping regions.
78
74
 
79
75
  Args:
80
- image: The image to split (H, W, C) if channels_last else (C, H, W).
76
+ image: The image to split (H, W, C).
81
77
  split_width: Width of each split.
82
78
  overlap_width: Width of the overlapping region.
83
- channels_last: Whether the image is in channels-last format.
84
79
 
85
80
  Returns:
86
81
  - A list of horizontal image slices.
87
82
  - The actual overlap ratio of the last split.
88
83
  """
89
- image_width = image.shape[1] if channels_last else image.shape[-1]
84
+ image_width = image.shape[1]
90
85
  if image_width <= split_width:
91
86
  return [image], 0.0
92
87
 
@@ -101,11 +96,7 @@ def _split_horizontally(
101
96
  splits = []
102
97
  for start_col in starts:
103
98
  end_col = start_col + split_width
104
- if channels_last:
105
- split = image[:, start_col:end_col, :]
106
- else:
107
- split = image[:, :, start_col:end_col]
108
- splits.append(split)
99
+ splits.append(image[:, start_col:end_col, :])
109
100
 
110
101
  # Calculate the last overlap ratio, if only one split no overlap
111
102
  last_overlap = 0
@@ -44,7 +44,7 @@ class RecognitionPredictor(nn.Module):
44
44
  @torch.inference_mode()
45
45
  def forward(
46
46
  self,
47
- crops: Sequence[np.ndarray | torch.Tensor],
47
+ crops: Sequence[np.ndarray],
48
48
  **kwargs: Any,
49
49
  ) -> list[tuple[str, float]]:
50
50
  if len(crops) == 0:
@@ -61,7 +61,6 @@ class RecognitionPredictor(nn.Module):
61
61
  self.critical_ar,
62
62
  self.target_ar,
63
63
  self.overlap_ratio,
64
- isinstance(crops[0], np.ndarray),
65
64
  )
66
65
  if remapped:
67
66
  crops = new_crops
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -15,7 +15,7 @@ from torchvision.models._utils import IntermediateLayerGetter
15
15
  from doctr.datasets import VOCABS
16
16
 
17
17
  from ...classification import resnet31
18
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
18
+ from ...utils import _bf16_to_float32, load_pretrained_params
19
19
  from ..core import RecognitionModel, RecognitionPostProcessor
20
20
 
21
21
  __all__ = ["SAR", "sar_resnet31"]
@@ -272,7 +272,7 @@ class SAR(nn.Module, RecognitionModel):
272
272
 
273
273
  if target is None or return_preds:
274
274
  # Disable for torch.compile compatibility
275
- @torch.compiler.disable # type: ignore[attr-defined]
275
+ @torch.compiler.disable
276
276
  def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
277
277
  return self.postprocessor(decoded_features)
278
278
 
@@ -304,7 +304,7 @@ class SAR(nn.Module, RecognitionModel):
304
304
  # Input length : number of timesteps
305
305
  input_len = model_output.shape[1]
306
306
  # Add one for additional <eos> token
307
- seq_len = seq_len + 1 # type: ignore[assignment]
307
+ seq_len = seq_len + 1
308
308
  # Compute loss
309
309
  # (N, L, vocab_size + 1)
310
310
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
@@ -1,4 +1 @@
1
- from doctr.file_utils import is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
1
+ from .pytorch import *
@@ -16,7 +16,7 @@ from torchvision.models._utils import IntermediateLayerGetter
16
16
  from doctr.datasets import VOCABS, decode_sequence
17
17
 
18
18
  from ...classification import vip_tiny
19
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
19
+ from ...utils import _bf16_to_float32, load_pretrained_params
20
20
  from ..core import RecognitionModel, RecognitionPostProcessor
21
21
 
22
22
  __all__ = ["VIPTR", "viptr_tiny"]
@@ -70,7 +70,7 @@ class VIPTRPostProcessor(RecognitionPostProcessor):
70
70
 
71
71
  def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
72
72
  """Performs decoding of raw output with CTC and decoding of CTC predictions
73
- with label_to_idx mapping dictionnary
73
+ with label_to_idx mapping dictionary
74
74
 
75
75
  Args:
76
76
  logits: raw output of the model, shape (N, C + 1, seq_len)
@@ -166,7 +166,7 @@ class VIPTR(RecognitionModel, nn.Module):
166
166
 
167
167
  if target is None or return_preds:
168
168
  # Disable for torch.compile compatibility
169
- @torch.compiler.disable # type: ignore[attr-defined]
169
+ @torch.compiler.disable
170
170
  def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
171
171
  return self.postprocessor(decoded_features)
172
172
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -15,7 +15,7 @@ from torchvision.models._utils import IntermediateLayerGetter
15
15
  from doctr.datasets import VOCABS
16
16
 
17
17
  from ...classification import vit_b, vit_s
18
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
18
+ from ...utils import _bf16_to_float32, load_pretrained_params
19
19
  from .base import _ViTSTR, _ViTSTRPostProcessor
20
20
 
21
21
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
@@ -117,7 +117,7 @@ class ViTSTR(_ViTSTR, nn.Module):
117
117
 
118
118
  if target is None or return_preds:
119
119
  # Disable for torch.compile compatibility
120
- @torch.compiler.disable # type: ignore[attr-defined]
120
+ @torch.compiler.disable
121
121
  def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
122
122
  return self.postprocessor(decoded_features)
123
123
 
@@ -149,7 +149,7 @@ class ViTSTR(_ViTSTR, nn.Module):
149
149
  # Input length : number of steps
150
150
  input_len = model_output.shape[1]
151
151
  # Add one for additional <eos> token (sos disappear in shift!)
152
- seq_len = seq_len + 1 # type: ignore[assignment]
152
+ seq_len = seq_len + 1
153
153
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
154
154
  # The "masked" first gt char is <sos>.
155
155
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -5,8 +5,8 @@
5
5
 
6
6
  from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available, is_torch_available
9
8
  from doctr.models.preprocessor import PreProcessor
9
+ from doctr.models.utils import _CompiledModule
10
10
 
11
11
  from .. import recognition
12
12
  from .predictor import RecognitionPredictor
@@ -23,11 +23,9 @@ ARCHS: list[str] = [
23
23
  "vitstr_small",
24
24
  "vitstr_base",
25
25
  "parseq",
26
+ "viptr_tiny",
26
27
  ]
27
28
 
28
- if is_torch_available():
29
- ARCHS.extend(["viptr_tiny"])
30
-
31
29
 
32
30
  def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor:
33
31
  if isinstance(arch, str):
@@ -38,14 +36,16 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
38
36
  pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
39
37
  )
40
38
  else:
41
- allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
42
- if is_torch_available():
43
- # Add VIPTR which is only available in torch at the moment
44
- allowed_archs.append(recognition.VIPTR)
45
- # Adding the type for torch compiled models to the allowed architectures
46
- from doctr.models.utils import _CompiledModule
47
-
48
- allowed_archs.append(_CompiledModule)
39
+ # Adding the type for torch compiled models to the allowed architectures
40
+ allowed_archs = [
41
+ recognition.CRNN,
42
+ recognition.SAR,
43
+ recognition.MASTER,
44
+ recognition.ViTSTR,
45
+ recognition.PARSeq,
46
+ recognition.VIPTR,
47
+ _CompiledModule,
48
+ ]
49
49
 
50
50
  if not isinstance(arch, tuple(allowed_archs)):
51
51
  raise ValueError(f"unknown architecture: {type(arch)}")
@@ -56,7 +56,7 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
56
56
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
57
57
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
58
58
  kwargs["batch_size"] = kwargs.get("batch_size", 128)
59
- input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:]
59
+ input_shape = _model.cfg["input_shape"][-2:]
60
60
  predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model)
61
61
 
62
62
  return predictor
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -164,7 +164,7 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
164
164
  """
165
165
  torch.onnx.export(
166
166
  model,
167
- dummy_input,
167
+ dummy_input, # type: ignore[arg-type]
168
168
  f"{model_name}.onnx",
169
169
  input_names=["input"],
170
170
  output_names=["logits"],
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -33,9 +33,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
33
33
  rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
34
34
  # Inverse the color
35
35
  if out.dtype == torch.uint8:
36
- out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) # type: ignore[attr-defined]
36
+ out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
37
37
  else:
38
- out = out * rgb_shift.to(dtype=out.dtype) # type: ignore[attr-defined]
38
+ out = out * rgb_shift.to(dtype=out.dtype)
39
39
  # Inverse the color
40
40
  out = 255 - out if out.dtype == torch.uint8 else 1 - out
41
41
  return out
@@ -77,7 +77,7 @@ def rotate_sample(
77
77
  rotated_geoms: np.ndarray = rotate_abs_geoms(
78
78
  _geoms,
79
79
  angle,
80
- img.shape[1:],
80
+ img.shape[1:], # type: ignore[arg-type]
81
81
  expand,
82
82
  ).astype(np.float32)
83
83
 
@@ -124,7 +124,7 @@ def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwarg
124
124
  Returns:
125
125
  Shadowed image as a PyTorch tensor (same shape as input).
126
126
  """
127
- shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
127
+ shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type]
128
128
  opacity = np.random.uniform(*opacity_range)
129
129
 
130
130
  # Apply Gaussian blur to the shadow mask
@@ -1,8 +1,2 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
1
  from .base import *
4
-
5
- if is_torch_available():
6
- from .pytorch import *
7
- elif is_tf_available():
8
- from .tensorflow import * # type: ignore[assignment]
2
+ from .pytorch import *
@@ -20,27 +20,13 @@ __all__ = ["SampleCompose", "ImageTransform", "ColorInversion", "OneOf", "Random
20
20
  class SampleCompose(NestedObject):
21
21
  """Implements a wrapper that will apply transformations sequentially on both image and target
22
22
 
23
- .. tabs::
23
+ .. code:: python
24
24
 
25
- .. tab:: PyTorch
26
-
27
- .. code:: python
28
-
29
- >>> import numpy as np
30
- >>> import torch
31
- >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
32
- >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
33
- >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
34
-
35
- .. tab:: TensorFlow
36
-
37
- .. code:: python
38
-
39
- >>> import numpy as np
40
- >>> import tensorflow as tf
41
- >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
42
- >>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
43
- >>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4)))
25
+ >>> import numpy as np
26
+ >>> import torch
27
+ >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
28
+ >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
29
+ >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
44
30
 
45
31
  Args:
46
32
  transforms: list of transformation modules
@@ -61,25 +47,12 @@ class SampleCompose(NestedObject):
61
47
  class ImageTransform(NestedObject):
62
48
  """Implements a transform wrapper to turn an image-only transformation into an image+target transform
63
49
 
64
- .. tabs::
65
-
66
- .. tab:: PyTorch
67
-
68
- .. code:: python
50
+ .. code:: python
69
51
 
70
- >>> import torch
71
- >>> from doctr.transforms import ImageTransform, ColorInversion
72
- >>> transfo = ImageTransform(ColorInversion((32, 32)))
73
- >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
74
-
75
- .. tab:: TensorFlow
76
-
77
- .. code:: python
78
-
79
- >>> import tensorflow as tf
80
- >>> from doctr.transforms import ImageTransform, ColorInversion
81
- >>> transfo = ImageTransform(ColorInversion((32, 32)))
82
- >>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None)
52
+ >>> import torch
53
+ >>> from doctr.transforms import ImageTransform, ColorInversion
54
+ >>> transfo = ImageTransform(ColorInversion((32, 32)))
55
+ >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
83
56
 
84
57
  Args:
85
58
  transform: the image transformation module to wrap
@@ -99,25 +72,12 @@ class ColorInversion(NestedObject):
99
72
  """Applies the following tranformation to a tensor (image or batch of images):
100
73
  convert to grayscale, colorize (shift 0-values randomly), and then invert colors
101
74
 
102
- .. tabs::
103
-
104
- .. tab:: PyTorch
105
-
106
- .. code:: python
107
-
108
- >>> import torch
109
- >>> from doctr.transforms import ColorInversion
110
- >>> transfo = ColorInversion(min_val=0.6)
111
- >>> out = transfo(torch.rand(8, 64, 64, 3))
75
+ .. code:: python
112
76
 
113
- .. tab:: TensorFlow
114
-
115
- .. code:: python
116
-
117
- >>> import tensorflow as tf
118
- >>> from doctr.transforms import ColorInversion
119
- >>> transfo = ColorInversion(min_val=0.6)
120
- >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1))
77
+ >>> import torch
78
+ >>> from doctr.transforms import ColorInversion
79
+ >>> transfo = ColorInversion(min_val=0.6)
80
+ >>> out = transfo(torch.rand(8, 64, 64, 3))
121
81
 
122
82
  Args:
123
83
  min_val: range [min_val, 1] to colorize RGB pixels
@@ -136,25 +96,12 @@ class ColorInversion(NestedObject):
136
96
  class OneOf(NestedObject):
137
97
  """Randomly apply one of the input transformations
138
98
 
139
- .. tabs::
140
-
141
- .. tab:: PyTorch
142
-
143
- .. code:: python
144
-
145
- >>> import torch
146
- >>> from doctr.transforms import OneOf
147
- >>> transfo = OneOf([JpegQuality(), Gamma()])
148
- >>> out = transfo(torch.rand(1, 64, 64, 3))
149
-
150
- .. tab:: TensorFlow
99
+ .. code:: python
151
100
 
152
- .. code:: python
153
-
154
- >>> import tensorflow as tf
155
- >>> from doctr.transforms import OneOf
156
- >>> transfo = OneOf([JpegQuality(), Gamma()])
157
- >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
101
+ >>> import torch
102
+ >>> from doctr.transforms import OneOf
103
+ >>> transfo = OneOf([JpegQuality(), Gamma()])
104
+ >>> out = transfo(torch.rand(1, 64, 64, 3))
158
105
 
159
106
  Args:
160
107
  transforms: list of transformations, one only will be picked
@@ -175,25 +122,12 @@ class OneOf(NestedObject):
175
122
  class RandomApply(NestedObject):
176
123
  """Apply with a probability p the input transformation
177
124
 
178
- .. tabs::
179
-
180
- .. tab:: PyTorch
181
-
182
- .. code:: python
183
-
184
- >>> import torch
185
- >>> from doctr.transforms import RandomApply
186
- >>> transfo = RandomApply(Gamma(), p=.5)
187
- >>> out = transfo(torch.rand(1, 64, 64, 3))
188
-
189
- .. tab:: TensorFlow
190
-
191
- .. code:: python
125
+ .. code:: python
192
126
 
193
- >>> import tensorflow as tf
194
- >>> from doctr.transforms import RandomApply
195
- >>> transfo = RandomApply(Gamma(), p=.5)
196
- >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
127
+ >>> import torch
128
+ >>> from doctr.transforms import RandomApply
129
+ >>> transfo = RandomApply(Gamma(), p=.5)
130
+ >>> out = transfo(torch.rand(1, 64, 64, 3))
197
131
 
198
132
  Args:
199
133
  transform: transformation to apply
@@ -13,7 +13,7 @@ from torch.nn.functional import pad
13
13
  from torchvision.transforms import functional as F
14
14
  from torchvision.transforms import transforms as T
15
15
 
16
- from ..functional.pytorch import random_shadow
16
+ from ..functional import random_shadow
17
17
 
18
18
  __all__ = [
19
19
  "Resize",
@@ -27,7 +27,21 @@ __all__ = [
27
27
 
28
28
 
29
29
  class Resize(T.Resize):
30
- """Resize the input image to the given size"""
30
+ """Resize the input image to the given size
31
+
32
+ >>> import torch
33
+ >>> from doctr.transforms import Resize
34
+ >>> transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=True)
35
+ >>> out = transfo(torch.rand((3, 64, 64)))
36
+
37
+ Args:
38
+ size: output size in pixels, either a tuple (height, width) or a single integer for square images
39
+ interpolation: interpolation mode to use for resizing, default is bilinear
40
+ preserve_aspect_ratio: whether to preserve the aspect ratio of the image,
41
+ if True, the image will be resized to fit within the target size while maintaining its aspect ratio
42
+ symmetric_pad: whether to symmetrically pad the image to the target size,
43
+ if True, the image will be padded equally on both sides to fit the target size
44
+ """
31
45
 
32
46
  def __init__(
33
47
  self,
@@ -36,25 +50,19 @@ class Resize(T.Resize):
36
50
  preserve_aspect_ratio: bool = False,
37
51
  symmetric_pad: bool = False,
38
52
  ) -> None:
39
- super().__init__(size, interpolation, antialias=True)
53
+ super().__init__(size if isinstance(size, (list, tuple)) else (size, size), interpolation, antialias=True)
40
54
  self.preserve_aspect_ratio = preserve_aspect_ratio
41
55
  self.symmetric_pad = symmetric_pad
42
56
 
43
- if not isinstance(self.size, (int, tuple, list)):
44
- raise AssertionError("size should be either a tuple, a list or an int")
45
-
46
57
  def forward(
47
58
  self,
48
59
  img: torch.Tensor,
49
60
  target: np.ndarray | None = None,
50
61
  ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
51
- if isinstance(self.size, int):
52
- target_ratio = img.shape[-2] / img.shape[-1]
53
- else:
54
- target_ratio = self.size[0] / self.size[1]
62
+ target_ratio = self.size[0] / self.size[1]
55
63
  actual_ratio = img.shape[-2] / img.shape[-1]
56
64
 
57
- if not self.preserve_aspect_ratio or (target_ratio == actual_ratio and (isinstance(self.size, (tuple, list)))):
65
+ if not self.preserve_aspect_ratio or (target_ratio == actual_ratio):
58
66
  # If we don't preserve the aspect ratio or the wanted aspect ratio is the same than the original one
59
67
  # We can use with the regular resize
60
68
  if target is not None:
@@ -62,16 +70,10 @@ class Resize(T.Resize):
62
70
  return super().forward(img)
63
71
  else:
64
72
  # Resize
65
- if isinstance(self.size, (tuple, list)):
66
- if actual_ratio > target_ratio:
67
- tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1))
68
- else:
69
- tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1])
70
- elif isinstance(self.size, int): # self.size is the longest side, infer the other
71
- if img.shape[-2] <= img.shape[-1]:
72
- tmp_size = (max(int(self.size * actual_ratio), 1), self.size)
73
- else:
74
- tmp_size = (self.size, max(int(self.size / actual_ratio), 1))
73
+ if actual_ratio > target_ratio:
74
+ tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1))
75
+ else:
76
+ tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1])
75
77
 
76
78
  # Scale image
77
79
  img = F.resize(img, tmp_size, self.interpolation, antialias=True)
@@ -93,14 +95,14 @@ class Resize(T.Resize):
93
95
  if self.preserve_aspect_ratio:
94
96
  # Get absolute coords
95
97
  if target.shape[1:] == (4,):
96
- if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
98
+ if self.symmetric_pad:
97
99
  target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
98
100
  target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
99
101
  else:
100
102
  target[:, [0, 2]] *= raw_shape[-1] / img.shape[-1]
101
103
  target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
102
104
  elif target.shape[1:] == (4, 2):
103
- if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
105
+ if self.symmetric_pad:
104
106
  target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
105
107
  target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
106
108
  else:
@@ -143,9 +145,9 @@ class GaussianNoise(torch.nn.Module):
143
145
  # Reshape the distribution
144
146
  noise = self.mean + 2 * self.std * torch.rand(x.shape, device=x.device) - self.std
145
147
  if x.dtype == torch.uint8:
146
- return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8) # type: ignore[attr-defined]
148
+ return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8)
147
149
  else:
148
- return (x + noise.to(dtype=x.dtype)).clamp(0, 1) # type: ignore[attr-defined]
150
+ return (x + noise.to(dtype=x.dtype)).clamp(0, 1)
149
151
 
150
152
  def extra_repr(self) -> str:
151
153
  return f"mean={self.mean}, std={self.std}"
@@ -233,7 +235,7 @@ class RandomShadow(torch.nn.Module):
233
235
  try:
234
236
  if x.dtype == torch.uint8:
235
237
  return (
236
- ( # type: ignore[attr-defined]
238
+ (
237
239
  255
238
240
  * random_shadow(
239
241
  x.to(dtype=torch.float32) / 255,