python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -7,6 +7,7 @@ import logging
7
7
  from typing import Any
8
8
 
9
9
  import torch
10
+ import validators
10
11
  from torch import nn
11
12
 
12
13
  from doctr.utils.data import download_from_url
@@ -36,7 +37,7 @@ def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
36
37
 
37
38
  def load_pretrained_params(
38
39
  model: nn.Module,
39
- url: str | None = None,
40
+ path_or_url: str | None = None,
40
41
  hash_prefix: str | None = None,
41
42
  ignore_keys: list[str] | None = None,
42
43
  **kwargs: Any,
@@ -44,33 +45,42 @@ def load_pretrained_params(
44
45
  """Load a set of parameters onto a model
45
46
 
46
47
  >>> from doctr.models import load_pretrained_params
47
- >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
48
+ >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.pt")
48
49
 
49
50
  Args:
50
51
  model: the PyTorch model to be loaded
51
- url: URL of the zipped set of parameters
52
+ path_or_url: the path or URL to the model parameters (checkpoint)
52
53
  hash_prefix: first characters of SHA256 expected hash
53
54
  ignore_keys: list of weights to be ignored from the state_dict
54
55
  **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
55
56
  """
56
- if url is None:
57
- logging.warning("Invalid model URL, using default initialization.")
58
- else:
59
- archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
57
+ if path_or_url is None:
58
+ logging.warning("No model URL or Path provided, using default initialization.")
59
+ return
60
+
61
+ archive_path = (
62
+ download_from_url(path_or_url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
63
+ if validators.url(path_or_url)
64
+ else path_or_url
65
+ )
60
66
 
61
- # Read state_dict
62
- state_dict = torch.load(archive_path, map_location="cpu")
67
+ # Read state_dict
68
+ state_dict = torch.load(archive_path, map_location="cpu")
63
69
 
64
- # Remove weights from the state_dict
65
- if ignore_keys is not None and len(ignore_keys) > 0:
66
- for key in ignore_keys:
70
+ # Remove weights from the state_dict
71
+ if ignore_keys is not None and len(ignore_keys) > 0:
72
+ for key in ignore_keys:
73
+ if key in state_dict:
67
74
  state_dict.pop(key)
68
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
69
- if set(missing_keys) != set(ignore_keys) or len(unexpected_keys) > 0:
70
- raise ValueError("unable to load state_dict, due to non-matching keys.")
71
- else:
72
- # Load weights
73
- model.load_state_dict(state_dict)
75
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
76
+ if any(k not in ignore_keys for k in missing_keys + unexpected_keys):
77
+ raise ValueError(
78
+ "Unable to load state_dict, due to non-matching keys.\n"
79
+ + f"Unexpected keys: {unexpected_keys}\nMissing keys: {missing_keys}"
80
+ )
81
+ else:
82
+ # Load weights
83
+ model.load_state_dict(state_dict)
74
84
 
75
85
 
76
86
  def conv_sequence_pt(
@@ -154,7 +164,7 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
154
164
  """
155
165
  torch.onnx.export(
156
166
  model,
157
- dummy_input,
167
+ dummy_input, # type: ignore[arg-type]
158
168
  f"{model_name}.onnx",
159
169
  input_names=["input"],
160
170
  output_names=["logits"],
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -33,9 +33,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
33
33
  rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
34
34
  # Inverse the color
35
35
  if out.dtype == torch.uint8:
36
- out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) # type: ignore[attr-defined]
36
+ out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
37
37
  else:
38
- out = out * rgb_shift.to(dtype=out.dtype) # type: ignore[attr-defined]
38
+ out = out * rgb_shift.to(dtype=out.dtype)
39
39
  # Inverse the color
40
40
  out = 255 - out if out.dtype == torch.uint8 else 1 - out
41
41
  return out
@@ -77,7 +77,7 @@ def rotate_sample(
77
77
  rotated_geoms: np.ndarray = rotate_abs_geoms(
78
78
  _geoms,
79
79
  angle,
80
- img.shape[1:],
80
+ img.shape[1:], # type: ignore[arg-type]
81
81
  expand,
82
82
  ).astype(np.float32)
83
83
 
@@ -124,7 +124,7 @@ def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwarg
124
124
  Returns:
125
125
  Shadowed image as a PyTorch tensor (same shape as input).
126
126
  """
127
- shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
127
+ shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type]
128
128
  opacity = np.random.uniform(*opacity_range)
129
129
 
130
130
  # Apply Gaussian blur to the shadow mask
@@ -1,8 +1,2 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
1
  from .base import *
4
-
5
- if is_torch_available():
6
- from .pytorch import *
7
- elif is_tf_available():
8
- from .tensorflow import * # type: ignore[assignment]
2
+ from .pytorch import *
@@ -20,27 +20,13 @@ __all__ = ["SampleCompose", "ImageTransform", "ColorInversion", "OneOf", "Random
20
20
  class SampleCompose(NestedObject):
21
21
  """Implements a wrapper that will apply transformations sequentially on both image and target
22
22
 
23
- .. tabs::
23
+ .. code:: python
24
24
 
25
- .. tab:: PyTorch
26
-
27
- .. code:: python
28
-
29
- >>> import numpy as np
30
- >>> import torch
31
- >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
32
- >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
33
- >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
34
-
35
- .. tab:: TensorFlow
36
-
37
- .. code:: python
38
-
39
- >>> import numpy as np
40
- >>> import tensorflow as tf
41
- >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
42
- >>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
43
- >>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4)))
25
+ >>> import numpy as np
26
+ >>> import torch
27
+ >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
28
+ >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
29
+ >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
44
30
 
45
31
  Args:
46
32
  transforms: list of transformation modules
@@ -61,25 +47,12 @@ class SampleCompose(NestedObject):
61
47
  class ImageTransform(NestedObject):
62
48
  """Implements a transform wrapper to turn an image-only transformation into an image+target transform
63
49
 
64
- .. tabs::
65
-
66
- .. tab:: PyTorch
67
-
68
- .. code:: python
50
+ .. code:: python
69
51
 
70
- >>> import torch
71
- >>> from doctr.transforms import ImageTransform, ColorInversion
72
- >>> transfo = ImageTransform(ColorInversion((32, 32)))
73
- >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
74
-
75
- .. tab:: TensorFlow
76
-
77
- .. code:: python
78
-
79
- >>> import tensorflow as tf
80
- >>> from doctr.transforms import ImageTransform, ColorInversion
81
- >>> transfo = ImageTransform(ColorInversion((32, 32)))
82
- >>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None)
52
+ >>> import torch
53
+ >>> from doctr.transforms import ImageTransform, ColorInversion
54
+ >>> transfo = ImageTransform(ColorInversion((32, 32)))
55
+ >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
83
56
 
84
57
  Args:
85
58
  transform: the image transformation module to wrap
@@ -99,25 +72,12 @@ class ColorInversion(NestedObject):
99
72
  """Applies the following tranformation to a tensor (image or batch of images):
100
73
  convert to grayscale, colorize (shift 0-values randomly), and then invert colors
101
74
 
102
- .. tabs::
103
-
104
- .. tab:: PyTorch
105
-
106
- .. code:: python
107
-
108
- >>> import torch
109
- >>> from doctr.transforms import ColorInversion
110
- >>> transfo = ColorInversion(min_val=0.6)
111
- >>> out = transfo(torch.rand(8, 64, 64, 3))
75
+ .. code:: python
112
76
 
113
- .. tab:: TensorFlow
114
-
115
- .. code:: python
116
-
117
- >>> import tensorflow as tf
118
- >>> from doctr.transforms import ColorInversion
119
- >>> transfo = ColorInversion(min_val=0.6)
120
- >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1))
77
+ >>> import torch
78
+ >>> from doctr.transforms import ColorInversion
79
+ >>> transfo = ColorInversion(min_val=0.6)
80
+ >>> out = transfo(torch.rand(8, 64, 64, 3))
121
81
 
122
82
  Args:
123
83
  min_val: range [min_val, 1] to colorize RGB pixels
@@ -136,25 +96,12 @@ class ColorInversion(NestedObject):
136
96
  class OneOf(NestedObject):
137
97
  """Randomly apply one of the input transformations
138
98
 
139
- .. tabs::
140
-
141
- .. tab:: PyTorch
142
-
143
- .. code:: python
144
-
145
- >>> import torch
146
- >>> from doctr.transforms import OneOf
147
- >>> transfo = OneOf([JpegQuality(), Gamma()])
148
- >>> out = transfo(torch.rand(1, 64, 64, 3))
149
-
150
- .. tab:: TensorFlow
99
+ .. code:: python
151
100
 
152
- .. code:: python
153
-
154
- >>> import tensorflow as tf
155
- >>> from doctr.transforms import OneOf
156
- >>> transfo = OneOf([JpegQuality(), Gamma()])
157
- >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
101
+ >>> import torch
102
+ >>> from doctr.transforms import OneOf
103
+ >>> transfo = OneOf([JpegQuality(), Gamma()])
104
+ >>> out = transfo(torch.rand(1, 64, 64, 3))
158
105
 
159
106
  Args:
160
107
  transforms: list of transformations, one only will be picked
@@ -175,25 +122,12 @@ class OneOf(NestedObject):
175
122
  class RandomApply(NestedObject):
176
123
  """Apply with a probability p the input transformation
177
124
 
178
- .. tabs::
179
-
180
- .. tab:: PyTorch
181
-
182
- .. code:: python
183
-
184
- >>> import torch
185
- >>> from doctr.transforms import RandomApply
186
- >>> transfo = RandomApply(Gamma(), p=.5)
187
- >>> out = transfo(torch.rand(1, 64, 64, 3))
188
-
189
- .. tab:: TensorFlow
190
-
191
- .. code:: python
125
+ .. code:: python
192
126
 
193
- >>> import tensorflow as tf
194
- >>> from doctr.transforms import RandomApply
195
- >>> transfo = RandomApply(Gamma(), p=.5)
196
- >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
127
+ >>> import torch
128
+ >>> from doctr.transforms import RandomApply
129
+ >>> transfo = RandomApply(Gamma(), p=.5)
130
+ >>> out = transfo(torch.rand(1, 64, 64, 3))
197
131
 
198
132
  Args:
199
133
  transform: transformation to apply
@@ -13,7 +13,7 @@ from torch.nn.functional import pad
13
13
  from torchvision.transforms import functional as F
14
14
  from torchvision.transforms import transforms as T
15
15
 
16
- from ..functional.pytorch import random_shadow
16
+ from ..functional import random_shadow
17
17
 
18
18
  __all__ = [
19
19
  "Resize",
@@ -27,7 +27,21 @@ __all__ = [
27
27
 
28
28
 
29
29
  class Resize(T.Resize):
30
- """Resize the input image to the given size"""
30
+ """Resize the input image to the given size
31
+
32
+ >>> import torch
33
+ >>> from doctr.transforms import Resize
34
+ >>> transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=True)
35
+ >>> out = transfo(torch.rand((3, 64, 64)))
36
+
37
+ Args:
38
+ size: output size in pixels, either a tuple (height, width) or a single integer for square images
39
+ interpolation: interpolation mode to use for resizing, default is bilinear
40
+ preserve_aspect_ratio: whether to preserve the aspect ratio of the image,
41
+ if True, the image will be resized to fit within the target size while maintaining its aspect ratio
42
+ symmetric_pad: whether to symmetrically pad the image to the target size,
43
+ if True, the image will be padded equally on both sides to fit the target size
44
+ """
31
45
 
32
46
  def __init__(
33
47
  self,
@@ -36,25 +50,19 @@ class Resize(T.Resize):
36
50
  preserve_aspect_ratio: bool = False,
37
51
  symmetric_pad: bool = False,
38
52
  ) -> None:
39
- super().__init__(size, interpolation, antialias=True)
53
+ super().__init__(size if isinstance(size, (list, tuple)) else (size, size), interpolation, antialias=True)
40
54
  self.preserve_aspect_ratio = preserve_aspect_ratio
41
55
  self.symmetric_pad = symmetric_pad
42
56
 
43
- if not isinstance(self.size, (int, tuple, list)):
44
- raise AssertionError("size should be either a tuple, a list or an int")
45
-
46
57
  def forward(
47
58
  self,
48
59
  img: torch.Tensor,
49
60
  target: np.ndarray | None = None,
50
61
  ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
51
- if isinstance(self.size, int):
52
- target_ratio = img.shape[-2] / img.shape[-1]
53
- else:
54
- target_ratio = self.size[0] / self.size[1]
62
+ target_ratio = self.size[0] / self.size[1]
55
63
  actual_ratio = img.shape[-2] / img.shape[-1]
56
64
 
57
- if not self.preserve_aspect_ratio or (target_ratio == actual_ratio and (isinstance(self.size, (tuple, list)))):
65
+ if not self.preserve_aspect_ratio or (target_ratio == actual_ratio):
58
66
  # If we don't preserve the aspect ratio or the wanted aspect ratio is the same than the original one
59
67
  # We can use with the regular resize
60
68
  if target is not None:
@@ -62,16 +70,10 @@ class Resize(T.Resize):
62
70
  return super().forward(img)
63
71
  else:
64
72
  # Resize
65
- if isinstance(self.size, (tuple, list)):
66
- if actual_ratio > target_ratio:
67
- tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1))
68
- else:
69
- tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1])
70
- elif isinstance(self.size, int): # self.size is the longest side, infer the other
71
- if img.shape[-2] <= img.shape[-1]:
72
- tmp_size = (max(int(self.size * actual_ratio), 1), self.size)
73
- else:
74
- tmp_size = (self.size, max(int(self.size / actual_ratio), 1))
73
+ if actual_ratio > target_ratio:
74
+ tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1))
75
+ else:
76
+ tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1])
75
77
 
76
78
  # Scale image
77
79
  img = F.resize(img, tmp_size, self.interpolation, antialias=True)
@@ -93,14 +95,14 @@ class Resize(T.Resize):
93
95
  if self.preserve_aspect_ratio:
94
96
  # Get absolute coords
95
97
  if target.shape[1:] == (4,):
96
- if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
98
+ if self.symmetric_pad:
97
99
  target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
98
100
  target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
99
101
  else:
100
102
  target[:, [0, 2]] *= raw_shape[-1] / img.shape[-1]
101
103
  target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
102
104
  elif target.shape[1:] == (4, 2):
103
- if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
105
+ if self.symmetric_pad:
104
106
  target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
105
107
  target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
106
108
  else:
@@ -143,9 +145,9 @@ class GaussianNoise(torch.nn.Module):
143
145
  # Reshape the distribution
144
146
  noise = self.mean + 2 * self.std * torch.rand(x.shape, device=x.device) - self.std
145
147
  if x.dtype == torch.uint8:
146
- return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8) # type: ignore[attr-defined]
148
+ return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8)
147
149
  else:
148
- return (x + noise.to(dtype=x.dtype)).clamp(0, 1) # type: ignore[attr-defined]
150
+ return (x + noise.to(dtype=x.dtype)).clamp(0, 1)
149
151
 
150
152
  def extra_repr(self) -> str:
151
153
  return f"mean={self.mean}, std={self.std}"
@@ -233,7 +235,7 @@ class RandomShadow(torch.nn.Module):
233
235
  try:
234
236
  if x.dtype == torch.uint8:
235
237
  return (
236
- ( # type: ignore[attr-defined]
238
+ (
237
239
  255
238
240
  * random_shadow(
239
241
  x.to(dtype=torch.float32) / 255,
doctr/utils/data.py CHANGED
@@ -92,7 +92,7 @@ def download_from_url(
92
92
  # Create folder hierarchy
93
93
  folder_path.mkdir(parents=True, exist_ok=True)
94
94
  except OSError:
95
- error_message = f"Failed creating cache direcotry at {folder_path}"
95
+ error_message = f"Failed creating cache directory at {folder_path}"
96
96
  if os.environ.get("DOCTR_CACHE_DIR", ""):
97
97
  error_message += " using path from 'DOCTR_CACHE_DIR' environment variable."
98
98
  else:
doctr/utils/geometry.py CHANGED
@@ -300,7 +300,7 @@ def rotate_image(
300
300
  # Compute the expanded padding
301
301
  exp_img: np.ndarray
302
302
  if expand:
303
- exp_shape = compute_expanded_shape(image.shape[:2], angle) # type: ignore[arg-type]
303
+ exp_shape = compute_expanded_shape(image.shape[:2], angle)
304
304
  h_pad, w_pad = (
305
305
  int(max(0, ceil(exp_shape[0] - image.shape[0]))),
306
306
  int(max(0, ceil(exp_shape[1] - image.shape[1]))),
@@ -390,14 +390,13 @@ def convert_to_relative_coords(geoms: np.ndarray, img_shape: tuple[int, int]) ->
390
390
  raise ValueError(f"invalid format for arg `geoms`: {geoms.shape}")
391
391
 
392
392
 
393
- def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) -> list[np.ndarray]:
393
+ def extract_crops(img: np.ndarray, boxes: np.ndarray) -> list[np.ndarray]:
394
394
  """Created cropped images from list of bounding boxes
395
395
 
396
396
  Args:
397
397
  img: input image
398
398
  boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative
399
399
  coordinates (xmin, ymin, xmax, ymax)
400
- channels_last: whether the channel dimensions is the last one instead of the last one
401
400
 
402
401
  Returns:
403
402
  list of cropped images
@@ -409,21 +408,19 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True
409
408
 
410
409
  # Project relative coordinates
411
410
  _boxes = boxes.copy()
412
- h, w = img.shape[:2] if channels_last else img.shape[-2:]
411
+ h, w = img.shape[:2]
413
412
  if not np.issubdtype(_boxes.dtype, np.integer):
414
413
  _boxes[:, [0, 2]] *= w
415
414
  _boxes[:, [1, 3]] *= h
416
415
  _boxes = _boxes.round().astype(int)
417
416
  # Add last index
418
417
  _boxes[2:] += 1
419
- if channels_last:
420
- return deepcopy([img[box[1] : box[3], box[0] : box[2]] for box in _boxes])
421
418
 
422
- return deepcopy([img[:, box[1] : box[3], box[0] : box[2]] for box in _boxes])
419
+ return deepcopy([img[box[1] : box[3], box[0] : box[2]] for box in _boxes])
423
420
 
424
421
 
425
422
  def extract_rcrops(
426
- img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True, assume_horizontal: bool = False
423
+ img: np.ndarray, polys: np.ndarray, dtype=np.float32, assume_horizontal: bool = False
427
424
  ) -> list[np.ndarray]:
428
425
  """Created cropped images from list of rotated bounding boxes
429
426
 
@@ -431,7 +428,6 @@ def extract_rcrops(
431
428
  img: input image
432
429
  polys: bounding boxes of shape (N, 4, 2)
433
430
  dtype: target data type of bounding boxes
434
- channels_last: whether the channel dimensions is the last one instead of the last one
435
431
  assume_horizontal: whether the boxes are assumed to be only horizontally oriented
436
432
 
437
433
  Returns:
@@ -444,12 +440,12 @@ def extract_rcrops(
444
440
 
445
441
  # Project relative coordinates
446
442
  _boxes = polys.copy()
447
- height, width = img.shape[:2] if channels_last else img.shape[-2:]
443
+ height, width = img.shape[:2]
448
444
  if not np.issubdtype(_boxes.dtype, np.integer):
449
445
  _boxes[:, :, 0] *= width
450
446
  _boxes[:, :, 1] *= height
451
447
 
452
- src_img = img if channels_last else img.transpose(1, 2, 0)
448
+ src_img = img
453
449
 
454
450
  # Handle only horizontal oriented boxes
455
451
  if assume_horizontal:
@@ -148,7 +148,7 @@ def get_colors(num_colors: int) -> list[tuple[float, float, float]]:
148
148
  hue = i / 360.0
149
149
  lightness = (50 + np.random.rand() * 10) / 100.0
150
150
  saturation = (90 + np.random.rand() * 10) / 100.0
151
- colors.append(colorsys.hls_to_rgb(hue, lightness, saturation))
151
+ colors.append(colorsys.hls_to_rgb(hue, lightness, saturation)) # type: ignore[arg-type]
152
152
  return colors
153
153
 
154
154
 
doctr/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = 'v0.11.0'
1
+ __version__ = 'v1.0.0'