python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,12 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import logging
7
- import os
8
- from typing import Any, Callable, List, Optional, Tuple, Union
9
- from zipfile import ZipFile
7
+ from collections.abc import Callable
8
+ from typing import Any
10
9
 
11
10
  import tensorflow as tf
12
11
  import tf2onnx
@@ -19,6 +18,7 @@ logging.getLogger("tensorflow").setLevel(logging.DEBUG)
19
18
 
20
19
  __all__ = [
21
20
  "load_pretrained_params",
21
+ "_build_model",
22
22
  "conv_sequence",
23
23
  "IntermediateLayerGetter",
24
24
  "export_model_to_onnx",
@@ -36,51 +36,50 @@ def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor:
36
36
  return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x
37
37
 
38
38
 
39
+ def _build_model(model: Model):
40
+ """Build a model by calling it once with dummy input
41
+
42
+ Args:
43
+ model: the model to be built
44
+ """
45
+ model(tf.zeros((1, *model.cfg["input_shape"])), training=False)
46
+
47
+
39
48
  def load_pretrained_params(
40
49
  model: Model,
41
- url: Optional[str] = None,
42
- hash_prefix: Optional[str] = None,
43
- overwrite: bool = False,
44
- internal_name: str = "weights",
50
+ url: str | None = None,
51
+ hash_prefix: str | None = None,
52
+ skip_mismatch: bool = False,
45
53
  **kwargs: Any,
46
54
  ) -> None:
47
55
  """Load a set of parameters onto a model
48
56
 
49
57
  >>> from doctr.models import load_pretrained_params
50
- >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
58
+ >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5")
51
59
 
52
60
  Args:
53
- ----
54
61
  model: the keras model to be loaded
55
62
  url: URL of the zipped set of parameters
56
63
  hash_prefix: first characters of SHA256 expected hash
57
- overwrite: should the zip extraction be enforced if the archive has already been extracted
58
- internal_name: name of the ckpt files
64
+ skip_mismatch: skip loading layers with mismatched shapes
59
65
  **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
60
66
  """
61
67
  if url is None:
62
68
  logging.warning("Invalid model URL, using default initialization.")
63
69
  else:
64
70
  archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
65
-
66
- # Unzip the archive
67
- params_path = archive_path.parent.joinpath(archive_path.stem)
68
- if not params_path.is_dir() or overwrite:
69
- with ZipFile(archive_path, "r") as f:
70
- f.extractall(path=params_path)
71
-
72
71
  # Load weights
73
- model.load_weights(f"{params_path}{os.sep}{internal_name}")
72
+ model.load_weights(archive_path, skip_mismatch=skip_mismatch)
74
73
 
75
74
 
76
75
  def conv_sequence(
77
76
  out_channels: int,
78
- activation: Optional[Union[str, Callable]] = None,
77
+ activation: str | Callable | None = None,
79
78
  bn: bool = False,
80
79
  padding: str = "same",
81
80
  kernel_initializer: str = "he_normal",
82
81
  **kwargs: Any,
83
- ) -> List[layers.Layer]:
82
+ ) -> list[layers.Layer]:
84
83
  """Builds a convolutional-based layer sequence
85
84
 
86
85
  >>> from tensorflow.keras import Sequential
@@ -88,7 +87,6 @@ def conv_sequence(
88
87
  >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3]))
89
88
 
90
89
  Args:
91
- ----
92
90
  out_channels: number of output channels
93
91
  activation: activation to be used (default: no activation)
94
92
  bn: should a batch normalization layer be added
@@ -97,7 +95,6 @@ def conv_sequence(
97
95
  **kwargs: additional arguments to be passed to the convolutional layer
98
96
 
99
97
  Returns:
100
- -------
101
98
  list of layers
102
99
  """
103
100
  # No bias before Batch norm
@@ -125,12 +122,11 @@ class IntermediateLayerGetter(Model):
125
122
  >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers)
126
123
 
127
124
  Args:
128
- ----
129
125
  model: the model to extract feature maps from
130
126
  layer_names: the list of layers to retrieve the feature map from
