python-doctr 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/datasets/pytorch.py +2 -2
  6. doctr/datasets/generator/base.py +6 -5
  7. doctr/datasets/imgur5k.py +1 -1
  8. doctr/datasets/loader.py +1 -6
  9. doctr/datasets/utils.py +2 -1
  10. doctr/datasets/vocabs.py +9 -2
  11. doctr/file_utils.py +26 -12
  12. doctr/io/elements.py +40 -6
  13. doctr/io/html.py +2 -2
  14. doctr/io/image/pytorch.py +6 -8
  15. doctr/io/image/tensorflow.py +1 -1
  16. doctr/io/pdf.py +5 -2
  17. doctr/io/reader.py +6 -0
  18. doctr/models/__init__.py +0 -1
  19. doctr/models/_utils.py +57 -20
  20. doctr/models/builder.py +71 -13
  21. doctr/models/classification/mobilenet/pytorch.py +45 -9
  22. doctr/models/classification/mobilenet/tensorflow.py +38 -7
  23. doctr/models/classification/predictor/pytorch.py +18 -11
  24. doctr/models/classification/predictor/tensorflow.py +16 -10
  25. doctr/models/classification/textnet/pytorch.py +3 -3
  26. doctr/models/classification/textnet/tensorflow.py +3 -3
  27. doctr/models/classification/zoo.py +39 -15
  28. doctr/models/detection/_utils/__init__.py +1 -0
  29. doctr/models/detection/_utils/base.py +66 -0
  30. doctr/models/detection/differentiable_binarization/base.py +4 -3
  31. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  32. doctr/models/detection/fast/base.py +6 -5
  33. doctr/models/detection/fast/pytorch.py +4 -4
  34. doctr/models/detection/fast/tensorflow.py +4 -4
  35. doctr/models/detection/linknet/base.py +4 -3
  36. doctr/models/detection/predictor/pytorch.py +15 -1
  37. doctr/models/detection/predictor/tensorflow.py +15 -1
  38. doctr/models/detection/zoo.py +7 -2
  39. doctr/models/factory/hub.py +3 -12
  40. doctr/models/kie_predictor/base.py +9 -3
  41. doctr/models/kie_predictor/pytorch.py +41 -20
  42. doctr/models/kie_predictor/tensorflow.py +36 -16
  43. doctr/models/modules/layers/pytorch.py +2 -3
  44. doctr/models/modules/layers/tensorflow.py +6 -8
  45. doctr/models/modules/transformer/pytorch.py +2 -2
  46. doctr/models/predictor/base.py +77 -50
  47. doctr/models/predictor/pytorch.py +31 -20
  48. doctr/models/predictor/tensorflow.py +27 -17
  49. doctr/models/preprocessor/pytorch.py +4 -4
  50. doctr/models/preprocessor/tensorflow.py +3 -2
  51. doctr/models/recognition/master/pytorch.py +2 -2
  52. doctr/models/recognition/parseq/pytorch.py +4 -3
  53. doctr/models/recognition/parseq/tensorflow.py +4 -3
  54. doctr/models/recognition/sar/pytorch.py +7 -6
  55. doctr/models/recognition/sar/tensorflow.py +3 -9
  56. doctr/models/recognition/vitstr/pytorch.py +1 -1
  57. doctr/models/recognition/zoo.py +1 -1
  58. doctr/models/zoo.py +2 -2
  59. doctr/py.typed +0 -0
  60. doctr/transforms/functional/base.py +1 -1
  61. doctr/transforms/functional/pytorch.py +4 -4
  62. doctr/transforms/modules/base.py +37 -15
  63. doctr/transforms/modules/pytorch.py +66 -8
  64. doctr/transforms/modules/tensorflow.py +63 -7
  65. doctr/utils/fonts.py +7 -5
  66. doctr/utils/geometry.py +35 -12
  67. doctr/utils/metrics.py +33 -174
  68. doctr/utils/reconstitution.py +126 -0
  69. doctr/utils/visualization.py +5 -118
  70. doctr/version.py +1 -1
  71. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/METADATA +84 -80
  72. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/RECORD +76 -76
  73. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/WHEEL +1 -1
  74. doctr/models/artefacts/__init__.py +0 -2
  75. doctr/models/artefacts/barcode.py +0 -74
  76. doctr/models/artefacts/face.py +0 -63
  77. doctr/models/obj_detection/__init__.py +0 -1
  78. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  79. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  80. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/LICENSE +0 -0
  81. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/top_level.txt +0 -0
  82. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/zip-safe +0 -0
