python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -14,7 +14,7 @@ from torchvision.models._utils import IntermediateLayerGetter
14
14
  from doctr.datasets import VOCABS
15
15
 
16
16
  from ...classification import vit_b, vit_s
17
- from ...utils.pytorch import load_pretrained_params
17
+ from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
18
18
  from .base import _ViTSTR, _ViTSTRPostProcessor
19
19
 
20
20
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
@@ -25,14 +25,14 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
25
25
  "std": (0.299, 0.296, 0.301),
26
26
  "input_shape": (3, 32, 128),
27
27
  "vocab": VOCABS["french"],
28
- "url": None,
28
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/vitstr_small-fcd12655.pt&src=0",
29
29
  },
30
30
  "vitstr_base": {
31
31
  "mean": (0.694, 0.695, 0.693),
32
32
  "std": (0.299, 0.296, 0.301),
33
33
  "input_shape": (3, 32, 128),
34
34
  "vocab": VOCABS["french"],
35
- "url": None,
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/vitstr_base-50b21df2.pt&src=0",
36
36
  },
37
37
  }
38
38
 
@@ -42,6 +42,7 @@ class ViTSTR(_ViTSTR, nn.Module):
42
42
  Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
43
43
 
44
44
  Args:
45
+ ----
45
46
  feature_extractor: the backbone serving as feature extractor
46
47
  vocab: vocabulary used for encoding
47
48
  embedding_units: number of embedding units
@@ -95,7 +96,7 @@ class ViTSTR(_ViTSTR, nn.Module):
95
96
  B, N, E = features.size()
96
97
  features = features.reshape(B * N, E)
97
98
  logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1)
98
- decoded_features = logits[:, 1:] # remove cls_token
99
+ decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
99
100
 
100
101
  out: Dict[str, Any] = {}
101
102
  if self.exportable:
@@ -124,17 +125,19 @@ class ViTSTR(_ViTSTR, nn.Module):
124
125
  Sequences are masked after the EOS character.
125
126
 
126
127
  Args:
128
+ ----
127
129
  model_output: predicted logits of the model
128
130
  gt: the encoded tensor with gt labels
129
131
  seq_len: lengths of each gt word inside the batch
130
132
 
131
133
  Returns:
134
+ -------
132
135
  The loss of the model on the batch
133
136
  """
134
137
  # Input length : number of steps
135
138
  input_len = model_output.shape[1]
136
139
  # Add one for additional <eos> token (sos disappear in shift!)
137
- seq_len = seq_len + 1
140
+ seq_len = seq_len + 1 # type: ignore[assignment]
138
141
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
139
142
  # The "masked" first gt char is <sos>.
140
143
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -150,6 +153,7 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
150
153
  """Post processor for ViTSTR architecture
151
154
 
152
155
  Args:
156
+ ----
153
157
  vocab: string containing the ordered sequence of supported characters
154
158
  """
155
159
 
@@ -159,18 +163,19 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
159
163
  ) -> List[Tuple[str, float]]:
160
164
  # compute pred with argmax for attention models
161
165
  out_idxs = logits.argmax(-1)
162
- # N x L
163
- probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
164
- # Take the minimum confidence of the sequence
165
- probs = probs.min(dim=1).values.detach().cpu()
166
+ preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
166
167
 
167
168
  # Manual decoding
168
169
  word_values = [
169
170
  "".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
170
171
  for encoded_seq in out_idxs.cpu().numpy()
171
172
  ]
173
+ # compute probabilties for each word up to the EOS token
174
+ probs = [
175
+ preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values)
176
+ ]
172
177
 
173
- return list(zip(word_values, probs.numpy().tolist()))
178
+ return list(zip(word_values, probs))
174
179
 
175
180
 
176
181
  def _vitstr(
@@ -223,12 +228,14 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
223
228
  >>> out = model(input_tensor)
224
229
 
225
230
  Args:
231
+ ----
226
232
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
233
+ kwargs: keyword arguments of the ViTSTR architecture
227
234
 
228
235
  Returns:
236
+ -------
229
237
  text recognition architecture
230
238
  """
