python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,11 +1,12 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
7
  import random
8
- from typing import Any, Callable, List, Optional, Tuple, Union
8
+ from collections.abc import Callable
9
+ from typing import Any
9
10
 
10
11
  import numpy as np
11
12
 
@@ -21,37 +22,36 @@ class SampleCompose(NestedObject):
21
22
 
22
23
  .. tabs::
23
24
 
24
- .. tab:: TensorFlow
25
+ .. tab:: PyTorch
25
26
 
26
27
  .. code:: python
27
28
 
28
29
  >>> import numpy as np
29
- >>> import tensorflow as tf
30
+ >>> import torch
30
31
  >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
31
- >>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
32
- >>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4)))
32
+ >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
33
+ >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
33
34
 
34
- .. tab:: PyTorch
35
+ .. tab:: TensorFlow
35
36
 
36
37
  .. code:: python
37
38
 
38
39
  >>> import numpy as np
39
- >>> import torch
40
+ >>> import tensorflow as tf
40
41
  >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
41
- >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
42
- >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
42
+ >>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
43
+ >>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4)))
43
44
 
44
45
  Args:
45
- ----
46
46
  transforms: list of transformation modules
47
47
  """
48
48
 
49
- _children_names: List[str] = ["sample_transforms"]
49
+ _children_names: list[str] = ["sample_transforms"]
50
50
 
51
- def __init__(self, transforms: List[Callable[[Any, Any], Tuple[Any, Any]]]) -> None:
51
+ def __init__(self, transforms: list[Callable[[Any, Any], tuple[Any, Any]]]) -> None:
52
52
  self.sample_transforms = transforms
53
53
 
54
- def __call__(self, x: Any, target: Any) -> Tuple[Any, Any]:
54
+ def __call__(self, x: Any, target: Any) -> tuple[Any, Any]:
55
55
  for t in self.sample_transforms:
56
56
  x, target = t(x, target)
57
57
 
@@ -63,35 +63,34 @@ class ImageTransform(NestedObject):
63
63
 
64
64
  .. tabs::
65
65
 
66
- .. tab:: TensorFlow
66
+ .. tab:: PyTorch
67
67
 
68
68
  .. code:: python
69
69
 
70
- >>> import tensorflow as tf
70
+ >>> import torch
71
71
  >>> from doctr.transforms import ImageTransform, ColorInversion
72
72
  >>> transfo = ImageTransform(ColorInversion((32, 32)))
73
- >>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None)
73
+ >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
74
74
 
75
- .. tab:: PyTorch
75
+ .. tab:: TensorFlow
76
76
 
77
77
  .. code:: python
78
78
 
79
- >>> import torch
79
+ >>> import tensorflow as tf
80
80
  >>> from doctr.transforms import ImageTransform, ColorInversion
81
81
  >>> transfo = ImageTransform(ColorInversion((32, 32)))
82
- >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
82
+ >>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None)
83
83
 
84
84
  Args:
85
- ----
86
85
  transform: the image transformation module to wrap
87
86
  """
88
87
 
89
- _children_names: List[str] = ["img_transform"]
88
+ _children_names: list[str] = ["img_transform"]
90
89
 
91
90
  def __init__(self, transform: Callable[[Any], Any]) -> None:
92
91
  self.img_transform = transform
93
92
 
94
- def __call__(self, img: Any, target: Any) -> Tuple[Any, Any]:
93
+ def __call__(self, img: Any, target: Any) -> tuple[Any, Any]:
95
94
  img = self.img_transform(img)
96
95
  return img, target
97
96
 
@@ -102,26 +101,25 @@ class ColorInversion(NestedObject):
102
101
 
103
102
  .. tabs::
104
103
 
105
- .. tab:: TensorFlow
104
+ .. tab:: PyTorch
106
105
 
107
106
  .. code:: python
108
107
 
109
- >>> import tensorflow as tf
108
+ >>> import torch
110
109
  >>> from doctr.transforms import ColorInversion