@@ -9,10 +9,10 @@ import numpy as np
9
9
  import tensorflow as tf
10
10
 
11
11
  from doctr.io.elements import Document
12
- from doctr.models._utils import estimate_orientation, get_language
12
+ from doctr.models._utils import get_language
13
13
  from doctr.models.detection.predictor import DetectionPredictor
14
14
  from doctr.models.recognition.predictor import RecognitionPredictor
15
- from doctr.utils.geometry import rotate_image
15
+ from doctr.utils.geometry import detach_scores
16
16
  from doctr.utils.repr import NestedObject
17
17
 
18
18
  from .base import _OCRPredictor
@@ -56,7 +56,13 @@ class OCRPredictor(NestedObject, _OCRPredictor):
56
56
  self.det_predictor = det_predictor
57
57
  self.reco_predictor = reco_predictor
58
58
  _OCRPredictor.__init__(
59
- self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
59
+ self,
60
+ assume_straight_pages,
61
+ straighten_pages,
62
+ preserve_aspect_ratio,
63
+ symmetric_pad,
64
+ detect_orientation,
65
+ **kwargs,
60
66
  )
61
67
  self.detect_orientation = detect_orientation
62
68
  self.detect_language = detect_language
@@ -81,19 +87,16 @@ class OCRPredictor(NestedObject, _OCRPredictor):
81
87
  for out_map in out_maps
82
88
  ]
83
89
  if self.detect_orientation:
84
- origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
90
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
85
91
  orientations = [
86
- {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
92
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
87
93
  ]
88
94
  else:
89
95
  orientations = None
96
+ general_pages_orientations = None
97
+ origin_pages_orientations = None
90
98
  if self.straighten_pages:
91
- origin_page_orientations = (
92
- origin_page_orientations
93
- if self.detect_orientation
94
- else [estimate_orientation(seq_map) for seq_map in seg_maps]
95
- )
96
- pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
99
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
97
100
  # forward again to get predictions on straight pages
98
101
  loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
99
102
 
@@ -101,9 +104,8 @@ class OCRPredictor(NestedObject, _OCRPredictor):
101
104
  len(loc_pred) == 1 for loc_pred in loc_preds_dict
102
105
  ), "Detection Model in ocr_predictor should output only one class"
103
106
  loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr]
104
-
105
- # Rectify crops if aspect ratio
106
- loc_preds = self._remove_padding(pages, loc_preds)
107
+ # Detach objectness scores from loc_preds
108
+ loc_preds, objectness_scores = detach_scores(loc_preds)
107
109
 
108
110
  # Apply hooks to loc_preds if any
109
111
  for hook in self.hooks:
@@ -113,14 +115,20 @@ class OCRPredictor(NestedObject, _OCRPredictor):
113
115
  crops, loc_preds = self._prepare_crops(
114
116
  pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
115
117
  )
116
- # Rectify crop orientation
118
+ # Rectify crop orientation and get crop orientation predictions
119
+ crop_orientations: Any = []
117
120
  if not self.assume_straight_pages:
118
- crops, loc_preds = self._rectify_crops(crops, loc_preds)
121
+ crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds)
122
+ crop_orientations = [
123
+ {"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations
124
+ ]
119
125
 
120
126
  # Identify character sequences
121
127
  word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
128
+ if not crop_orientations:
129
+ crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds]
122
130
 
123
- boxes, text_preds = self._process_predictions(loc_preds, word_preds)
131
+ boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations)
124
132
 
125
133
  if self.detect_language:
126
134
  languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds]
@@ -131,8 +139,10 @@ class OCRPredictor(NestedObject, _OCRPredictor):
131
139
  out = self.doc_builder(
132
140
  pages,
133
141
  boxes,
142
+ objectness_scores,
134
143
  text_preds,
135
144
  origin_page_shapes, # type: ignore[arg-type]
145
+ crop_orientations,
136
146
  orientations,
137
147
  languages_dict,
138
148
  )