131
127
  """
132
128
 
133
- def __init__(self, model: Model, layer_names: List[str]) -> None:
129
+ def __init__(self, model: Model, layer_names: list[str]) -> None:
134
130
  intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names]
135
131
  super().__init__(model.input, outputs=intermediate_fmaps)
136
132
 
@@ -139,8 +135,8 @@ class IntermediateLayerGetter(Model):
139
135
 
140
136
 
141
137
  def export_model_to_onnx(
142
- model: Model, model_name: str, dummy_input: List[tf.TensorSpec], **kwargs: Any
143
- ) -> Tuple[str, List[str]]:
138
+ model: Model, model_name: str, dummy_input: list[tf.TensorSpec], **kwargs: Any
139
+ ) -> tuple[str, list[str]]:
144
140
  """Export model to ONNX format.
145
141
 
146
142
  >>> import tensorflow as tf
@@ -151,16 +147,18 @@ def export_model_to_onnx(
151
147
  >>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")])
152
148
 
153
149
  Args:
154
- ----
155
150
  model: the keras model to be exported
156
151
  model_name: the name for the exported model
157
152
  dummy_input: the dummy input to the model
158
153
  kwargs: additional arguments to be passed to tf2onnx
159
154
 
160
155
  Returns:
161
- -------
162
156
  the path to the exported model and a list with the output layer names
163
157
  """
158
+ # get the users eager mode
159
+ eager_mode = tf.executing_eagerly()
160
+ # set eager mode to true to avoid issues with tf2onnx
161
+ tf.config.run_functions_eagerly(True)
164
162
  large_model = kwargs.get("large_model", False)
165
163
  model_proto, _ = tf2onnx.convert.from_keras(
166
164
  model,
@@ -171,6 +169,9 @@ def export_model_to_onnx(
171
169
  # Get the output layer names
172
170
  output = [n.name for n in model_proto.graph.output]
173
171
 
172
+ # reset the eager mode to the users mode
173
+ tf.config.run_functions_eagerly(eager_mode)
174
+
174
175
  # models which are too large (weights > 2GB while converting to ONNX) needs to be handled
175
176
  # about an external tensor storage where the graph and weights are seperatly stored in a archive
176
177
  if large_model:
doctr/models/zoo.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -83,7 +83,6 @@ def ocr_predictor(
83
83
  >>> out = model([input_page])
84
84
 
85
85
  Args:
86
- ----
87
86
  det_arch: name of the detection architecture or the model itself to use
88
87
  (e.g. 'db_resnet50', 'db_mobilenet_v3_large')
89
88
  reco_arch: name of the recognition architecture or the model itself to use
@@ -108,7 +107,6 @@ def ocr_predictor(
108
107
  kwargs: keyword args of `OCRPredictor`
109
108
 
110
109
  Returns:
111
- -------
112
110
  OCR predictor
113
111
  """
114
112
  return _predictor(
@@ -197,7 +195,6 @@ def kie_predictor(
197
195
  >>> out = model([input_page])
198
196
 
199
197
  Args:
200
- ----
201
198
  det_arch: name of the detection architecture or the model itself to use
202
199
  (e.g. 'db_resnet50', 'db_mobilenet_v3_large')
203
200
  reco_arch: name of the recognition architecture or the model itself to use
@@ -222,7 +219,6 @@ def kie_predictor(
222
219
  kwargs: keyword args of `OCRPredictor`
223
220
 
224
221
  Returns:
225
- -------
226
222
  KIE predictor
227
223
  """
228
224
  return _kie_predictor(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
3
+ if is_torch_available():
6
4
  from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import *
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Tuple, Union
7
6
 
8
7
  import cv2
9
8
  import numpy as np
@@ -15,17 +14,15 @@ __all__ = ["crop_boxes", "create_shadow_mask"]
15
14
 
16
15
  def crop_boxes(
17
16
  boxes: np.ndarray,
18
- crop_box: Union[Tuple[int, int, int, int], Tuple[float, float, float, float]],
17
+ crop_box: tuple[int, int, int, int] | tuple[float, float, float, float],
19
18
  ) -> np.ndarray:
20
19
  """Crop localization boxes
21
20
 
22
21
  Args:
23
- ----
24
22
  boxes: ndarray of shape (N, 4) in relative or abs coordinates
25
23
  crop_box: box (xmin, ymin, xmax, ymax) to crop the image, in the same coord format that the boxes
26
24
 
27
25
  Returns:
28
- -------
29
26
  the cropped boxes
30
27
  """
31
28
  is_box_rel = boxes.max() <= 1
@@ -49,17 +46,15 @@ def crop_boxes(
49
46
  return boxes[is_valid]
50
47
 
51
48
 
52
- def expand_line(line: np.ndarray, target_shape: Tuple[int, int]) -> Tuple[float, float]:
49
+ def expand_line(line: np.ndarray, target_shape: tuple[int, int]) -> tuple[float, float]:
53
50
  """Expands a 2-point line, so that the first is on the edge. In other terms, we extend the line in
54
51
  the same direction until we meet one of the edges.
55
52
 
56
53
  Args:
57
- ----
58
54
  line: array of shape (2, 2) of the point supposed to be on one edge, and the shadow tip.
59
55
  target_shape: the desired mask shape
60
56
 
61
57
  Returns:
62
- -------
63
58
  2D coordinates of the first point once we extended the line (on one of the edges)
64
59
  """
65
60
  if any(coord == 0 or coord == size for coord, size in zip(line[0], target_shape[::-1])):
@@ -112,7 +107,7 @@ def expand_line(line: np.ndarray, target_shape: Tuple[int, int]) -> Tuple[float,
112
107
 
113
108
 
114
109
  def create_shadow_mask(
115
- target_shape: Tuple[int, int],
110
+ target_shape: tuple[int, int],
116
111
  min_base_width=0.3,
117
112
  max_tip_width=0.5,
118
113
  max_tip_height=0.3,
@@ -120,14 +115,12 @@ def create_shadow_mask(
120
115
  """Creates a random shadow mask
121
116
 
122
117
  Args:
123
- ----
124
118
  target_shape: the target shape (H, W)
125
119
  min_base_width: the relative minimum shadow base width
126
120
  max_tip_width: the relative maximum shadow tip width
127
121
  max_tip_height: the relative maximum shadow tip height
128
122
 
129
123
  Returns:
130
- -------
131
124
  a numpy ndarray of shape (H, W, 1) with values in the range [0, 1]
132
125
  """
133
126
  # Default base is top
@@ -1,13 +1,13 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Tuple
8
7
 
9
8
  import numpy as np
10
9
  import torch
10
+ from scipy.ndimage import gaussian_filter
11
11
  from torchvision.transforms import functional as F
12
12
 
13
13
  from doctr.utils.geometry import rotate_abs_geoms
@@ -21,12 +21,10 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
21
21
  """Invert the colors of an image
22
22
 
23
23
  Args:
24
- ----
25
24
  img : torch.Tensor, the image to invert
26
25
  min_val : minimum value of the random shift
27
26
 
28
27
  Returns:
29
- -------
30
28
  the inverted image
31
29
  """
32
30
  out = F.rgb_to_grayscale(img, num_output_channels=3)
@@ -35,9 +33,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
35
33
  rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
36
34
  # Inverse the color
37
35
  if out.dtype == torch.uint8:
38
- out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
36
+ out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) # type: ignore[attr-defined]
39
37
  else:
40
- out = out * rgb_shift.to(dtype=out.dtype)
38
+ out = out * rgb_shift.to(dtype=out.dtype) # type: ignore[attr-defined]
41
39
  # Inverse the color
42
40
  out = 255 - out if out.dtype == torch.uint8 else 1 - out
43
41
  return out
@@ -48,18 +46,16 @@ def rotate_sample(
48
46
  geoms: np.ndarray,
49
47
  angle: float,
50
48
  expand: bool = False,
51
- ) -> Tuple[torch.Tensor, np.ndarray]:
49
+ ) -> tuple[torch.Tensor, np.ndarray]:
52
50
  """Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
53
51
 
54
52
  Args:
55
- ----
56
53
  img: image to rotate
57
54
  geoms: array of geometries of shape (N, 4) or (N, 4, 2)
58
55
  angle: angle in degrees. +: counter-clockwise, -: clockwise
59
56
  expand: whether the image should be padded before the rotation
60
57
 
61
58
  Returns:
62
- -------
63
59
  A tuple of rotated img (tensor), rotated geometries of shape (N, 4, 2)
64
60
  """
65
61
  rotated_img = F.rotate(img, angle=angle, fill=0, expand=expand) # Interpolation NEAREST by default
@@ -81,7 +77,7 @@ def rotate_sample(
81
77
  rotated_geoms: np.ndarray = rotate_abs_geoms(
82
78
  _geoms,
83
79
  angle,
84
- img.shape[1:], # type: ignore[arg-type]
80
+ img.shape[1:],
85
81
  expand,
86
82
  ).astype(np.float32)
87
83
 
@@ -89,22 +85,20 @@ def rotate_sample(
89
85
  rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[2]
90
86
  rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[1]
91
87
 
92
- return rotated_img, np.clip(rotated_geoms, 0, 1)
88
+ return rotated_img, np.clip(np.around(rotated_geoms, decimals=15), 0, 1)
93
89
 
94
90
 
95
91
  def crop_detection(
96
- img: torch.Tensor, boxes: np.ndarray, crop_box: Tuple[float, float, float, float]
97
- ) -> Tuple[torch.Tensor, np.ndarray]:
92
+ img: torch.Tensor, boxes: np.ndarray, crop_box: tuple[float, float, float, float]
93
+ ) -> tuple[torch.Tensor, np.ndarray]:
98
94
  """Crop and image and associated bboxes
99
95
 
100
96
  Args:
101
- ----
102
97
  img: image to crop
103
98
  boxes: array of boxes to clip, absolute (int) or relative (float)
104
99
  crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords.
105
100
 
106
101
  Returns:
107
- -------
108
102
  A tuple of cropped image, cropped boxes, where the image is not resized.
109
103
  """
110
104
  if any(val < 0 or val > 1 for val in crop_box):
@@ -119,27 +113,25 @@ def crop_detection(
119
113
  return cropped_img, boxes
120
114
 
121
115
 
122
- def random_shadow(img: torch.Tensor, opacity_range: Tuple[float, float], **kwargs) -> torch.Tensor:
123
- """Crop and image and associated bboxes
116
+ def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwargs) -> torch.Tensor:
117
+ """Apply a random shadow effect to an image using NumPy for blurring.
124
118
 
125
119
  Args:
126
- ----
127
- img: image to modify
128
- opacity_range: the minimum and maximum desired opacity of the shadow
129
- **kwargs: additional arguments to pass to `create_shadow_mask`
120
+ img: Image to modify (C, H, W) as a PyTorch tensor.
121
+ opacity_range: The minimum and maximum desired opacity of the shadow.
122
+ **kwargs: Additional arguments to pass to `create_shadow_mask`.
130
123
 
131
124
  Returns:
132
- -------
133
- shaded image
125
+ Shadowed image as a PyTorch tensor (same shape as input).
134
126
  """
135
- shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type]
136
-
127
+ shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
137
128
  opacity = np.random.uniform(*opacity_range)
138
- shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...])
139
129
 