111
110
  >>> transfo = ColorInversion(min_val=0.6)
112
- >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1))
111
+ >>> out = transfo(torch.rand(8, 64, 64, 3))
113
112
 
114
- .. tab:: PyTorch
113
+ .. tab:: TensorFlow
115
114
 
116
115
  .. code:: python
117
116
 
118
- >>> import torch
117
+ >>> import tensorflow as tf
119
118
  >>> from doctr.transforms import ColorInversion
120
119
  >>> transfo = ColorInversion(min_val=0.6)
121
- >>> out = transfo(torch.rand(8, 64, 64, 3))
120
+ >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1))
122
121
 
123
122
  Args:
124
- ----
125
123
  min_val: range [min_val, 1] to colorize RGB pixels
126
124
  """
127
125
 
@@ -140,35 +138,34 @@ class OneOf(NestedObject):
140
138
 
141
139
  .. tabs::
142
140
 
143
- .. tab:: TensorFlow
141
+ .. tab:: PyTorch
144
142
 
145
143
  .. code:: python
146
144
 
147
- >>> import tensorflow as tf
145
+ >>> import torch
148
146
  >>> from doctr.transforms import OneOf
149
147
  >>> transfo = OneOf([JpegQuality(), Gamma()])
150
- >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
148
+ >>> out = transfo(torch.rand(1, 64, 64, 3))
151
149
 
152
- .. tab:: PyTorch
150
+ .. tab:: TensorFlow
153
151
 
154
152
  .. code:: python
155
153
 
156
- >>> import torch
154
+ >>> import tensorflow as tf
157
155
  >>> from doctr.transforms import OneOf
158
156
  >>> transfo = OneOf([JpegQuality(), Gamma()])
159
- >>> out = transfo(torch.rand(1, 64, 64, 3))
157
+ >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
160
158
 
161
159
  Args:
162
- ----
163
160
  transforms: list of transformations, one only will be picked
164
161
  """
165
162
 
166
- _children_names: List[str] = ["transforms"]
163
+ _children_names: list[str] = ["transforms"]
167
164
 
168
- def __init__(self, transforms: List[Callable[[Any], Any]]) -> None:
165
+ def __init__(self, transforms: list[Callable[[Any], Any]]) -> None:
169
166
  self.transforms = transforms
170
167
 
171
- def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]:
168
+ def __call__(self, img: Any, target: np.ndarray | None = None) -> Any | tuple[Any, np.ndarray]:
172
169
  # Pick transformation
173
170
  transfo = self.transforms[int(random.random() * len(self.transforms))]
174
171
  # Apply
@@ -180,26 +177,25 @@ class RandomApply(NestedObject):
180
177
 
181
178
  .. tabs::
182
179
 
183
- .. tab:: TensorFlow
180
+ .. tab:: PyTorch
184
181
 
185
182
  .. code:: python
186
183
 
187
- >>> import tensorflow as tf
184
+ >>> import torch
188
185
  >>> from doctr.transforms import RandomApply
189
186
  >>> transfo = RandomApply(Gamma(), p=.5)
190
- >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
187
+ >>> out = transfo(torch.rand(1, 64, 64, 3))
191
188
 
192
- .. tab:: PyTorch
189
+ .. tab:: TensorFlow
193
190
 
194
191
  .. code:: python
195
192
 
196
- >>> import torch
193
+ >>> import tensorflow as tf
197
194
  >>> from doctr.transforms import RandomApply
198
195
  >>> transfo = RandomApply(Gamma(), p=.5)
199
- >>> out = transfo(torch.rand(1, 64, 64, 3))
196
+ >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
200
197
 
201
198
  Args:
202
- ----
203
199
  transform: transformation to apply
204
200
  p: probability to apply
