python-doctr 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/cord.py +10 -1
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +11 -1
  8. doctr/datasets/generator/base.py +6 -5
  9. doctr/datasets/ic03.py +11 -1
  10. doctr/datasets/ic13.py +10 -1
  11. doctr/datasets/iiit5k.py +26 -16
  12. doctr/datasets/imgur5k.py +11 -2
  13. doctr/datasets/loader.py +1 -6
  14. doctr/datasets/sroie.py +11 -1
  15. doctr/datasets/svhn.py +11 -1
  16. doctr/datasets/svt.py +11 -1
  17. doctr/datasets/synthtext.py +11 -1
  18. doctr/datasets/utils.py +9 -3
  19. doctr/datasets/vocabs.py +15 -4
  20. doctr/datasets/wildreceipt.py +12 -1
  21. doctr/file_utils.py +45 -12
  22. doctr/io/elements.py +52 -10
  23. doctr/io/html.py +2 -2
  24. doctr/io/image/pytorch.py +6 -8
  25. doctr/io/image/tensorflow.py +1 -1
  26. doctr/io/pdf.py +5 -2
  27. doctr/io/reader.py +6 -0
  28. doctr/models/__init__.py +0 -1
  29. doctr/models/_utils.py +57 -20
  30. doctr/models/builder.py +73 -15
  31. doctr/models/classification/magc_resnet/tensorflow.py +13 -6
  32. doctr/models/classification/mobilenet/pytorch.py +47 -9
  33. doctr/models/classification/mobilenet/tensorflow.py +51 -14
  34. doctr/models/classification/predictor/pytorch.py +28 -17
  35. doctr/models/classification/predictor/tensorflow.py +26 -16
  36. doctr/models/classification/resnet/tensorflow.py +21 -8
  37. doctr/models/classification/textnet/pytorch.py +3 -3
  38. doctr/models/classification/textnet/tensorflow.py +11 -5
  39. doctr/models/classification/vgg/tensorflow.py +9 -3
  40. doctr/models/classification/vit/tensorflow.py +10 -4
  41. doctr/models/classification/zoo.py +55 -19
  42. doctr/models/detection/_utils/__init__.py +1 -0
  43. doctr/models/detection/_utils/base.py +66 -0
  44. doctr/models/detection/differentiable_binarization/base.py +4 -3
  45. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  46. doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
  47. doctr/models/detection/fast/base.py +6 -5
  48. doctr/models/detection/fast/pytorch.py +4 -4
  49. doctr/models/detection/fast/tensorflow.py +15 -12
  50. doctr/models/detection/linknet/base.py +4 -3
  51. doctr/models/detection/linknet/tensorflow.py +23 -11
  52. doctr/models/detection/predictor/pytorch.py +15 -1
  53. doctr/models/detection/predictor/tensorflow.py +17 -3
  54. doctr/models/detection/zoo.py +7 -2
  55. doctr/models/factory/hub.py +8 -18
  56. doctr/models/kie_predictor/base.py +13 -3
  57. doctr/models/kie_predictor/pytorch.py +45 -20
  58. doctr/models/kie_predictor/tensorflow.py +44 -17
  59. doctr/models/modules/layers/pytorch.py +2 -3
  60. doctr/models/modules/layers/tensorflow.py +6 -8
  61. doctr/models/modules/transformer/pytorch.py +2 -2
  62. doctr/models/modules/transformer/tensorflow.py +0 -2
  63. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  64. doctr/models/modules/vision_transformer/tensorflow.py +1 -1
  65. doctr/models/predictor/base.py +97 -58
  66. doctr/models/predictor/pytorch.py +35 -20
  67. doctr/models/predictor/tensorflow.py +35 -18
  68. doctr/models/preprocessor/pytorch.py +4 -4
  69. doctr/models/preprocessor/tensorflow.py +3 -2
  70. doctr/models/recognition/crnn/tensorflow.py +8 -6
  71. doctr/models/recognition/master/pytorch.py +2 -2
  72. doctr/models/recognition/master/tensorflow.py +9 -4
  73. doctr/models/recognition/parseq/pytorch.py +4 -3
  74. doctr/models/recognition/parseq/tensorflow.py +14 -11
  75. doctr/models/recognition/sar/pytorch.py +7 -6
  76. doctr/models/recognition/sar/tensorflow.py +10 -12
  77. doctr/models/recognition/vitstr/pytorch.py +1 -1
  78. doctr/models/recognition/vitstr/tensorflow.py +9 -4
  79. doctr/models/recognition/zoo.py +1 -1
  80. doctr/models/utils/pytorch.py +1 -1
  81. doctr/models/utils/tensorflow.py +15 -15
  82. doctr/models/zoo.py +2 -2
  83. doctr/py.typed +0 -0
  84. doctr/transforms/functional/base.py +1 -1
  85. doctr/transforms/functional/pytorch.py +5 -5
  86. doctr/transforms/modules/base.py +37 -15
  87. doctr/transforms/modules/pytorch.py +73 -14
  88. doctr/transforms/modules/tensorflow.py +78 -19
  89. doctr/utils/fonts.py +7 -5
  90. doctr/utils/geometry.py +141 -31
  91. doctr/utils/metrics.py +34 -175
  92. doctr/utils/reconstitution.py +212 -0
  93. doctr/utils/visualization.py +5 -118
  94. doctr/version.py +1 -1
  95. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/METADATA +85 -81
  96. python_doctr-0.10.0.dist-info/RECORD +173 -0
  97. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
  98. doctr/models/artefacts/__init__.py +0 -2
  99. doctr/models/artefacts/barcode.py +0 -74
  100. doctr/models/artefacts/face.py +0 -63
  101. doctr/models/obj_detection/__init__.py +0 -1
  102. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  103. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  104. python_doctr-0.8.1.dist-info/RECORD +0 -173
  105. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
  106. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
  107. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