140
- # Add some blur to make it believable
141
- k = 7 + 2 * int(4 * np.random.rand(1))
130
+ # Apply Gaussian blur to the shadow mask
142
131
  sigma = np.random.uniform(0.5, 5.0)
143
- shadow_tensor = F.gaussian_blur(shadow_tensor, k, sigma=[sigma, sigma])
132
+ blurred_mask = gaussian_filter(shadow_mask, sigma=sigma)
133
+
134
+ shadow_tensor = 1 - torch.from_numpy(blurred_mask).float()
135
+ shadow_tensor = shadow_tensor.to(img.device).unsqueeze(0) # Add channel dimension
144
136
 
145
137
  return opacity * shadow_tensor * img + (1 - opacity) * img
@@ -1,12 +1,12 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
7
  import random
8
+ from collections.abc import Iterable
8
9
  from copy import deepcopy
9
- from typing import Iterable, Optional, Tuple, Union
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
@@ -22,12 +22,10 @@ def invert_colors(img: tf.Tensor, min_val: float = 0.6) -> tf.Tensor:
22
22
  """Invert the colors of an image
23
23
 
24
24
  Args:
25
- ----
26
25
  img : tf.Tensor, the image to invert
27
26
  min_val : minimum value of the random shift
28
27
 
29
28
  Returns:
30
- -------
31
29
  the inverted image
32
30
  """