205
201
  """
@@ -211,7 +207,7 @@ class RandomApply(NestedObject):
211
207
  def extra_repr(self) -> str:
212
208
  return f"transform={self.transform}, p={self.p}"
213
209
 
214
- def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]:
210
+ def __call__(self, img: Any, target: np.ndarray | None = None) -> Any | tuple[Any, np.ndarray]:
215
211
  if random.random() < self.p:
216
212
  return self.transform(img) if target is None else self.transform(img, target) # type: ignore[call-arg]
217
213
  return img if target is None else (img, target)
@@ -224,9 +220,7 @@ class RandomRotate(NestedObject):
224
220
  :align: center
225
221
 
226
222
  Args:
227
- ----
228
- max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in
229
- [-max_angle, max_angle]
223
+ max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in [-max_angle, max_angle]
230
224
  expand: whether the image should be padded before the rotation
231
225
  """
232
226
 
@@ -237,7 +231,7 @@ class RandomRotate(NestedObject):
237
231
  def extra_repr(self) -> str:
238
232
  return f"max_angle={self.max_angle}, expand={self.expand}"
239
233
 
240
- def __call__(self, img: Any, target: np.ndarray) -> Tuple[Any, np.ndarray]:
234
+ def __call__(self, img: Any, target: np.ndarray) -> tuple[Any, np.ndarray]:
241
235
  angle = random.uniform(-self.max_angle, self.max_angle)
242
236
  r_img, r_polys = F.rotate_sample(img, target, angle, self.expand)
243
237
  # Removes deleted boxes
@@ -249,19 +243,18 @@ class RandomCrop(NestedObject):
249
243
  """Randomly crop a tensor image and its boxes
250
244
 
251
245
  Args:
252
- ----
253
246
  scale: tuple of floats, relative (min_area, max_area) of the crop
254
247
  ratio: tuple of float, relative (min_ratio, max_ratio) where ratio = h/w
255
248
  """
256
249
 
257
- def __init__(self, scale: Tuple[float, float] = (0.08, 1.0), ratio: Tuple[float, float] = (0.75, 1.33)) -> None:
250
+ def __init__(self, scale: tuple[float, float] = (0.08, 1.0), ratio: tuple[float, float] = (0.75, 1.33)) -> None:
258
251
  self.scale = scale
259
252
  self.ratio = ratio
260
253
 
261
254
  def extra_repr(self) -> str:
262
255
  return f"scale={self.scale}, ratio={self.ratio}"
263
256
 
264
- def __call__(self, img: Any, target: np.ndarray) -> Tuple[Any, np.ndarray]:
257
+ def __call__(self, img: Any, target: np.ndarray) -> tuple[Any, np.ndarray]:
265
258
  scale = random.uniform(self.scale[0], self.scale[1])
266
259
  ratio = random.uniform(self.ratio[0], self.ratio[1])
267
260
 
@@ -1,21 +1,29 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Optional, Tuple, Union
8
7
 
9
8
  import numpy as np
10
9
  import torch
11
10
  from PIL.Image import Image
11
+ from scipy.ndimage import gaussian_filter
12
12
  from torch.nn.functional import pad
13
13
  from torchvision.transforms import functional as F
14
14
  from torchvision.transforms import transforms as T
15
15
 
16
16
  from ..functional.pytorch import random_shadow
17
17
 
18
- __all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow", "RandomResize"]
18
+ __all__ = [
19
+ "Resize",
20
+ "GaussianNoise",
21
+ "ChannelShuffle",
22
+ "RandomHorizontalFlip",
23
+ "RandomShadow",
24
+ "RandomResize",
25
+ "GaussianBlur",
26
+ ]
19
27
 
20
28
 
21
29
  class Resize(T.Resize):
@@ -23,7 +31,7 @@ class Resize(T.Resize):
23
31
 
