python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +8 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +7 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +4 -5
  17. doctr/datasets/ic13.py +4 -5
  18. doctr/datasets/iiit5k.py +6 -5
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +6 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +6 -5
  27. doctr/datasets/svhn.py +6 -5
  28. doctr/datasets/svt.py +4 -5
  29. doctr/datasets/synthtext.py +4 -5
  30. doctr/datasets/utils.py +34 -29
  31. doctr/datasets/vocabs.py +17 -7
  32. doctr/datasets/wildreceipt.py +14 -10
  33. doctr/file_utils.py +2 -7
  34. doctr/io/elements.py +59 -79
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +30 -48
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +8 -11
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +5 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +8 -21
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +6 -8
  52. doctr/models/classification/predictor/tensorflow.py +6 -8
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +20 -31
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +8 -15
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +9 -12
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +6 -12
  65. doctr/models/classification/zoo.py +19 -14
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +14 -26
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +14 -23
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +5 -6
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +3 -7
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +4 -5
  91. doctr/models/kie_predictor/pytorch.py +18 -19
  92. doctr/models/kie_predictor/tensorflow.py +13 -14
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -10
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  101. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +28 -29
  104. doctr/models/predictor/pytorch.py +12 -13
  105. doctr/models/predictor/tensorflow.py +8 -9
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +10 -14
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +11 -23
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +12 -22
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +16 -22
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +12 -21
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +12 -20
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +14 -17
  136. doctr/models/utils/tensorflow.py +17 -16
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +20 -28
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +58 -22
  145. doctr/transforms/modules/tensorflow.py +18 -32
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +16 -47
  150. doctr/utils/metrics.py +17 -37
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +9 -13
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.10.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,13 +1,13 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Tuple
8
7
 
9
8
  import numpy as np
10
9
  import torch
10
+ from scipy.ndimage import gaussian_filter
11
11
  from torchvision.transforms import functional as F
12
12
 
13
13
  from doctr.utils.geometry import rotate_abs_geoms
@@ -21,12 +21,10 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
21
21
  """Invert the colors of an image
22
22
 
23
23
  Args:
24
- ----
25
24
  img : torch.Tensor, the image to invert
26
25
  min_val : minimum value of the random shift
27
26
 
28
27
  Returns:
29
- -------
30
28
  the inverted image
31
29
  """
32
30
  out = F.rgb_to_grayscale(img, num_output_channels=3)
@@ -35,9 +33,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
35
33
  rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
36
34
  # Inverse the color
37
35
  if out.dtype == torch.uint8:
38
- out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
36
+ out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) # type: ignore[attr-defined]
39
37
  else:
40
- out = out * rgb_shift.to(dtype=out.dtype)
38
+ out = out * rgb_shift.to(dtype=out.dtype) # type: ignore[attr-defined]
41
39
  # Inverse the color
42
40
  out = 255 - out if out.dtype == torch.uint8 else 1 - out
43
41
  return out
@@ -48,18 +46,16 @@ def rotate_sample(
48
46
  geoms: np.ndarray,
49
47
  angle: float,
50
48
  expand: bool = False,
51
- ) -> Tuple[torch.Tensor, np.ndarray]:
49
+ ) -> tuple[torch.Tensor, np.ndarray]:
52
50
  """Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
53
51
 
54
52
  Args:
55
- ----
56
53
  img: image to rotate
57
54
  geoms: array of geometries of shape (N, 4) or (N, 4, 2)
58
55
  angle: angle in degrees. +: counter-clockwise, -: clockwise
59
56
  expand: whether the image should be padded before the rotation
60
57
 
61
58
  Returns:
62
- -------
63
59
  A tuple of rotated img (tensor), rotated geometries of shape (N, 4, 2)
64
60
  """
65
61
  rotated_img = F.rotate(img, angle=angle, fill=0, expand=expand) # Interpolation NEAREST by default