33
31
  out = tf.image.rgb_to_grayscale(img) # Convert to gray
@@ -48,13 +46,11 @@ def rotated_img_tensor(img: tf.Tensor, angle: float, expand: bool = False) -> tf
48
46
  """Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
49
47
 
50
48
  Args:
51
- ----
52
49
  img: image to rotate
53
50
  angle: angle in degrees. +: counter-clockwise, -: clockwise
54
51
  expand: whether the image should be padded before the rotation
55
52
 
56
53
  Returns:
57
- -------
58
54
  the rotated image (tensor)
59
55
  """
60
56
  # Compute the expanded padding
@@ -103,18 +99,16 @@ def rotate_sample(
103
99
  geoms: np.ndarray,
104
100
  angle: float,
105
101
  expand: bool = False,
106
- ) -> Tuple[tf.Tensor, np.ndarray]:
102
+ ) -> tuple[tf.Tensor, np.ndarray]:
107
103
  """Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
108
104
 
109
105
  Args:
110
- ----
111
106
  img: image to rotate
112
107
  geoms: array of geometries of shape (N, 4) or (N, 4, 2)
113
108
  angle: angle in degrees. +: counter-clockwise, -: clockwise
114
109
  expand: whether the image should be padded before the rotation
115
110
 
116
111
  Returns:
117
- -------
118
112
  A tuple of rotated img (tensor), rotated boxes (np array)
119
113
  """