@@ -12,7 +12,7 @@ from tensorflow.keras import Model, layers
12
12
  from doctr.datasets import VOCABS
13
13
 
14
14
  from ...classification import vit_b, vit_s
15
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
15
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
16
16
  from .base import _ViTSTR, _ViTSTRPostProcessor
17
17
 
18
18
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
@@ -23,14 +23,14 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
23
23
  "std": (0.299, 0.296, 0.301),
24
24
  "input_shape": (32, 128, 3),
25
25
  "vocab": VOCABS["french"],
26
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_small-358fab2e.zip&src=0",
26
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0",
27
27
  },
28
28
  "vitstr_base": {
29
29
  "mean": (0.694, 0.695, 0.693),
30
30
  "std": (0.299, 0.296, 0.301),
31
31
  "input_shape": (32, 128, 3),
32
32
  "vocab": VOCABS["french"],
33
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_base-2889159a.zip&src=0",
33
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0",
34
34
  },
35
35
  }
36
36
 
@@ -216,9 +216,14 @@ def _vitstr(
216
216
 
217
217
  # Build the model
218
218
  model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs)
219
+ _build_model(model)
220
+
219
221
  # Load pretrained parameters
220
222
  if pretrained:
221
- load_pretrained_params(model, default_cfgs[arch]["url"])
223
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
224
+ load_pretrained_params(
225
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
226
+ )
222
227
 
223
228
  return model
224
229
 
@@ -45,7 +45,7 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
45
45
 
46
46
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
47
47
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
48
- kwargs["batch_size"] = kwargs.get("batch_size", 32)
48
+ kwargs["batch_size"] = kwargs.get("batch_size", 128)
49
49
  input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:]
50
50
  predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model)
51
51
 
@@ -157,7 +157,7 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
157
157
  """
158
158
  torch.onnx.export(
159
159
  model,
160
- dummy_input,
160
+ dummy_input, # type: ignore[arg-type]
161
161
  f"{model_name}.onnx",
162
162
  input_names=["input"],
163
163
  output_names=["logits"],
@@ -4,9 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import logging
7
- import os
8
7
  from typing import Any, Callable, List, Optional, Tuple, Union
9
- from zipfile import ZipFile
10
8
 
11
9
  import tensorflow as tf
12
10
  import tf2onnx
@@ -19,6 +17,7 @@ logging.getLogger("tensorflow").setLevel(logging.DEBUG)
19
17
 
20
18
  __all__ = [
21
19
  "load_pretrained_params",
20
+ "_build_model",
22
21
  "conv_sequence",
23
22
  "IntermediateLayerGetter",
24
23
  "export_model_to_onnx",
@@ -36,41 +35,42 @@ def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor:
36
35
  return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x
37
36
 
38
37
 
38
+ def _build_model(model: Model):
39
+ """Build a model by calling it once with dummy input
40
+
41
+ Args:
42
+ ----
43
+ model: the model to be built
44
+ """
45
+ model(tf.zeros((1, *model.cfg["input_shape"])), training=False)
46
+
47
+
39
48
  def load_pretrained_params(
40
49
  model: Model,
41
50
  url: Optional[str] = None,
42
51
  hash_prefix: Optional[str] = None,
43
- overwrite: bool = False,
44
- internal_name: str = "weights",
52
+ skip_mismatch: bool = False,
45
53
  **kwargs: Any,
46
54
  ) -> None:
47
55
  """Load a set of parameters onto a model