231
-
232
239
  return _vitstr(
233
240
  "vitstr_small",
234
241
  pretrained,
@@ -252,12 +259,14 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
252
259
  >>> out = model(input_tensor)
253
260
 
254
261
  Args:
262
+ ----
255
263
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
264
+ kwargs: keyword arguments of the ViTSTR architecture
256
265
 
257
266
  Returns:
267
+ -------
258
268
  text recognition architecture
259
269
  """
260
-
261
270
  return _vitstr(
262
271
  "vitstr_base",
263
272
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -12,7 +12,7 @@ from tensorflow.keras import Model, layers
12
12
  from doctr.datasets import VOCABS
13
13
 
14
14
  from ...classification import vit_b, vit_s
15
- from ...utils.tensorflow import load_pretrained_params
15
+ from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
16
16
  from .base import _ViTSTR, _ViTSTRPostProcessor
17
17
 
18
18
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
@@ -40,6 +40,7 @@ class ViTSTR(_ViTSTR, Model):
40
40
  Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
41
41
 
42
42
  Args:
43
+ ----
43
44
  feature_extractor: the backbone serving as feature extractor
44
45
  vocab: vocabulary used for encoding
45
46
  embedding_units: number of embedding units
@@ -84,11 +85,13 @@ class ViTSTR(_ViTSTR, Model):
84
85
  Sequences are masked after the EOS character.
85
86
 
86
87
  Args:
88
+ ----
87
89
  model_output: predicted logits of the model
88
90
  gt: the encoded tensor with gt labels
89
91
  seq_len: lengths of each gt word inside the batch
90
92
 
91
93
  Returns:
94
+ -------
92
95
  The loss of the model on the batch
93
96
  """
94
97
  # Input length : number of steps
@@ -131,7 +134,7 @@ class ViTSTR(_ViTSTR, Model):
131
134
  logits = tf.reshape(
132
135
  self.head(features, **kwargs), (B, N, len(self.vocab) + 1)
133
136
  ) # (batch_size, max_length, vocab + 1)
134
- decoded_features = logits[:, 1:] # remove cls_token
137
+ decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
135
138
 
136
139
  out: Dict[str, tf.Tensor] = {}
137
140
  if self.exportable:
@@ -155,6 +158,7 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
155
158
  """Post processor for ViTSTR architecture
156
159
 
157
160
  Args:
161
+ ----
158
162
  vocab: string containing the ordered sequence of supported characters
159
163
  """
160
164
 
@@ -164,10 +168,7 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
164
168
  ) -> List[Tuple[str, float]]:
165
169
  # compute pred with argmax for attention models
166
170
  out_idxs = tf.math.argmax(logits, axis=2)
167
- # N x L
168
- probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)
169
- # Take the minimum confidence of the sequence
170
- probs = tf.math.reduce_min(probs, axis=1)
171
+ preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
171
172
 
172
173
  # decode raw output of the model with tf_label_to_idx
173
174
  out_idxs = tf.cast(out_idxs, dtype="int32")
@@ -177,7 +178,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
177
178
  decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
178
179
  word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
179
180
 
180
- return list(zip(word_values, probs.numpy().tolist()))
181
+ # compute probabilties for each word up to the EOS token
182
+ probs = [
183
+ preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
184
+ for i, word in enumerate(word_values)
185
+ ]
186
+
187
+ return list(zip(word_values, probs))
181
188
 
182
189
 