24
32
  def __init__(
25
33
  self,
26
- size: Union[int, Tuple[int, int]],
34
+ size: int | tuple[int, int],
27
35
  interpolation=F.InterpolationMode.BILINEAR,
28
36
  preserve_aspect_ratio: bool = False,
29
37
  symmetric_pad: bool = False,
@@ -38,8 +46,8 @@ class Resize(T.Resize):
38
46
  def forward(
39
47
  self,
40
48
  img: torch.Tensor,
41
- target: Optional[np.ndarray] = None,
42
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, np.ndarray]]:
49
+ target: np.ndarray | None = None,
50
+ ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
43
51
  if isinstance(self.size, int):
44
52
  target_ratio = img.shape[-2] / img.shape[-1]
45
53
  else:
@@ -74,16 +82,18 @@ class Resize(T.Resize):
74
82
  if self.symmetric_pad:
75
83
  half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2))
76
84
  _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1])
85
+ # Pad image
77
86
  img = pad(img, _pad)
78
87
 
79
88
  # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
80
89
  if target is not None:
90
+ if self.symmetric_pad:
91
+ offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
92
+
81
93
  if self.preserve_aspect_ratio:
82
94
  # Get absolute coords
83
95
  if target.shape[1:] == (4,):
84
96
  if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
85
- if np.max(target) <= 1:
86
- offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
87
97
  target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
88
98
  target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
89
99
  else:
@@ -91,16 +101,15 @@ class Resize(T.Resize):
91
101
  target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
92
102
  elif target.shape[1:] == (4, 2):
93
103
  if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
94
- if np.max(target) <= 1:
95
- offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
96
104
  target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
97
105
  target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
98
106
  else:
99
107
  target[..., 0] *= raw_shape[-1] / img.shape[-1]
100
108
  target[..., 1] *= raw_shape[-2] / img.shape[-2]
101
109
  else:
102
- raise AssertionError
103
- return img, target
110
+ raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")
111
+
112
+ return img, np.clip(target, 0, 1)
104
113
 
105
114
  return img
106
115
 
@@ -121,7 +130,6 @@ class GaussianNoise(torch.nn.Module):
121
130
  >>> out = transfo(torch.rand((3, 224, 224)))
122
131
 
123
132
  Args:
124
- ----
125
133
  mean : mean of the gaussian distribution
126
134
  std : std of the gaussian distribution
127
135
  """
@@ -135,14 +143,47 @@ class GaussianNoise(torch.nn.Module):
135
143
  # Reshape the distribution
136
144
  noise = self.mean + 2 * self.std * torch.rand(x.shape, device=x.device) - self.std
137
145
  if x.dtype == torch.uint8:
138
- return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8)
146
+ return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8) # type: ignore[attr-defined]
139
147
  else:
140
- return (x + noise.to(dtype=x.dtype)).clamp(0, 1)
148
+ return (x + noise.to(dtype=x.dtype)).clamp(0, 1) # type: ignore[attr-defined]
141
149
 
142
150
  def extra_repr(self) -> str:
143
151
  return f"mean={self.mean}, std={self.std}"
144
152
 
145
153
 
154
+ class GaussianBlur(torch.nn.Module):
155
+ """Apply Gaussian Blur to the input tensor
156
+
157
+ >>> import torch
158
+ >>> from doctr.transforms import GaussianBlur
159
+ >>> transfo = GaussianBlur(sigma=(0.0, 1.0))
160
+
161
+ Args:
162
+ sigma : standard deviation range for the gaussian kernel
163
+ """
164
+
165
+ def __init__(self, sigma: tuple[float, float]) -> None:
166
+ super().__init__()
167
+ self.sigma_range = sigma
168
+
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ # Sample a random sigma value within the specified range
171
+ sigma = torch.empty(1).uniform_(*self.sigma_range).item()
172
+
173
+ # Apply Gaussian blur along spatial dimensions only
174
+ blurred = torch.tensor(
175
+ gaussian_filter(
176
+ x.numpy(),
177
+ sigma=sigma,
178
+ mode="reflect",
179
+ truncate=4.0,
180
+ ),
181
+ dtype=x.dtype,
182
+ device=x.device,
183
+ )
184
+ return blurred
185
+
186
+
146
187
  class ChannelShuffle(torch.nn.Module):
147
188
  """Randomly shuffle channel order of a given image"""
148
189
 
@@ -158,9 +199,7 @@ class ChannelShuffle(torch.nn.Module):
158
199
  class RandomHorizontalFlip(T.RandomHorizontalFlip):
159
200
  """Randomly flip the input image horizontally"""
160
201
 
161
- def forward(
162
- self, img: Union[torch.Tensor, Image], target: np.ndarray
163
- ) -> Tuple[Union[torch.Tensor, Image], np.ndarray]:
202
+ def forward(self, img: torch.Tensor | Image, target: np.ndarray) -> tuple[torch.Tensor | Image, np.ndarray]:
164
203
  if torch.rand(1) < self.p:
165
204
  _img = F.hflip(img)
166
205
  _target = target.copy()
@@ -182,11 +221,10 @@ class RandomShadow(torch.nn.Module):
182
221
  >>> out = transfo(torch.rand((3, 64, 64)))
183
222
 
184
223
  Args:
185
- ----
186
224
  opacity_range : minimum and maximum opacity of the shade
187
225
  """