@@ -81,7 +77,7 @@ def rotate_sample(
81
77
  rotated_geoms: np.ndarray = rotate_abs_geoms(
82
78
  _geoms,
83
79
  angle,
84
- img.shape[1:], # type: ignore[arg-type]
80
+ img.shape[1:],
85
81
  expand,
86
82
  ).astype(np.float32)
87
83
 
@@ -93,18 +89,16 @@ def rotate_sample(
93
89
 
94
90
 
95
91
  def crop_detection(
96
- img: torch.Tensor, boxes: np.ndarray, crop_box: Tuple[float, float, float, float]
97
- ) -> Tuple[torch.Tensor, np.ndarray]:
92
+ img: torch.Tensor, boxes: np.ndarray, crop_box: tuple[float, float, float, float]
93
+ ) -> tuple[torch.Tensor, np.ndarray]:
98
94
  """Crop and image and associated bboxes
99
95
 
100
96
  Args:
101
- ----
102
97
  img: image to crop
103
98
  boxes: array of boxes to clip, absolute (int) or relative (float)
104
99
  crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords.
105
100
 
106
101
  Returns:
107
- -------
108
102
  A tuple of cropped image, cropped boxes, where the image is not resized.
109
103
  """
110
104
  if any(val < 0 or val > 1 for val in crop_box):
@@ -119,27 +113,25 @@ def crop_detection(
119
113
  return cropped_img, boxes
120
114
 
121
115
 
122
- def random_shadow(img: torch.Tensor, opacity_range: Tuple[float, float], **kwargs) -> torch.Tensor:
123
- """Crop and image and associated bboxes
116
+ def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwargs) -> torch.Tensor:
117
+ """Apply a random shadow effect to an image using NumPy for blurring.
124
118
 
125
119
  Args:
126
- ----
127
- img: image to modify
128
- opacity_range: the minimum and maximum desired opacity of the shadow
129
- **kwargs: additional arguments to pass to `create_shadow_mask`
120
+ img: Image to modify (C, H, W) as a PyTorch tensor.
121
+ opacity_range: The minimum and maximum desired opacity of the shadow.
122
+ **kwargs: Additional arguments to pass to `create_shadow_mask`.
130
123
 
131
124
  Returns:
132
- -------
133
- shaded image
125
+ Shadowed image as a PyTorch tensor (same shape as input).
134
126
  """
135
- shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type]
136
-
127
+ shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
137
128
  opacity = np.random.uniform(*opacity_range)
138
- shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...])
139
129
 
140
- # Add some blur to make it believable
141
- k = 7 + 2 * int(4 * np.random.rand(1))
130
+ # Apply Gaussian blur to the shadow mask
142
131
  sigma = np.random.uniform(0.5, 5.0)
143
- shadow_tensor = F.gaussian_blur(shadow_tensor, k, sigma=[sigma, sigma])
132
+ blurred_mask = gaussian_filter(shadow_mask, sigma=sigma)
133
+
134
+ shadow_tensor = 1 - torch.from_numpy(blurred_mask).float()
135
+ shadow_tensor = shadow_tensor.to(img.device).unsqueeze(0) # Add channel dimension
144
136
 
145
137
  return opacity * shadow_tensor * img + (1 - opacity) * img
@@ -1,12 +1,12 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
7
  import random
8
+ from collections.abc import Iterable
8
9
  from copy import deepcopy
9
- from typing import Iterable, Optional, Tuple, Union
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
@@ -22,12 +22,10 @@ def invert_colors(img: tf.Tensor, min_val: float = 0.6) -> tf.Tensor:
22
22
  """Invert the colors of an image
23
23
 
24
24
  Args:
25
- ----
26
25
  img : tf.Tensor, the image to invert
27
26
  min_val : minimum value of the random shift
28
27
 
29
28
  Returns:
30
- -------
31
29
  the inverted image
32
30
  """
33
31
  out = tf.image.rgb_to_grayscale(img) # Convert to gray
@@ -48,13 +46,11 @@ def rotated_img_tensor(img: tf.Tensor, angle: float, expand: bool = False) -> tf
48
46
  """Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
49
47
 
50
48
  Args:
51
- ----
52
49
  img: image to rotate
53
50
  angle: angle in degrees. +: counter-clockwise, -: clockwise
54
51
  expand: whether the image should be padded before the rotation
55
52
 
56
53
  Returns:
57
- -------
58
54
  the rotated image (tensor)
59
55
  """
60
56
  # Compute the expanded padding
@@ -103,18 +99,16 @@ def rotate_sample(
103
99
  geoms: np.ndarray,
104
100
  angle: float,
105
101
  expand: bool = False,
106
- ) -> Tuple[tf.Tensor, np.ndarray]:
102
+ ) -> tuple[tf.Tensor, np.ndarray]:
107
103
  """Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
108
104
 
109
105
  Args:
110
- ----
111
106
  img: image to rotate
112
107
  geoms: array of geometries of shape (N, 4) or (N, 4, 2)
113
108
  angle: angle in degrees. +: counter-clockwise, -: clockwise
114
109
  expand: whether the image should be padded before the rotation
115
110
 
116
111
  Returns:
117
- -------
118
112
  A tuple of rotated img (tensor), rotated boxes (np array)
119
113
  """
120
114
  # Rotated the image
@@ -140,22 +134,20 @@ def rotate_sample(
140
134
  rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[1]
141
135
  rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[0]
142
136
 
143
- return rotated_img, np.clip(rotated_geoms, 0, 1)
137
+ return rotated_img, np.clip(np.around(rotated_geoms, decimals=15), 0, 1)
144
138
 
145
139
 
146
140
  def crop_detection(
147
- img: tf.Tensor, boxes: np.ndarray, crop_box: Tuple[float, float, float, float]
148
- ) -> Tuple[tf.Tensor, np.ndarray]:
141
+ img: tf.Tensor, boxes: np.ndarray, crop_box: tuple[float, float, float, float]
142
+ ) -> tuple[tf.Tensor, np.ndarray]:
149
143
  """Crop and image and associated bboxes
150
144
 
151
145
  Args:
152
- ----
153
146
  img: image to crop
154
147
  boxes: array of boxes to clip, absolute (int) or relative (float)
155
148
  crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords.
156
149
 
157
150
  Returns:
158
- -------
159
151
  A tuple of cropped image, cropped boxes, where the image is not resized.
160
152
  """
161
153
  if any(val < 0 or val > 1 for val in crop_box):
@@ -172,16 +164,15 @@ def crop_detection(
172
164
 
173
165
  def _gaussian_filter(
174
166
  img: tf.Tensor,
175
- kernel_size: Union[int, Iterable[int]],
167
+ kernel_size: int | Iterable[int],
176
168
  sigma: float,
177
- mode: Optional[str] = None,
178
- pad_value: Optional[int] = 0,
169
+ mode: str | None = None,
170
+ pad_value: int = 0,
179
171
  ):
180
172
  """Apply Gaussian filter to image.
181
173
  Adapted from: https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/filters.py
182
174
 
183
175
  Args:
184
- ----
185
176
  img: image to filter of shape (N, H, W, C)
186
177
  kernel_size: kernel size of the filter
187
178
  sigma: standard deviation of the Gaussian filter
@@ -189,7 +180,6 @@ def _gaussian_filter(
189
180
  pad_value: value to pad the image with
190
181
 
191
182
  Returns:
192
- -------
193
183
  A tensor of shape (N, H, W, C)
194
184
  """
195
185
  ksize = tf.convert_to_tensor(tf.broadcast_to(kernel_size, [2]), dtype=tf.int32)
@@ -235,17 +225,15 @@ def _gaussian_filter(
235
225
  return tf.nn.depthwise_conv2d(img, g, [1, 1, 1, 1], padding="VALID", data_format="NHWC")
236
226
 
237
227
 
238
- def random_shadow(img: tf.Tensor, opacity_range: Tuple[float, float], **kwargs) -> tf.Tensor:
228
+ def random_shadow(img: tf.Tensor, opacity_range: tuple[float, float], **kwargs) -> tf.Tensor:
239
229
  """Apply a random shadow to a given image
240
230
 
241
231
  Args:
242
- ----
243
232
  img: image to modify
244
233
  opacity_range: the minimum and maximum desired opacity of the shadow
245
234
  **kwargs: additional arguments to pass to `create_shadow_mask`
246
235
 
247
236
  Returns:
248
- -------
249
237
  shadowed image
250
238
  """
251
239
  shadow_mask = create_shadow_mask(img.shape[:2], **kwargs)
@@ -2,7 +2,7 @@ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
3
  from .base import *
4
4
 
5
- if is_tf_available():
6
- from .tensorflow import *
7
- elif is_torch_available():
8
- from .pytorch import * # type: ignore[assignment]
5
+ if is_torch_available():
6
+ from .pytorch import *
7
+ elif is_tf_available():
8
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,11 +1,12 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
7
  import random
8
- from typing import Any, Callable, List, Optional, Tuple, Union
8
+ from collections.abc import Callable
9
+ from typing import Any
9
10
 
10
11
  import numpy as np
11
12
 
@@ -21,37 +22,36 @@ class SampleCompose(NestedObject):
21
22
 
22
23
  .. tabs::
23
24
 
24
- .. tab:: TensorFlow
25
+ .. tab:: PyTorch
25
26
 
26
27
  .. code:: python
27
28
 
28
29
  >>> import numpy as np
29
- >>> import tensorflow as tf
30
+ >>> import torch
30
31
  >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
31
- >>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
32
- >>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4)))
32
+ >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
33
+ >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
33
34
 
34
- .. tab:: PyTorch
35
+ .. tab:: TensorFlow
35
36
 
36
37
  .. code:: python
37
38
 
38
39
  >>> import numpy as np
39
- >>> import torch
40
+ >>> import tensorflow as tf
40
41
  >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
41
- >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
42
- >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
42
+ >>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
43
+ >>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4)))
43
44
 
44
45
  Args:
45
- ----
46
46
  transforms: list of transformation modules
47
47
  """
48
48
 
49
- _children_names: List[str] = ["sample_transforms"]
49
+ _children_names: list[str] = ["sample_transforms"]
50
50
 
51
- def __init__(self, transforms: List[Callable[[Any, Any], Tuple[Any, Any]]]) -> None:
51
+ def __init__(self, transforms: list[Callable[[Any, Any], tuple[Any, Any]]]) -> None:
52
52
  self.sample_transforms = transforms
53
53
 
54
- def __call__(self, x: Any, target: Any) -> Tuple[Any, Any]:
54
+ def __call__(self, x: Any, target: Any) -> tuple[Any, Any]:
55
55
  for t in self.sample_transforms:
56
56
  x, target = t(x, target)
57
57
 
@@ -63,35 +63,34 @@ class ImageTransform(NestedObject):
63
63
 
64
64
  .. tabs::
65
65
 
66
- .. tab:: TensorFlow
66
+ .. tab:: PyTorch
67
67
 
68
68
  .. code:: python
69
69
 
70
- >>> import tensorflow as tf
70
+ >>> import torch
71
71
  >>> from doctr.transforms import ImageTransform, ColorInversion
72
72
  >>> transfo = ImageTransform(ColorInversion((32, 32)))
73
- >>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None)
73
+ >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
74
74
 
75
- .. tab:: PyTorch
75
+ .. tab:: TensorFlow
76
76
 
77
77
  .. code:: python
78
78
 
79
- >>> import torch
79
+ >>> import tensorflow as tf
80
80
  >>> from doctr.transforms import ImageTransform, ColorInversion
81
81
  >>> transfo = ImageTransform(ColorInversion((32, 32)))
82
- >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
82
+ >>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None)
83
83
 
84
84
  Args:
85
- ----
86
85
  transform: the image transformation module to wrap
87
86
  """
88
87
 
89
- _children_names: List[str] = ["img_transform"]
88
+ _children_names: list[str] = ["img_transform"]
90
89
 
91
90
  def __init__(self, transform: Callable[[Any], Any]) -> None:
92
91
  self.img_transform = transform
93
92
 
94
- def __call__(self, img: Any, target: Any) -> Tuple[Any, Any]:
93
+ def __call__(self, img: Any, target: Any) -> tuple[Any, Any]:
95
94
  img = self.img_transform(img)
96
95
  return img, target
97
96
 
@@ -102,26 +101,25 @@ class ColorInversion(NestedObject):
102
101
 
103
102
  .. tabs::
104
103
 
105
- .. tab:: TensorFlow
104
+ .. tab:: PyTorch
106
105
 
107
106
  .. code:: python
108
107
 
109
- >>> import tensorflow as tf
108
+ >>> import torch
110
109
  >>> from doctr.transforms import ColorInversion
111
110
  >>> transfo = ColorInversion(min_val=0.6)
112
- >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1))
111
+ >>> out = transfo(torch.rand(8, 64, 64, 3))
113
112
 
114
- .. tab:: PyTorch
113
+ .. tab:: TensorFlow
115
114
 
116
115
  .. code:: python
117
116
 
118
- >>> import torch
117
+ >>> import tensorflow as tf
119
118
  >>> from doctr.transforms import ColorInversion
120
119
  >>> transfo = ColorInversion(min_val=0.6)
121
- >>> out = transfo(torch.rand(8, 64, 64, 3))
120
+ >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1))
122
121
 
123
122
  Args:
124
- ----
125
123
  min_val: range [min_val, 1] to colorize RGB pixels
126
124
  """
127
125
 
@@ -140,35 +138,34 @@ class OneOf(NestedObject):
140
138
 
141
139
  .. tabs::
142
140
 
143
- .. tab:: TensorFlow
141
+ .. tab:: PyTorch
144
142
 
145
143
  .. code:: python
146
144
 
147
- >>> import tensorflow as tf
145
+ >>> import torch
148
146
  >>> from doctr.transforms import OneOf
149
147
  >>> transfo = OneOf([JpegQuality(), Gamma()])
150
- >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
148
+ >>> out = transfo(torch.rand(1, 64, 64, 3))
151
149
 
152
- .. tab:: PyTorch
150
+ .. tab:: TensorFlow
153
151
 
154
152
  .. code:: python
155
153
 
156
- >>> import torch
154
+ >>> import tensorflow as tf
157
155
  >>> from doctr.transforms import OneOf
158
156
  >>> transfo = OneOf([JpegQuality(), Gamma()])
159
- >>> out = transfo(torch.rand(1, 64, 64, 3))
157
+ >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
160
158
 
161
159
  Args:
162
- ----
163
160
  transforms: list of transformations, one only will be picked
164
161
  """
165
162
 
166
- _children_names: List[str] = ["transforms"]
163
+ _children_names: list[str] = ["transforms"]
167
164
 
168
- def __init__(self, transforms: List[Callable[[Any], Any]]) -> None:
165
+ def __init__(self, transforms: list[Callable[[Any], Any]]) -> None:
169
166
  self.transforms = transforms
170
167
 
171
- def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]:
168
+ def __call__(self, img: Any, target: np.ndarray | None = None) -> Any | tuple[Any, np.ndarray]:
172
169
  # Pick transformation
173
170
  transfo = self.transforms[int(random.random() * len(self.transforms))]
174
171
  # Apply
@@ -180,26 +177,25 @@ class RandomApply(NestedObject):
180
177
 
181
178
  .. tabs::
182
179
 
183
- .. tab:: TensorFlow
180
+ .. tab:: PyTorch
184
181
 
185
182
  .. code:: python
186
183
 
187
- >>> import tensorflow as tf
184
+ >>> import torch
188
185
  >>> from doctr.transforms import RandomApply
189
186
  >>> transfo = RandomApply(Gamma(), p=.5)
190
- >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
187
+ >>> out = transfo(torch.rand(1, 64, 64, 3))
191
188
 
192
- .. tab:: PyTorch
189
+ .. tab:: TensorFlow
193
190
 
194
191
  .. code:: python
195
192
 
196
- >>> import torch
193
+ >>> import tensorflow as tf
197
194
  >>> from doctr.transforms import RandomApply
198
195
  >>> transfo = RandomApply(Gamma(), p=.5)
199
- >>> out = transfo(torch.rand(1, 64, 64, 3))
196
+ >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
200
197
 
201
198
  Args:
202
- ----
203
199
  transform: transformation to apply
204
200
  p: probability to apply
205
201
  """
@@ -211,7 +207,7 @@ class RandomApply(NestedObject):
211
207
  def extra_repr(self) -> str:
212
208
  return f"transform={self.transform}, p={self.p}"
213
209
 
214
- def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]:
210
+ def __call__(self, img: Any, target: np.ndarray | None = None) -> Any | tuple[Any, np.ndarray]:
215
211
  if random.random() < self.p:
216
212
  return self.transform(img) if target is None else self.transform(img, target) # type: ignore[call-arg]
217
213
  return img if target is None else (img, target)
@@ -224,9 +220,7 @@ class RandomRotate(NestedObject):
224
220
  :align: center
225
221
 
226
222
  Args:
227
- ----
228
- max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in
229
- [-max_angle, max_angle]
223
+ max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in [-max_angle, max_angle]
230
224
  expand: whether the image should be padded before the rotation
231
225
  """
232
226
 
@@ -237,7 +231,7 @@ class RandomRotate(NestedObject):
237
231
  def extra_repr(self) -> str:
238
232
  return f"max_angle={self.max_angle}, expand={self.expand}"
239
233
 
240
- def __call__(self, img: Any, target: np.ndarray) -> Tuple[Any, np.ndarray]:
234
+ def __call__(self, img: Any, target: np.ndarray) -> tuple[Any, np.ndarray]:
241
235
  angle = random.uniform(-self.max_angle, self.max_angle)
242
236
  r_img, r_polys = F.rotate_sample(img, target, angle, self.expand)
243
237
  # Removes deleted boxes
@@ -249,19 +243,18 @@ class RandomCrop(NestedObject):
249
243
  """Randomly crop a tensor image and its boxes
250
244
 
251
245
  Args:
252
- ----
253
246
  scale: tuple of floats, relative (min_area, max_area) of the crop
254
247
  ratio: tuple of float, relative (min_ratio, max_ratio) where ratio = h/w
255
248
  """
256
249
 
257
- def __init__(self, scale: Tuple[float, float] = (0.08, 1.0), ratio: Tuple[float, float] = (0.75, 1.33)) -> None:
250
+ def __init__(self, scale: tuple[float, float] = (0.08, 1.0), ratio: tuple[float, float] = (0.75, 1.33)) -> None:
258
251
  self.scale = scale
259
252
  self.ratio = ratio
260
253
 
261
254
  def extra_repr(self) -> str:
262
255
  return f"scale={self.scale}, ratio={self.ratio}"
263
256
 
264
- def __call__(self, img: Any, target: np.ndarray) -> Tuple[Any, np.ndarray]:
257
+ def __call__(self, img: Any, target: np.ndarray) -> tuple[Any, np.ndarray]:
265
258
  scale = random.uniform(self.scale[0], self.scale[1])
266
259
  ratio = random.uniform(self.ratio[0], self.ratio[1])
267
260