48
56
 
49
57
  >>> from doctr.models import load_pretrained_params
50
- >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
58
+ >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5")
51
59
 
52
60
  Args:
53
61
  ----
54
62
  model: the keras model to be loaded
55
63
  url: URL of the zipped set of parameters
56
64
  hash_prefix: first characters of SHA256 expected hash
57
- overwrite: should the zip extraction be enforced if the archive has already been extracted
58
- internal_name: name of the ckpt files
65
+ skip_mismatch: skip loading layers with mismatched shapes
59
66
  **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
60
67
  """
61
68
  if url is None:
62
69
  logging.warning("Invalid model URL, using default initialization.")
63
70
  else:
64
71
  archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
65
-
66
- # Unzip the archive
67
- params_path = archive_path.parent.joinpath(archive_path.stem)
68
- if not params_path.is_dir() or overwrite:
69
- with ZipFile(archive_path, "r") as f:
70
- f.extractall(path=params_path)
71
-
72
72
  # Load weights
73
- model.load_weights(f"{params_path}{os.sep}{internal_name}")
73
+ model.load_weights(archive_path, skip_mismatch=skip_mismatch)
74
74
 
75
75
 
76
76
  def conv_sequence(
doctr/models/zoo.py CHANGED
@@ -61,7 +61,7 @@ def _predictor(
61
61
 
62
62
 
63
63
  def ocr_predictor(
64
- det_arch: Any = "db_resnet50",
64
+ det_arch: Any = "fast_base",
65
65
  reco_arch: Any = "crnn_vgg16_bn",
66
66
  pretrained: bool = False,
67
67
  pretrained_backbone: bool = True,
@@ -175,7 +175,7 @@ def _kie_predictor(
175
175
 
176
176
 
177
177
  def kie_predictor(
178
- det_arch: Any = "db_resnet50",
178
+ det_arch: Any = "fast_base",
179
179
  reco_arch: Any = "crnn_vgg16_bn",
180
180
  pretrained: bool = False,
181
181
  pretrained_backbone: bool = True,
doctr/py.typed ADDED
File without changes
@@ -200,4 +200,4 @@ def create_shadow_mask(
200
200
  mask: np.ndarray = np.zeros((*target_shape, 1), dtype=np.uint8)
201
201
  mask = cv2.fillPoly(mask, [final_contour], (255,), lineType=cv2.LINE_AA)[..., 0]
202
202
 
203
- return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32) # type: ignore[operator]
203
+ return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32)
@@ -35,9 +35,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
35
35
  rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
36
36
  # Inverse the color
37
37
  if out.dtype == torch.uint8:
38
- out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) # type: ignore[attr-defined]
38
+ out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
39
39
  else:
40
- out = out * rgb_shift.to(dtype=out.dtype) # type: ignore[attr-defined]
40
+ out = out * rgb_shift.to(dtype=out.dtype)
41
41
  # Inverse the color
42
42
  out = 255 - out if out.dtype == torch.uint8 else 1 - out
43
43
  return out
@@ -81,7 +81,7 @@ def rotate_sample(
81
81
  rotated_geoms: np.ndarray = rotate_abs_geoms(
82
82
  _geoms,
83
83
  angle,
84
- img.shape[1:],
84
+ img.shape[1:], # type: ignore[arg-type]
85
85
  expand,
86
86
  ).astype(np.float32)
87
87
 
@@ -89,7 +89,7 @@ def rotate_sample(
89
89
  rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[2]
90
90
  rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[1]
91
91
 
92
- return rotated_img, np.clip(rotated_geoms, 0, 1)
92
+ return rotated_img, np.clip(np.around(rotated_geoms, decimals=15), 0, 1)
93
93
 
94
94
 
95
95
  def crop_detection(
@@ -132,7 +132,7 @@ def random_shadow(img: torch.Tensor, opacity_range: Tuple[float, float], **kwarg
132
132
  -------
133
133
  shaded image
134
134
  """