188
226
 
189
- def __init__(self, opacity_range: Optional[Tuple[float, float]] = None) -> None:
227
+ def __init__(self, opacity_range: tuple[float, float] | None = None) -> None:
190
228
  super().__init__()
191
229
  self.opacity_range = opacity_range if isinstance(opacity_range, tuple) else (0.2, 0.8)
192
230
 
@@ -195,7 +233,7 @@ class RandomShadow(torch.nn.Module):
195
233
  try:
196
234
  if x.dtype == torch.uint8:
197
235
  return (
198
- (
236
+ ( # type: ignore[attr-defined]
199
237
  255
200
238
  * random_shadow(
201
239
  x.to(dtype=torch.float32) / 255,
@@ -224,20 +262,19 @@ class RandomResize(torch.nn.Module):
224
262
  >>> out = transfo(torch.rand((3, 64, 64)))
225
263
 
226
264
  Args:
227
- ----
228
265
  scale_range: range of the resizing factor for width and height (independently)
229
266
  preserve_aspect_ratio: whether to preserve the aspect ratio of the image,
230
- given a float value, the aspect ratio will be preserved with this probability
267
+ given a float value, the aspect ratio will be preserved with this probability
231
268
  symmetric_pad: whether to symmetrically pad the image,
232
- given a float value, the symmetric padding will be applied with this probability
269
+ given a float value, the symmetric padding will be applied with this probability
233
270
  p: probability to apply the transformation
234
271
  """
235
272
 
236
273
  def __init__(
237
274
  self,
238
- scale_range: Tuple[float, float] = (0.3, 0.9),
239
- preserve_aspect_ratio: Union[bool, float] = False,
240
- symmetric_pad: Union[bool, float] = False,
275
+ scale_range: tuple[float, float] = (0.3, 0.9),
276
+ preserve_aspect_ratio: bool | float = False,
277
+ symmetric_pad: bool | float = False,
241
278
  p: float = 0.5,
242
279
  ) -> None:
243
280
  super().__init__()
@@ -247,7 +284,7 @@ class RandomResize(torch.nn.Module):
247
284
  self.p = p
248
285
  self._resize = Resize
249
286
 
250
- def forward(self, img: torch.Tensor, target: np.ndarray) -> Tuple[torch.Tensor, np.ndarray]:
287
+ def forward(self, img: torch.Tensor, target: np.ndarray) -> tuple[torch.Tensor, np.ndarray]:
251
288
  if torch.rand(1) < self.p:
252
289
  scale_h = np.random.uniform(*self.scale_range)
253
290
  scale_w = np.random.uniform(*self.scale_range)