@@ -79,7 +79,7 @@ class PreProcessor(nn.Module):
79
79
  else:
80
80
  x = x.to(dtype=torch.float32) # type: ignore[union-attr]
81
81
 
82
- return x # type: ignore[return-value]
82
+ return x
83
83
 
84
84
  def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]) -> List[torch.Tensor]:
85
85
  """Prepare document data for model forwarding
@@ -103,7 +103,7 @@ class PreProcessor(nn.Module):
103
103
  elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
104
104
  raise TypeError("unsupported data type for torch.Tensor")
105
105
  # Resizing
106
- if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: # type: ignore[union-attr]
106
+ if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
107
107
  x = F.resize(
108
108
  x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
109
109
  )
@@ -118,11 +118,11 @@ class PreProcessor(nn.Module):
118
118
  # Sample transform (to tensor, resize)
119
119
  samples = list(multithread_exec(self.sample_transforms, x))
120
120
  # Batching
121
- batches = self.batch_inputs(samples) # type: ignore[assignment]
121
+ batches = self.batch_inputs(samples)
122
122
  else:
123
123
  raise TypeError(f"invalid input type: {type(x)}")
124
124
 
125
125
  # Batch transforms (normalize)
126
126
  batches = list(multithread_exec(self.normalize, batches))
127
127
 
128
- return batches # type: ignore[return-value]
128
+ return batches
@@ -41,6 +41,7 @@ class PreProcessor(NestedObject):
41
41
  self.resize = Resize(output_size, **kwargs)
42
42
  # Perform the division by 255 at the same time
43
43
  self.normalize = Normalize(mean, std)
44
+ self._runs_on_cuda = tf.test.is_gpu_available()
44
45
 
45
46
  def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]:
46
47
  """Gather samples into batches for inference purposes
@@ -113,13 +114,13 @@ class PreProcessor(NestedObject):
113
114
 
114
115
  elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x):
115
116
  # Sample transform (to tensor, resize)
116
- samples = list(multithread_exec(self.sample_transforms, x))
117
+ samples = list(multithread_exec(self.sample_transforms, x, threads=1 if self._runs_on_cuda else None))
117
118
  # Batching
118
119
  batches = self.batch_inputs(samples)
119
120
  else:
120
121
  raise TypeError(f"invalid input type: {type(x)}")
121
122
 
122
123
  # Batch transforms (normalize)
123
- batches = list(multithread_exec(self.normalize, batches))
124
+ batches = list(multithread_exec(self.normalize, batches, threads=1 if self._runs_on_cuda else None))
124
125
 
125
126
  return batches
@@ -107,7 +107,7 @@ class MASTER(_MASTER, nn.Module):
107
107
  # NOTE: nn.TransformerDecoder takes the inverse from this implementation
108
108
  # [True, True, True, ..., False, False, False] -> False is masked
109
109
  # (N, 1, 1, max_length)
110
- target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
110
+ target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
111
111
  target_length = target.size(1)
112
112
  # sub mask filled diagonal with True = see and False = masked (max_length, max_length)
113
113
  # NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
@@ -142,7 +142,7 @@ class MASTER(_MASTER, nn.Module):
142
142
  # Input length : number of timesteps
143
143
  input_len = model_output.shape[1]
144
144
  # Add one for additional <eos> token (sos disappear in shift!)
145
- seq_len = seq_len + 1 # type: ignore[assignment]
145
+ seq_len = seq_len + 1
146
146
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
147
147
  # The "masked" first gt char is <sos>. Delete last logit of the model output.
148
148
  cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -212,7 +212,7 @@ class PARSeq(_PARSeq, nn.Module):
212
212
 
213
213
  sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
214
214
  eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
215
- combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int() # type: ignore
215
+ combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
216
216
  if len(combined) > 1:
217
217
  combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
218
218
  return combined
@@ -282,7 +282,8 @@ class PARSeq(_PARSeq, nn.Module):
282
282
  ys[:, i + 1] = pos_prob.squeeze().argmax(-1)
283
283
 
284
284
  # Stop decoding if all sequences have reached the EOS token