135
- shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
135
+ shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type]
136
136
 
137
137
  opacity = np.random.uniform(*opacity_range)
138
138
  shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...])
@@ -5,7 +5,7 @@
5
5
 
6
6
  import math
7
7
  import random
8
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
+ from typing import Any, Callable, List, Optional, Tuple, Union
9
9
 
10
10
  import numpy as np
11
11
 
@@ -168,11 +168,11 @@ class OneOf(NestedObject):
168
168
  def __init__(self, transforms: List[Callable[[Any], Any]]) -> None:
169
169
  self.transforms = transforms
170
170
 
171
- def __call__(self, img: Any) -> Any:
171
+ def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]:
172
172
  # Pick transformation
173
173
  transfo = self.transforms[int(random.random() * len(self.transforms))]
174
174
  # Apply
175
- return transfo(img)
175
+ return transfo(img) if target is None else transfo(img, target) # type: ignore[call-arg]
176
176
 
177
177
 
178
178
  class RandomApply(NestedObject):
@@ -261,17 +261,39 @@ class RandomCrop(NestedObject):
261
261
  def extra_repr(self) -> str:
262
262
  return f"scale={self.scale}, ratio={self.ratio}"
263
263
 
264
- def __call__(self, img: Any, target: Dict[str, np.ndarray]) -> Tuple[Any, Dict[str, np.ndarray]]:
264
+ def __call__(self, img: Any, target: np.ndarray) -> Tuple[Any, np.ndarray]:
265
265
  scale = random.uniform(self.scale[0], self.scale[1])
266
266
  ratio = random.uniform(self.ratio[0], self.ratio[1])
267
- # Those might overflow
268
- crop_h = math.sqrt(scale * ratio)
269
- crop_w = math.sqrt(scale / ratio)
270
- xmin, ymin = random.uniform(0, 1 - crop_w), random.uniform(0, 1 - crop_h)
271
- xmax, ymax = xmin + crop_w, ymin + crop_h
272
- # Clip them
273
- xmin, ymin = max(xmin, 0), max(ymin, 0)
274
- xmax, ymax = min(xmax, 1), min(ymax, 1)
275
-
276
- croped_img, crop_boxes = F.crop_detection(img, target["boxes"], (xmin, ymin, xmax, ymax))
277
- return croped_img, dict(boxes=crop_boxes)
267
+
268
+ height, width = img.shape[:2]
269
+
270
+ # Calculate crop size
271
+ crop_area = scale * width * height
272
+ aspect_ratio = ratio * (width / height)
273
+ crop_width = int(round(math.sqrt(crop_area * aspect_ratio)))
274
+ crop_height = int(round(math.sqrt(crop_area / aspect_ratio)))
275
+
276
+ # Ensure crop size does not exceed image dimensions
277
+ crop_width = min(crop_width, width)
278
+ crop_height = min(crop_height, height)
279
+
280
+ # Randomly select crop position
281
+ x = random.randint(0, width - crop_width)
282
+ y = random.randint(0, height - crop_height)
283
+
284
+ # relative crop box
285
+ crop_box = (x / width, y / height, (x + crop_width) / width, (y + crop_height) / height)
286
+ if target.shape[1:] == (4, 2):
287
+ min_xy = np.min(target, axis=1)
288
+ max_xy = np.max(target, axis=1)
289
+ _target = np.concatenate((min_xy, max_xy), axis=1)
290
+ else:
291
+ _target = target
292
+
293
+ # Crop image and targets
294
+ croped_img, crop_boxes = F.crop_detection(img, _target, crop_box)
295
+ # hard fallback if no box is kept
296
+ if crop_boxes.shape[0] == 0:
297
+ return img, target
298
+ # clip boxes
299
+ return croped_img, np.clip(crop_boxes, 0, 1)
@@ -4,7 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, Dict, Optional, Tuple, Union
7
+ from typing import Optional, Tuple, Union
8
8
 