120
114
  # Rotated the image
@@ -140,22 +134,20 @@ def rotate_sample(
140
134
  rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[1]
141
135
  rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[0]
142
136
 
143
- return rotated_img, np.clip(rotated_geoms, 0, 1)
137
+ return rotated_img, np.clip(np.around(rotated_geoms, decimals=15), 0, 1)
144
138
 
145
139
 
146
140
  def crop_detection(
147
- img: tf.Tensor, boxes: np.ndarray, crop_box: Tuple[float, float, float, float]
148
- ) -> Tuple[tf.Tensor, np.ndarray]:
141
+ img: tf.Tensor, boxes: np.ndarray, crop_box: tuple[float, float, float, float]
142
+ ) -> tuple[tf.Tensor, np.ndarray]:
149
143
  """Crop and image and associated bboxes
150
144
 
151
145
  Args:
152
- ----
153
146
  img: image to crop
154
147
  boxes: array of boxes to clip, absolute (int) or relative (float)
155
148
  crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords.
156
149
 
157
150
  Returns:
158
- -------
159
151
  A tuple of cropped image, cropped boxes, where the image is not resized.
160
152
  """
161
153
  if any(val < 0 or val > 1 for val in crop_box):
@@ -172,16 +164,15 @@ def crop_detection(
172
164
 
173
165
  def _gaussian_filter(
174
166
  img: tf.Tensor,
175
- kernel_size: Union[int, Iterable[int]],
167
+ kernel_size: int | Iterable[int],
176
168
  sigma: float,
177
- mode: Optional[str] = None,
178
- pad_value: Optional[int] = 0,
169
+ mode: str | None = None,
170
+ pad_value: int = 0,
179
171
  ):
180
172
  """Apply Gaussian filter to image.
181
173
  Adapted from: https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/filters.py
182
174
 
183
175
  Args:
184
- ----
185
176
  img: image to filter of shape (N, H, W, C)
186
177
  kernel_size: kernel size of the filter
187
178
  sigma: standard deviation of the Gaussian filter
@@ -189,7 +180,6 @@ def _gaussian_filter(
189
180
  pad_value: value to pad the image with
190
181
 
191
182
  Returns:
192
- -------
193
183
  A tensor of shape (N, H, W, C)
194
184
  """
195
185
  ksize = tf.convert_to_tensor(tf.broadcast_to(kernel_size, [2]), dtype=tf.int32)
@@ -235,17 +225,15 @@ def _gaussian_filter(
235
225
  return tf.nn.depthwise_conv2d(img, g, [1, 1, 1, 1], padding="VALID", data_format="NHWC")
236
226
 
237
227
 
238
- def random_shadow(img: tf.Tensor, opacity_range: Tuple[float, float], **kwargs) -> tf.Tensor:
228
+ def random_shadow(img: tf.Tensor, opacity_range: tuple[float, float], **kwargs) -> tf.Tensor:
239
229
  """Apply a random shadow to a given image
240
230
 
241
231
  Args:
242
- ----
243
232
  img: image to modify
244
233
  opacity_range: the minimum and maximum desired opacity of the shadow
245
234
  **kwargs: additional arguments to pass to `create_shadow_mask`
246
235
 
247
236
  Returns:
248
- -------
249
237
  shadowed image
250
238
  """
251
239
  shadow_mask = create_shadow_mask(img.shape[:2], **kwargs)
@@ -2,7 +2,7 @@ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
3
  from .base import *
4
4
 
5
- if is_tf_available():
6
- from .tensorflow import *
7
- elif is_torch_available():
8
- from .pytorch import * # type: ignore[assignment]
5
+ if is_torch_available():
6
+ from .pytorch import *
7
+ elif is_tf_available():
8
+ from .tensorflow import * # type: ignore[assignment]