285
- if max_len is None and (ys == self.vocab_size).any(dim=-1).all(): # type: ignore[attr-defined]
285
+ # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
286
+ if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
286
287
  break
287
288
 
288
289
  logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
@@ -297,7 +298,7 @@ class PARSeq(_PARSeq, nn.Module):
297
298
 
298
299
  # Create padding mask for refined target input maskes all behind EOS token as False
299
300
  # (N, 1, 1, max_length)
300
- target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
301
+ target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
301
302
  mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
302
303
  logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
303
304
 
@@ -288,10 +288,11 @@ class PARSeq(_PARSeq, Model):
288
288
  )
289
289
 
290
290
  # Stop decoding if all sequences have reached the EOS token
291
- # We need to check it on True to be compatible with ONNX
291
+ # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
292
292
  if (
293
- max_len is None
294
- and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1)) is True
293
+ not self.exportable
294
+ and max_len is None
295
+ and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1))
295
296
  ):
296
297
  break
297
298
 
@@ -125,25 +125,26 @@ class SARDecoder(nn.Module):
125
125
  if t == 0:
126
126
  # step to init the first states of the LSTMCell
127
127
  hidden_state_init = cell_state_init = torch.zeros(
128
- features.size(0), features.size(1), device=features.device
128
+ features.size(0), features.size(1), device=features.device, dtype=features.dtype
129
129
  )
130
130
  hidden_state, cell_state = hidden_state_init, cell_state_init
131
131
  prev_symbol = holistic
132
132
  elif t == 1:
133
133
  # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
134
134
  # (N, vocab_size + 1) --> (N, embedding_units)
135
- prev_symbol = torch.zeros(features.size(0), self.vocab_size + 1, device=features.device)
135
+ prev_symbol = torch.zeros(
136
+ features.size(0), self.vocab_size + 1, device=features.device, dtype=features.dtype
137
+ )
136
138
  prev_symbol = self.embed(prev_symbol)
137
139
  else:
138
- if gt is not None:
140
+ if gt is not None and self.training:
139
141
  # (N, embedding_units) -2 because of <bos> and <eos> (same)
140
142
  prev_symbol = self.embed(gt_embedding[:, t - 2])
141
143
  else:
142
144
  # -1 to start at timestep where prev_symbol was initialized
143
145
  index = logits_list[t - 1].argmax(-1)
144
146
  # update prev_symbol with ones at the index of the previous logit vector
145
- # (N, embedding_units)
146
- prev_symbol = prev_symbol.scatter_(1, index.unsqueeze(1), 1)
147
+ prev_symbol = self.embed(self.embed_tgt(index))
147
148
 
148
149
  # (N, C), (N, C) take the last hidden state and cell state from current timestep
149
150
  hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init))
@@ -292,7 +293,7 @@ class SAR(nn.Module, RecognitionModel):
292
293
  # Input length : number of timesteps
293
294
  input_len = model_output.shape[1]
294
295
  # Add one for additional <eos> token
295
- seq_len = seq_len + 1 # type: ignore[assignment]
296
+ seq_len = seq_len + 1
296
297
  # Compute loss
297
298
  # (N, L, vocab_size + 1)
298
299
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
@@ -177,23 +177,17 @@ class SARDecoder(layers.Layer, NestedObject):
177
177
  elif t == 1:
178
178
  # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
179
179
  # (N, vocab_size + 1) --> (N, embedding_units)
180
- prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1])
180
+ prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1], dtype=features.dtype)
181
181
  prev_symbol = self.embed(prev_symbol, **kwargs)
182
182
  else:
183
- if gt is not None:
183
+ if gt is not None and kwargs.get("training", False):
184
184
  # (N, embedding_units) -2 because of <bos> and <eos> (same)
185
185
  prev_symbol = self.embed(gt_embedding[:, t - 2], **kwargs)
186
186
  else:
187
187
  # -1 to start at timestep where prev_symbol was initialized
188
188
  index = tf.argmax(logits_list[t - 1], axis=-1)
189
189
  # update prev_symbol with ones at the index of the previous logit vector