9
9
  import numpy as np
10
10
  import torch
@@ -15,7 +15,7 @@ from torchvision.transforms import transforms as T
15
15
 
16
16
  from ..functional.pytorch import random_shadow
17
17
 
18
- __all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow"]
18
+ __all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow", "RandomResize"]
19
19
 
20
20
 
21
21
  class Resize(T.Resize):
@@ -74,16 +74,18 @@ class Resize(T.Resize):
74
74
  if self.symmetric_pad:
75
75
  half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2))
76
76
  _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1])
77
+ # Pad image
77
78
  img = pad(img, _pad)
78
79
 
79
80
  # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
80
81
  if target is not None:
82
+ if self.symmetric_pad:
83
+ offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
84
+
81
85
  if self.preserve_aspect_ratio:
82
86
  # Get absolute coords
83
87
  if target.shape[1:] == (4,):
84
88
  if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
85
- if np.max(target) <= 1:
86
- offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
87
89
  target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
88
90
  target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
89
91
  else:
@@ -91,16 +93,15 @@ class Resize(T.Resize):
91
93
  target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
92
94
  elif target.shape[1:] == (4, 2):
93
95
  if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
94
- if np.max(target) <= 1:
95
- offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
96
96
  target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
97
97
  target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
98
98
  else:
99
99
  target[..., 0] *= raw_shape[-1] / img.shape[-1]
100
100
  target[..., 1] *= raw_shape[-2] / img.shape[-2]
101
101
  else:
102
- raise AssertionError
103
- return img, target
102
+ raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")
103
+
104
+ return img, np.clip(target, 0, 1)
104
105
 
105
106
  return img
106
107
 
@@ -135,9 +136,9 @@ class GaussianNoise(torch.nn.Module):
135
136
  # Reshape the distribution
136
137
  noise = self.mean + 2 * self.std * torch.rand(x.shape, device=x.device) - self.std
137
138
  if x.dtype == torch.uint8:
138
- return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8) # type: ignore[attr-defined]
139
+ return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8)
139
140
  else:
140
- return (x + noise.to(dtype=x.dtype)).clamp(0, 1) # type: ignore[attr-defined]
141
+ return (x + noise.to(dtype=x.dtype)).clamp(0, 1)
141
142
 
142
143
  def extra_repr(self) -> str:
143
144
  return f"mean={self.mean}, std={self.std}"
@@ -159,13 +160,16 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
159
160
  """Randomly flip the input image horizontally"""
160
161
 
161
162
  def forward(
162
- self, img: Union[torch.Tensor, Image], target: Dict[str, Any]
163
- ) -> Tuple[Union[torch.Tensor, Image], Dict[str, Any]]:
163
+ self, img: Union[torch.Tensor, Image], target: np.ndarray
164
+ ) -> Tuple[Union[torch.Tensor, Image], np.ndarray]:
164
165
  if torch.rand(1) < self.p:
165
166
  _img = F.hflip(img)
166
167
  _target = target.copy()
167
168
  # Changing the relative bbox coordinates
168
- _target["boxes"][:, ::2] = 1 - target["boxes"][:, [2, 0]]
169
+ if target.shape[1:] == (4,):
170
+ _target[:, ::2] = 1 - target[:, [2, 0]]
171
+ else:
172
+ _target[..., 0] = 1 - target[..., 0]
169
173
  return _img, _target
170
174
  return img, target
171
175
 
@@ -199,7 +203,7 @@ class RandomShadow(torch.nn.Module):
199
203
  self.opacity_range,
200
204
  )
201
205
  )
202
- .round() # type: ignore[attr-defined]
206
+ .round()
203
207
  .clip(0, 255)
204
208
  .to(dtype=torch.uint8)
205
209
  )
@@ -210,3 +214,58 @@ class RandomShadow(torch.nn.Module):
210
214
 
211
215
  def extra_repr(self) -> str:
212
216
  return f"opacity_range={self.opacity_range}"