183
190
  def _vitstr(
@@ -227,12 +234,14 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
227
234
  >>> out = model(input_tensor)
228
235
 
229
236
  Args:
237
+ ----
230
238
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
239
+ **kwargs: keyword arguments of the ViTSTR architecture
231
240
 
232
241
  Returns:
242
+ -------
233
243
  text recognition architecture
234
244
  """
235
-
236
245
  return _vitstr(
237
246
  "vitstr_small",
238
247
  pretrained,
@@ -254,12 +263,14 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
254
263
  >>> out = model(input_tensor)
255
264
 
256
265
  Args:
266
+ ----
257
267
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
268
+ **kwargs: keyword arguments of the ViTSTR architecture
258
269
 
259
270
  Returns:
271
+ -------
260
272
  text recognition architecture
261
273
  """
262
-
263
274
  return _vitstr(
264
275
  "vitstr_base",
265
276
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -63,11 +63,13 @@ def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False,
63
63
  >>> out = model([input_page])
64
64
 
65
65
  Args:
66
+ ----
66
67
  arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
67
68
  pretrained: If True, returns a model pre-trained on our text recognition dataset
69
+ **kwargs: optional parameters to be passed to the architecture
68
70
 
69
71
  Returns:
72
+ -------
70
73
  Recognition predictor
71
74
  """
72
-
73
75
  return _predictor(arch, pretrained, **kwargs)
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -11,18 +11,29 @@ from torch import nn
11
11
 
12
12
  from doctr.utils.data import download_from_url
13
13
 
14
- __all__ = ["load_pretrained_params", "conv_sequence_pt", "set_device_and_dtype", "export_model_to_onnx", "_copy_tensor"]
14
+ __all__ = [
15
+ "load_pretrained_params",
16
+ "conv_sequence_pt",
17
+ "set_device_and_dtype",
18
+ "export_model_to_onnx",
19
+ "_copy_tensor",
20
+ "_bf16_to_float32",
21
+ ]
15
22
 
16
23
 
17
24
  def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
18
25
  return x.clone().detach()
19
26
 
20
27
 
28
+ def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
29
+ # bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype
30
+ return x.float() if x.dtype == torch.bfloat16 else x
31
+
32
+
21
33
  def load_pretrained_params(
22
34
  model: nn.Module,
23
35
  url: Optional[str] = None,
24
36
  hash_prefix: Optional[str] = None,
25
- overwrite: bool = False,
26
37
  ignore_keys: Optional[List[str]] = None,
27
38
  **kwargs: Any,
28
39
  ) -> None:
@@ -32,13 +43,13 @@ def load_pretrained_params(
32
43
  >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
33
44
 
34
45
  Args:
46
+ ----
35
47
  model: the PyTorch model to be loaded
36
48
  url: URL of the zipped set of parameters
37
49
  hash_prefix: first characters of SHA256 expected hash
38
- overwrite: should the zip extraction be enforced if the archive has already been extracted
39
50
  ignore_keys: list of weights to be ignored from the state_dict
51
+ **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
40
52
  """
41
-
42
53
  if url is None:
43
54
  logging.warning("Invalid model URL, using default initialization.")
44
55
  else:
@@ -73,11 +84,15 @@ def conv_sequence_pt(
73
84
  >>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3))
74
85
 
75
86
  Args:
87
+ ----
88
+ in_channels: number of input channels
76
89
  out_channels: number of output channels
77
90
  relu: whether ReLU should be used
78
91
  bn: should a batch normalization layer be added
92
+ **kwargs: additional arguments to be passed to the convolutional layer
79
93
 
80
94
  Returns:
95
+ -------
81
96
  list of layers
82
97
  """
83
98
  # No bias before Batch norm
@@ -107,15 +122,16 @@ def set_device_and_dtype(
107
122
  >>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16)
108
123
 
109
124
  Args:
125
+ ----
110
126
  model: the model to be set
111
127
  batches: the batches to be set
112
128
  device: the device to be used
113
129
  dtype: the dtype to be used
114
130
 
115
131
  Returns:
132
+ -------
116
133
  the model and batches set
117
134
  """
118
-
119
135
  return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches]
120
136
 
121
137
 
@@ -129,12 +145,14 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
129
145
  >>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
130
146
 
131
147
  Args:
148
+ ----
132
149
  model: the PyTorch model to be exported
133
150
  model_name: the name for the exported model
134
151
  dummy_input: the dummy input to the model
135
152
  kwargs: additional arguments to be passed to torch.onnx.export
136
153
 
137
154
  Returns:
155
+ -------
138
156
  the path to the exported model
139
157
  """
140
158
  torch.onnx.export(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -17,13 +17,25 @@ from doctr.utils.data import download_from_url
17
17
  logging.getLogger("tensorflow").setLevel(logging.DEBUG)
18
18
 
19
19
 
20
- __all__ = ["load_pretrained_params", "conv_sequence", "IntermediateLayerGetter", "export_model_to_onnx", "_copy_tensor"]
20
+ __all__ = [
21
+ "load_pretrained_params",
22
+ "conv_sequence",
23
+ "IntermediateLayerGetter",
24
+ "export_model_to_onnx",
25
+ "_copy_tensor",
26
+ "_bf16_to_float32",
27
+ ]
21
28
 
22
29
 
23
30
  def _copy_tensor(x: tf.Tensor) -> tf.Tensor:
24
31
  return tf.identity(x)
25
32
 
26
33
 
34
+ def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor:
35
+ # Convert bfloat16 to float32 for numpy compatibility
36
+ return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x
37
+
38
+
27
39
  def load_pretrained_params(
28
40
  model: Model,
29
41
  url: Optional[str] = None,
@@ -38,13 +50,14 @@ def load_pretrained_params(
38
50
  >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
39
51
 
40
52
  Args:
53
+ ----
41
54
  model: the keras model to be loaded
42
55
  url: URL of the zipped set of parameters
43
56
  hash_prefix: first characters of SHA256 expected hash
44
57
  overwrite: should the zip extraction be enforced if the archive has already been extracted
45
58
  internal_name: name of the ckpt files
59
+ **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
46
60
  """
47
-
48
61
  if url is None:
49
62
  logging.warning("Invalid model URL, using default initialization.")
50
63
  else:
@@ -75,13 +88,16 @@ def conv_sequence(
75
88
  >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3]))
76
89
 
77
90
  Args:
91
+ ----
78
92
  out_channels: number of output channels
79
93
  activation: activation to be used (default: no activation)
80
94
  bn: should a batch normalization layer be added
81
95
  padding: padding scheme
82
96
  kernel_initializer: kernel initializer
97
+ **kwargs: additional arguments to be passed to the convolutional layer
83
98
 
84
99
  Returns:
100
+ -------
85
101
  list of layers
86
102
  """
87
103
  # No bias before Batch norm
@@ -109,6 +125,7 @@ class IntermediateLayerGetter(Model):
109
125
  >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers)
110
126
 
111
127
  Args:
128
+ ----
112
129
  model: the model to extract feature maps from
113
130
  layer_names: the list of layers to retrieve the feature map from
114
131
  """
@@ -134,12 +151,14 @@ def export_model_to_onnx(
134
151
  >>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")])
135
152
 
136
153
  Args:
154
+ ----
137
155
  model: the keras model to be exported
138
156
  model_name: the name for the exported model
139
157
  dummy_input: the dummy input to the model
140
158
  kwargs: additional arguments to be passed to tf2onnx
141
159
 
142
160
  Returns:
161
+ -------
143
162
  the path to the exported model and a list with the output layer names
144
163
  """
145
164
  large_model = kwargs.get("large_model", False)
doctr/models/zoo.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -24,6 +24,7 @@ def _predictor(
24
24
  det_bs: int = 2,
25
25
  reco_bs: int = 128,
26
26
  detect_orientation: bool = False,
27
+ straighten_pages: bool = False,
27
28
  detect_language: bool = False,
28
29
  **kwargs,
29
30
  ) -> OCRPredictor:
@@ -53,6 +54,7 @@ def _predictor(
53
54
  preserve_aspect_ratio=preserve_aspect_ratio,
54
55
  symmetric_pad=symmetric_pad,
55
56
  detect_orientation=detect_orientation,
57
+ straighten_pages=straighten_pages,
56
58
  detect_language=detect_language,
57
59
  **kwargs,
58
60
  )
@@ -68,6 +70,7 @@ def ocr_predictor(
68
70
  symmetric_pad: bool = True,
69
71
  export_as_straight_boxes: bool = False,
70
72
  detect_orientation: bool = False,
73
+ straighten_pages: bool = False,
71
74
  detect_language: bool = False,
72
75
  **kwargs: Any,
73
76
  ) -> OCRPredictor:
@@ -80,6 +83,7 @@ def ocr_predictor(
80
83
  >>> out = model([input_page])
81
84
 
82
85
  Args:
86
+ ----
83
87
  det_arch: name of the detection architecture or the model itself to use
84
88
  (e.g. 'db_resnet50', 'db_mobilenet_v3_large')
85
89
  reco_arch: name of the recognition architecture or the model itself to use
@@ -95,14 +99,18 @@ def ocr_predictor(
95
99
  (potentially rotated) as straight bounding boxes.
96
100
  detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
97
101
  page. Doing so will slightly deteriorate the overall latency.
102
+ straighten_pages: if True, estimates the page general orientation
103
+ based on the segmentation map median line orientation.
104
+ Then, rotates page before passing it again to the deep learning detection module.
105
+ Doing so will improve performances for documents with page-uniform rotations.
98
106
  detect_language: if True, the language prediction will be added to the predictions for each
99
107
  page. Doing so will slightly deteriorate the overall latency.
100
108
  kwargs: keyword args of `OCRPredictor`
101
109
 
102
110
  Returns:
111
+ -------
103
112
  OCR predictor
104
113
  """
105
-
106
114
  return _predictor(
107
115
  det_arch,
108
116
  reco_arch,
@@ -113,6 +121,7 @@ def ocr_predictor(
113
121
  symmetric_pad=symmetric_pad,
114
122
  export_as_straight_boxes=export_as_straight_boxes,
115
123
  detect_orientation=detect_orientation,
124
+ straighten_pages=straighten_pages,
116
125
  detect_language=detect_language,
117
126
  **kwargs,
118
127
  )
@@ -129,6 +138,7 @@ def _kie_predictor(
129
138
  det_bs: int = 2,
130
139
  reco_bs: int = 128,
131
140
  detect_orientation: bool = False,
141
+ straighten_pages: bool = False,
132
142
  detect_language: bool = False,
133
143
  **kwargs,
134
144
  ) -> KIEPredictor:
@@ -158,6 +168,7 @@ def _kie_predictor(
158
168
  preserve_aspect_ratio=preserve_aspect_ratio,
159
169
  symmetric_pad=symmetric_pad,
160
170
  detect_orientation=detect_orientation,
171
+ straighten_pages=straighten_pages,
161
172
  detect_language=detect_language,
162
173
  **kwargs,
163
174
  )
@@ -173,6 +184,7 @@ def kie_predictor(
173
184
  symmetric_pad: bool = True,
174
185
  export_as_straight_boxes: bool = False,
175
186
  detect_orientation: bool = False,
187
+ straighten_pages: bool = False,
176
188
  detect_language: bool = False,
177
189
  **kwargs: Any,
178
190
  ) -> KIEPredictor:
@@ -185,6 +197,7 @@ def kie_predictor(
185
197
  >>> out = model([input_page])
186
198
 
187
199
  Args:
200
+ ----
188
201
  det_arch: name of the detection architecture or the model itself to use
189
202
  (e.g. 'db_resnet50', 'db_mobilenet_v3_large')
190
203
  reco_arch: name of the recognition architecture or the model itself to use
@@ -200,14 +213,18 @@ def kie_predictor(
200
213
  (potentially rotated) as straight bounding boxes.
201
214
  detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
202
215
  page. Doing so will slightly deteriorate the overall latency.
216
+ straighten_pages: if True, estimates the page general orientation
217
+ based on the segmentation map median line orientation.
218
+ Then, rotates page before passing it again to the deep learning detection module.
219
+ Doing so will improve performances for documents with page-uniform rotations.
203
220
  detect_language: if True, the language prediction will be added to the predictions for each
204
221
  page. Doing so will slightly deteriorate the overall latency.
205
222
  kwargs: keyword args of `OCRPredictor`
206
223
 
207
224
  Returns:
225
+ -------
208
226
  KIE predictor
209
227
  """
210
-
211
228
  return _kie_predictor(
212
229
  det_arch,
213
230
  reco_arch,
@@ -218,6 +235,7 @@ def kie_predictor(
218
235
  symmetric_pad=symmetric_pad,
219
236
  export_as_straight_boxes=export_as_straight_boxes,
220
237
  detect_orientation=detect_orientation,
238
+ straighten_pages=straighten_pages,
221
239
  detect_language=detect_language,
222
240
  **kwargs,
223
241
  )
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -20,10 +20,12 @@ def crop_boxes(
20
20
  """Crop localization boxes
21
21
 
22
22
  Args:
23
+ ----
23
24
  boxes: ndarray of shape (N, 4) in relative or abs coordinates
24
25
  crop_box: box (xmin, ymin, xmax, ymax) to crop the image, in the same coord format that the boxes
25
26
 
26
27
  Returns:
28
+ -------
27
29
  the cropped boxes
28
30
  """
29
31
  is_box_rel = boxes.max() <= 1
@@ -52,10 +54,12 @@ def expand_line(line: np.ndarray, target_shape: Tuple[int, int]) -> Tuple[float,
52
54
  the same direction until we meet one of the edges.
53
55
 
54
56
  Args:
57
+ ----
55
58
  line: array of shape (2, 2) of the point supposed to be on one edge, and the shadow tip.
56
59
  target_shape: the desired mask shape
57
60
 
58
61
  Returns:
62
+ -------
59
63
  2D coordinates of the first point once we extended the line (on one of the edges)
60
64
  """
61
65
  if any(coord == 0 or coord == size for coord, size in zip(line[0], target_shape[::-1])):
@@ -116,15 +120,16 @@ def create_shadow_mask(
116
120
  """Creates a random shadow mask
117
121
 
118
122
  Args:
123
+ ----
119
124
  target_shape: the target shape (H, W)
120
125
  min_base_width: the relative minimum shadow base width
121
126
  max_tip_width: the relative maximum shadow tip width
122
127
  max_tip_height: the relative maximum shadow tip height
123
128
 
124
129
  Returns:
130
+ -------
125
131
  a numpy ndarray of shape (H, W, 1) with values in the range [0, 1]
126
132
  """
127
-
128
133
  # Default base is top
129
134
  _params = np.random.rand(6)
130
135
  base_width = min_base_width + (1 - min_base_width) * _params[0]
@@ -195,4 +200,4 @@ def create_shadow_mask(
195
200
  mask: np.ndarray = np.zeros((*target_shape, 1), dtype=np.uint8)
196
201
  mask = cv2.fillPoly(mask, [final_contour], (255,), lineType=cv2.LINE_AA)[..., 0]
197
202
 
198
- return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32)
203
+ return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32) # type: ignore[operator]