190
- # (N, embedding_units)
191
- index = tf.ones_like(index)
192
- prev_symbol = tf.scatter_nd(
193
- tf.expand_dims(index, axis=1),
194
- prev_symbol,
195
- tf.constant([features.shape[0], features.shape[-1]], dtype=tf.int64),
196
- )
190
+ prev_symbol = self.embed(self.embed_tgt(index, **kwargs), **kwargs)
197
191
 
198
192
  # (N, C), (N, C) take the last hidden state and cell state from current timestep
199
193
  _, states = self.lstm_cells(prev_symbol, states, **kwargs)
@@ -137,7 +137,7 @@ class ViTSTR(_ViTSTR, nn.Module):
137
137
  # Input length : number of steps
138
138
  input_len = model_output.shape[1]
139
139
  # Add one for additional <eos> token (sos disappear in shift!)
140
- seq_len = seq_len + 1 # type: ignore[assignment]
140
+ seq_len = seq_len + 1
141
141
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
142
142
  # The "masked" first gt char is <sos>.
143
143
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -45,7 +45,7 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
45
45
 
46
46
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
47
47
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
48
- kwargs["batch_size"] = kwargs.get("batch_size", 32)
48
+ kwargs["batch_size"] = kwargs.get("batch_size", 128)
49
49
  input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:]
50
50
  predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model)
51
51
 