217
+
218
+
219
+ class RandomResize(torch.nn.Module):
220
+ """Randomly resize the input image and align corresponding targets
221
+
222
+ >>> import torch
223
+ >>> from doctr.transforms import RandomResize
224
+ >>> transfo = RandomResize((0.3, 0.9), preserve_aspect_ratio=True, symmetric_pad=True, p=0.5)
225
+ >>> out = transfo(torch.rand((3, 64, 64)))
226
+
227
+ Args:
228
+ ----
229
+ scale_range: range of the resizing factor for width and height (independently)
230
+ preserve_aspect_ratio: whether to preserve the aspect ratio of the image,
231
+ given a float value, the aspect ratio will be preserved with this probability
232
+ symmetric_pad: whether to symmetrically pad the image,
233
+ given a float value, the symmetric padding will be applied with this probability
234
+ p: probability to apply the transformation
235
+ """
236
+
237
+ def __init__(
238
+ self,
239
+ scale_range: Tuple[float, float] = (0.3, 0.9),
240
+ preserve_aspect_ratio: Union[bool, float] = False,
241
+ symmetric_pad: Union[bool, float] = False,
242
+ p: float = 0.5,
243
+ ) -> None:
244
+ super().__init__()
245
+ self.scale_range = scale_range
246
+ self.preserve_aspect_ratio = preserve_aspect_ratio
247
+ self.symmetric_pad = symmetric_pad
248
+ self.p = p
249
+ self._resize = Resize
250
+
251
+ def forward(self, img: torch.Tensor, target: np.ndarray) -> Tuple[torch.Tensor, np.ndarray]:
252
+ if torch.rand(1) < self.p:
253
+ scale_h = np.random.uniform(*self.scale_range)
254
+ scale_w = np.random.uniform(*self.scale_range)
255
+ new_size = (int(img.shape[-2] * scale_h), int(img.shape[-1] * scale_w))
256
+
257
+ _img, _target = self._resize(
258
+ new_size,
259
+ preserve_aspect_ratio=self.preserve_aspect_ratio
260
+ if isinstance(self.preserve_aspect_ratio, bool)
261
+ else bool(torch.rand(1) <= self.symmetric_pad),
262
+ symmetric_pad=self.symmetric_pad
263
+ if isinstance(self.symmetric_pad, bool)
264
+ else bool(torch.rand(1) <= self.symmetric_pad),
265
+ )(img, target)
266
+
267
+ return _img, _target
268
+ return img, target
269
+
270
+ def extra_repr(self) -> str:
271
+ return f"scale_range={self.scale_range}, preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}, p={self.p}" # noqa: E501
@@ -4,7 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import random
7
- from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
7
+ from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
8
8
 
9
9
  import numpy as np
10
10
  import tensorflow as tf
@@ -30,6 +30,7 @@ __all__ = [
30
30
  "GaussianNoise",
31
31
  "RandomHorizontalFlip",
32
32
  "RandomShadow",
33
+ "RandomResize",
33
34
  ]
34
35
 
35
36
 
@@ -106,29 +107,34 @@ class Resize(NestedObject):
106
107
  target: Optional[np.ndarray] = None,
107
108
  ) -> Union[tf.Tensor, Tuple[tf.Tensor, np.ndarray]]:
108
109
  input_dtype = img.dtype
110
+ self.output_size = (
111
+ (self.output_size, self.output_size) if isinstance(self.output_size, int) else self.output_size
112
+ )
109
113
 
110
114
  img = tf.image.resize(img, self.wanted_size, self.method, self.preserve_aspect_ratio, self.antialias)
111
115
  # It will produce an un-padded resized image, with a side shorter than wanted if we preserve aspect ratio
112
116
  raw_shape = img.shape[:2]
117
+ if self.symmetric_pad:
118
+ half_pad = (int((self.output_size[0] - img.shape[0]) / 2), 0)
113
119
  if self.preserve_aspect_ratio:
114
120
  if isinstance(self.output_size, (tuple, list)):
115
121
  # In that case we need to pad because we want to enforce both width and height
116
122
  if not self.symmetric_pad:
117
- offset = (0, 0)
123
+ half_pad = (0, 0)
118
124
  elif self.output_size[0] == img.shape[0]:
119
- offset = (0, int((self.output_size[1] - img.shape[1]) / 2))
120
- else:
121
- offset = (int((self.output_size[0] - img.shape[0]) / 2), 0)
122
- img = tf.image.pad_to_bounding_box(img, *offset, *self.output_size)
125
+ half_pad = (0, int((self.output_size[1] - img.shape[1]) / 2))
126
+ # Pad image
127
+ img = tf.image.pad_to_bounding_box(img, *half_pad, *self.output_size)
123
128
 
124
129
  # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
125
130
  if target is not None:
131
+ if self.symmetric_pad:
132
+ offset = half_pad[0] / img.shape[0], half_pad[1] / img.shape[1]
133
+
126
134
  if self.preserve_aspect_ratio:
127
135
  # Get absolute coords
128
136
  if target.shape[1:] == (4,):
129
137
  if isinstance(self.output_size, (tuple, list)) and self.symmetric_pad:
130
- if np.max(target) <= 1:
131
- offset = offset[0] / img.shape[0], offset[1] / img.shape[1]
132
138
  target[:, [0, 2]] = offset[1] + target[:, [0, 2]] * raw_shape[1] / img.shape[1]
133
139
  target[:, [1, 3]] = offset[0] + target[:, [1, 3]] * raw_shape[0] / img.shape[0]
134
140
  else:
@@ -136,16 +142,15 @@ class Resize(NestedObject):
136
142
  target[:, [1, 3]] *= raw_shape[0] / img.shape[0]
137
143
  elif target.shape[1:] == (4, 2):
138
144
  if isinstance(self.output_size, (tuple, list)) and self.symmetric_pad:
139
- if np.max(target) <= 1:
140
- offset = offset[0] / img.shape[0], offset[1] / img.shape[1]
141
145
  target[..., 0] = offset[1] + target[..., 0] * raw_shape[1] / img.shape[1]
142
146
  target[..., 1] = offset[0] + target[..., 1] * raw_shape[0] / img.shape[0]
143
147
  else:
144
148
  target[..., 0] *= raw_shape[1] / img.shape[1]
145
149
  target[..., 1] *= raw_shape[0] / img.shape[0]
146
150
  else:
147
- raise AssertionError
148
- return tf.cast(img, dtype=input_dtype), target
151
+ raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")
152
+
153
+ return tf.cast(img, dtype=input_dtype), np.clip(target, 0, 1)
149
154
 
150
155
  return tf.cast(img, dtype=input_dtype)
151
156
 
@@ -394,7 +399,6 @@ class GaussianBlur(NestedObject):
394
399
  def extra_repr(self) -> str:
395
400
  return f"kernel_shape={self.kernel_shape}, std={self.std}"
396
401
 
397
- @tf.function
398
402
  def __call__(self, img: tf.Tensor) -> tf.Tensor:
399
403
  return tf.squeeze(
400
404
  _gaussian_filter(
@@ -457,10 +461,7 @@ class RandomHorizontalFlip(NestedObject):
457
461
  >>> from doctr.transforms import RandomHorizontalFlip
458
462
  >>> transfo = RandomHorizontalFlip(p=0.5)
459
463
  >>> image = tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)
460
- >>> target = {
461
- >>> "boxes": np.array([[0.1, 0.1, 0.4, 0.5] ], dtype= np.float32),
462
- >>> "labels": np.ones(1, dtype= np.int64)
463
- >>> }
464
+ >>> target = np.array([[0.1, 0.1, 0.4, 0.5] ], dtype= np.float32)
464
465
  >>> out = transfo(image, target)
465
466
 
466
467
  Args:
@@ -472,12 +473,15 @@ class RandomHorizontalFlip(NestedObject):
472
473
  super().__init__()
473
474
  self.p = p
474
475
 
475
- def __call__(self, img: Union[tf.Tensor, np.ndarray], target: Dict[str, Any]) -> Tuple[tf.Tensor, Dict[str, Any]]:
476
+ def __call__(self, img: Union[tf.Tensor, np.ndarray], target: np.ndarray) -> Tuple[tf.Tensor, np.ndarray]:
476
477
  if np.random.rand(1) <= self.p:
477
478
  _img = tf.image.flip_left_right(img)
478
479
  _target = target.copy()
479
480
  # Changing the relative bbox coordinates
480
- _target["boxes"][:, ::2] = 1 - target["boxes"][:, [2, 0]]
481
+ if target.shape[1:] == (4,):
482
+ _target[:, ::2] = 1 - target[:, [2, 0]]
483
+ else:
484
+ _target[..., 0] = 1 - target[..., 0]
481
485
  return _img, _target
482
486
  return img, target
483
487
 
@@ -515,3 +519,58 @@ class RandomShadow(NestedObject):
515
519
 
516
520
  def extra_repr(self) -> str:
517
521
  return f"opacity_range={self.opacity_range}"
522
+
523
+
524
+ class RandomResize(NestedObject):
525
+ """Randomly resize the input image and align corresponding targets
526
+
527
+ >>> import tensorflow as tf
528
+ >>> from doctr.transforms import RandomResize
529
+ >>> transfo = RandomResize((0.3, 0.9), preserve_aspect_ratio=True, symmetric_pad=True, p=0.5)
530
+ >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
531
+
532
+ Args:
533
+ ----
534
+ scale_range: range of the resizing factor for width and height (independently)
535
+ preserve_aspect_ratio: whether to preserve the aspect ratio of the image,
536
+ given a float value, the aspect ratio will be preserved with this probability
537
+ symmetric_pad: whether to symmetrically pad the image,
538
+ given a float value, the symmetric padding will be applied with this probability
539
+ p: probability to apply the transformation
540
+ """
541
+
542
+ def __init__(
543
+ self,
544
+ scale_range: Tuple[float, float] = (0.3, 0.9),
545
+ preserve_aspect_ratio: Union[bool, float] = False,
546
+ symmetric_pad: Union[bool, float] = False,
547
+ p: float = 0.5,
548
+ ):
549
+ super().__init__()
550
+ self.scale_range = scale_range
551
+ self.preserve_aspect_ratio = preserve_aspect_ratio
552
+ self.symmetric_pad = symmetric_pad
553
+ self.p = p
554
+ self._resize = Resize
555
+
556
+ def __call__(self, img: tf.Tensor, target: np.ndarray) -> Tuple[tf.Tensor, np.ndarray]:
557
+ if np.random.rand(1) <= self.p:
558
+ scale_h = random.uniform(*self.scale_range)
559
+ scale_w = random.uniform(*self.scale_range)
560
+ new_size = (int(img.shape[-3] * scale_h), int(img.shape[-2] * scale_w))
561
+
562
+ _img, _target = self._resize(
563
+ new_size,
564
+ preserve_aspect_ratio=self.preserve_aspect_ratio
565
+ if isinstance(self.preserve_aspect_ratio, bool)
566
+ else bool(np.random.rand(1) <= self.symmetric_pad),
567
+ symmetric_pad=self.symmetric_pad
568
+ if isinstance(self.symmetric_pad, bool)
569
+ else bool(np.random.rand(1) <= self.symmetric_pad),
570
+ )(img, target)
571
+
572
+ return _img, _target
573
+ return img, target
574
+
575
+ def extra_repr(self) -> str:
576
+ return f"scale_range={self.scale_range}, preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}, p={self.p}" # noqa: E501
doctr/utils/fonts.py CHANGED
@@ -5,14 +5,16 @@
5
5
 
6
6
  import logging
7
7
  import platform
8
- from typing import Optional
8
+ from typing import Optional, Union
9
9
 
10
10
  from PIL import ImageFont
11
11
 
12
12
  __all__ = ["get_font"]
13
13
 
14
14
 
15
- def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFont.ImageFont:
15
+ def get_font(
16
+ font_family: Optional[str] = None, font_size: int = 13
17
+ ) -> Union[ImageFont.FreeTypeFont, ImageFont.ImageFont]:
16
18
  """Resolves a compatible ImageFont for the system
17
19
 
18
20
  Args:
@@ -28,14 +30,14 @@ def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFon
28
30
  if font_family is None:
29
31
  try:
30
32
  font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", font_size)
31
- except OSError:
32
- font = ImageFont.load_default()
33
+ except OSError: # pragma: no cover
34
+ font = ImageFont.load_default() # type: ignore[assignment]
33
35
  logging.warning(
34
36
  "unable to load recommended font family. Loading default PIL font,"
35
37
  "font size issues may be expected."
36
38
  "To prevent this, it is recommended to specify the value of 'font_family'."
37
39
  )
38
- else:
40
+ else: # pragma: no cover
39
41
  font = ImageFont.truetype(font_family, font_size)
40
42
 
41
43
  return font