doctr/models/zoo.py CHANGED
@@ -61,7 +61,7 @@ def _predictor(
61
61
 
62
62
 
63
63
  def ocr_predictor(
64
- det_arch: Any = "db_resnet50",
64
+ det_arch: Any = "fast_base",
65
65
  reco_arch: Any = "crnn_vgg16_bn",
66
66
  pretrained: bool = False,
67
67
  pretrained_backbone: bool = True,
@@ -175,7 +175,7 @@ def _kie_predictor(
175
175
 
176
176
 
177
177
  def kie_predictor(
178
- det_arch: Any = "db_resnet50",
178
+ det_arch: Any = "fast_base",
179
179
  reco_arch: Any = "crnn_vgg16_bn",
180
180
  pretrained: bool = False,
181
181
  pretrained_backbone: bool = True,
doctr/py.typed ADDED
File without changes
@@ -200,4 +200,4 @@ def create_shadow_mask(
200
200
  mask: np.ndarray = np.zeros((*target_shape, 1), dtype=np.uint8)
201
201
  mask = cv2.fillPoly(mask, [final_contour], (255,), lineType=cv2.LINE_AA)[..., 0]
202
202
 
203
- return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32) # type: ignore[operator]
203
+ return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32)
@@ -35,9 +35,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
35
35
  rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
36
36
  # Inverse the color
37
37
  if out.dtype == torch.uint8:
38
- out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) # type: ignore[attr-defined]
38
+ out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
39
39
  else:
40
- out = out * rgb_shift.to(dtype=out.dtype) # type: ignore[attr-defined]
40
+ out = out * rgb_shift.to(dtype=out.dtype)
41
41
  # Inverse the color
42
42
  out = 255 - out if out.dtype == torch.uint8 else 1 - out
43
43
  return out
@@ -81,7 +81,7 @@ def rotate_sample(
81
81
  rotated_geoms: np.ndarray = rotate_abs_geoms(
82
82
  _geoms,
83
83
  angle,
84
- img.shape[1:],
84
+ img.shape[1:], # type: ignore[arg-type]
85
85
  expand,
86
86
  ).astype(np.float32)
87
87
 
@@ -132,7 +132,7 @@ def random_shadow(img: torch.Tensor, opacity_range: Tuple[float, float], **kwarg
132
132
  -------
133
133
  shaded image
134
134
  """
135
- shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
135
+ shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type]
136
136
 
137
137
  opacity = np.random.uniform(*opacity_range)
138
138
  shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...])
@@ -5,7 +5,7 @@
5
5
 
6
6
  import math
7
7
  import random
8
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
+ from typing import Any, Callable, List, Optional, Tuple, Union
9
9
 
10
10
  import numpy as np
11
11
 
@@ -168,11 +168,11 @@ class OneOf(NestedObject):
168
168
  def __init__(self, transforms: List[Callable[[Any], Any]]) -> None:
169
169
  self.transforms = transforms
170
170
 
171
- def __call__(self, img: Any) -> Any:
171
+ def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]:
172
172
  # Pick transformation
173
173
  transfo = self.transforms[int(random.random() * len(self.transforms))]
174
174
  # Apply
175
- return transfo(img)
175
+ return transfo(img) if target is None else transfo(img, target) # type: ignore[call-arg]
176
176
 
177
177
 
178
178
  class RandomApply(NestedObject):
@@ -261,17 +261,39 @@ class RandomCrop(NestedObject):
261
261
  def extra_repr(self) -> str:
262
262
  return f"scale={self.scale}, ratio={self.ratio}"
263
263
 
264
- def __call__(self, img: Any, target: Dict[str, np.ndarray]) -> Tuple[Any, Dict[str, np.ndarray]]:
264
+ def __call__(self, img: Any, target: np.ndarray) -> Tuple[Any, np.ndarray]:
265
265
  scale = random.uniform(self.scale[0], self.scale[1])
266
266
  ratio = random.uniform(self.ratio[0], self.ratio[1])
267
- # Those might overflow
268
- crop_h = math.sqrt(scale * ratio)
269
- crop_w = math.sqrt(scale / ratio)
270
- xmin, ymin = random.uniform(0, 1 - crop_w), random.uniform(0, 1 - crop_h)
271
- xmax, ymax = xmin + crop_w, ymin + crop_h
272
- # Clip them
273
- xmin, ymin = max(xmin, 0), max(ymin, 0)
274
- xmax, ymax = min(xmax, 1), min(ymax, 1)
275
-
276
- croped_img, crop_boxes = F.crop_detection(img, target["boxes"], (xmin, ymin, xmax, ymax))
277
- return croped_img, dict(boxes=crop_boxes)
267
+
268
+ height, width = img.shape[:2]
269
+
270
+ # Calculate crop size
271
+ crop_area = scale * width * height
272
+ aspect_ratio = ratio * (width / height)
273
+ crop_width = int(round(math.sqrt(crop_area * aspect_ratio)))
274
+ crop_height = int(round(math.sqrt(crop_area / aspect_ratio)))
275
+
276
+ # Ensure crop size does not exceed image dimensions
277
+ crop_width = min(crop_width, width)
278
+ crop_height = min(crop_height, height)
279
+
280
+ # Randomly select crop position
281
+ x = random.randint(0, width - crop_width)
282
+ y = random.randint(0, height - crop_height)
283
+
284
+ # relative crop box
285
+ crop_box = (x / width, y / height, (x + crop_width) / width, (y + crop_height) / height)
286
+ if target.shape[1:] == (4, 2):
287
+ min_xy = np.min(target, axis=1)
288
+ max_xy = np.max(target, axis=1)
289
+ _target = np.concatenate((min_xy, max_xy), axis=1)
290
+ else:
291
+ _target = target
292
+
293
+ # Crop image and targets
294
+ croped_img, crop_boxes = F.crop_detection(img, _target, crop_box)
295
+ # hard fallback if no box is kept
296
+ if crop_boxes.shape[0] == 0:
297
+ return img, target
298
+ # clip boxes
299
+ return croped_img, np.clip(crop_boxes, 0, 1)
@@ -4,7 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
- from typing import Any, Dict, Optional, Tuple, Union
7
+ from typing import Optional, Tuple, Union
8
8
 
9
9
  import numpy as np
10
10
  import torch
@@ -15,7 +15,7 @@ from torchvision.transforms import transforms as T
15
15
 
16
16
  from ..functional.pytorch import random_shadow
17
17
 
18
- __all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow"]
18
+ __all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow", "RandomResize"]
19
19
 
20
20
 
21
21
  class Resize(T.Resize):
@@ -135,9 +135,9 @@ class GaussianNoise(torch.nn.Module):
135
135
  # Reshape the distribution
136
136
  noise = self.mean + 2 * self.std * torch.rand(x.shape, device=x.device) - self.std
137
137
  if x.dtype == torch.uint8:
138
- return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8) # type: ignore[attr-defined]
138
+ return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8)
139
139
  else:
140
- return (x + noise.to(dtype=x.dtype)).clamp(0, 1) # type: ignore[attr-defined]
140
+ return (x + noise.to(dtype=x.dtype)).clamp(0, 1)
141
141
 
142
142
  def extra_repr(self) -> str:
143
143
  return f"mean={self.mean}, std={self.std}"
@@ -159,13 +159,16 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
159
159
  """Randomly flip the input image horizontally"""
160
160
 
161
161
  def forward(
162
- self, img: Union[torch.Tensor, Image], target: Dict[str, Any]
163
- ) -> Tuple[Union[torch.Tensor, Image], Dict[str, Any]]:
162
+ self, img: Union[torch.Tensor, Image], target: np.ndarray
163
+ ) -> Tuple[Union[torch.Tensor, Image], np.ndarray]:
164
164
  if torch.rand(1) < self.p:
165
165
  _img = F.hflip(img)
166
166
  _target = target.copy()
167
167
  # Changing the relative bbox coordinates
168
- _target["boxes"][:, ::2] = 1 - target["boxes"][:, [2, 0]]
168
+ if target.shape[1:] == (4,):
169
+ _target[:, ::2] = 1 - target[:, [2, 0]]
170
+ else:
171
+ _target[..., 0] = 1 - target[..., 0]
169
172
  return _img, _target
170
173
  return img, target
171
174
 
@@ -199,7 +202,7 @@ class RandomShadow(torch.nn.Module):
199
202
  self.opacity_range,
200
203
  )
201
204
  )
202
- .round() # type: ignore[attr-defined]
205
+ .round()
203
206
  .clip(0, 255)
204
207
  .to(dtype=torch.uint8)
205
208
  )
@@ -210,3 +213,58 @@ class RandomShadow(torch.nn.Module):
210
213
 
211
214
  def extra_repr(self) -> str:
212
215
  return f"opacity_range={self.opacity_range}"
216
+
217
+
218
+ class RandomResize(torch.nn.Module):
219
+ """Randomly resize the input image and align corresponding targets
220
+
221
+ >>> import torch
222
+ >>> from doctr.transforms import RandomResize
223
+ >>> transfo = RandomResize((0.3, 0.9), preserve_aspect_ratio=True, symmetric_pad=True, p=0.5)
224
+ >>> out = transfo(torch.rand((3, 64, 64)))
225
+
226
+ Args:
227
+ ----
228
+ scale_range: range of the resizing factor for width and height (independently)
229
+ preserve_aspect_ratio: whether to preserve the aspect ratio of the image,
230
+ given a float value, the aspect ratio will be preserved with this probability
231
+ symmetric_pad: whether to symmetrically pad the image,
232
+ given a float value, the symmetric padding will be applied with this probability
233
+ p: probability to apply the transformation
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ scale_range: Tuple[float, float] = (0.3, 0.9),
239
+ preserve_aspect_ratio: Union[bool, float] = False,
240
+ symmetric_pad: Union[bool, float] = False,
241
+ p: float = 0.5,
242
+ ) -> None:
243
+ super().__init__()
244
+ self.scale_range = scale_range
245
+ self.preserve_aspect_ratio = preserve_aspect_ratio
246
+ self.symmetric_pad = symmetric_pad
247
+ self.p = p
248
+ self._resize = Resize
249
+
250
+ def forward(self, img: torch.Tensor, target: np.ndarray) -> Tuple[torch.Tensor, np.ndarray]:
251
+ if torch.rand(1) < self.p:
252
+ scale_h = np.random.uniform(*self.scale_range)
253
+ scale_w = np.random.uniform(*self.scale_range)
254
+ new_size = (int(img.shape[-2] * scale_h), int(img.shape[-1] * scale_w))
255
+
256
+ _img, _target = self._resize(
257
+ new_size,
258
+ preserve_aspect_ratio=self.preserve_aspect_ratio
259
+ if isinstance(self.preserve_aspect_ratio, bool)
260
+ else bool(torch.rand(1) <= self.symmetric_pad),
261
+ symmetric_pad=self.symmetric_pad
262
+ if isinstance(self.symmetric_pad, bool)
263
+ else bool(torch.rand(1) <= self.symmetric_pad),
264
+ )(img, target)
265
+
266
+ return _img, _target
267
+ return img, target
268
+
269
+ def extra_repr(self) -> str:
270
+ return f"scale_range={self.scale_range}, preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}, p={self.p}" # noqa: E501