python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -15,60 +15,51 @@ from tensorflow.keras import Model, Sequential, layers
15
15
 
16
16
  from doctr.file_utils import CLASS_NAME
17
17
  from doctr.models.classification import resnet18, resnet34, resnet50
18
- from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params
18
+ from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
19
19
  from doctr.utils.repr import NestedObject
20
20
 
21
21
  from .base import LinkNetPostProcessor, _LinkNet
22
22
 
23
- __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50", "linknet_resnet18_rotation"]
23
+ __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
24
24
 
25
25
  default_cfgs: Dict[str, Dict[str, Any]] = {
26
26
  "linknet_resnet18": {
27
27
  "mean": (0.798, 0.785, 0.772),
28
28
  "std": (0.264, 0.2749, 0.287),
29
29
  "input_shape": (1024, 1024, 3),
30
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/linknet_resnet18-611b50f2.zip&src=0",
31
- },
32
- "linknet_resnet18_rotation": {
33
- "mean": (0.798, 0.785, 0.772),
34
- "std": (0.264, 0.2749, 0.287),
35
- "input_shape": (1024, 1024, 3),
36
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/linknet_resnet18-a48e6ed3.zip&src=0",
30
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-b9ee56e6.zip&src=0",
37
31
  },
38
32
  "linknet_resnet34": {
39
33
  "mean": (0.798, 0.785, 0.772),
40
34
  "std": (0.264, 0.2749, 0.287),
41
35
  "input_shape": (1024, 1024, 3),
42
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/linknet_resnet34-bf30afb1.zip&src=0",
36
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-51909c56.zip&src=0",
43
37
  },
44
38
  "linknet_resnet50": {
45
39
  "mean": (0.798, 0.785, 0.772),
46
40
  "std": (0.264, 0.2749, 0.287),
47
41
  "input_shape": (1024, 1024, 3),
48
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/linknet_resnet50-cd299262.zip&src=0",
42
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-ac9f3829.zip&src=0",
49
43
  },
50
44
  }
51
45
 
52
46
 
53
47
  def decoder_block(in_chan: int, out_chan: int, stride: int, **kwargs: Any) -> Sequential:
54
48
  """Creates a LinkNet decoder block"""
55
-
56
- return Sequential(
57
- [
58
- *conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs),
59
- layers.Conv2DTranspose(
60
- filters=in_chan // 4,
61
- kernel_size=3,
62
- strides=stride,
63
- padding="same",
64
- use_bias=False,
65
- kernel_initializer="he_normal",
66
- ),
67
- layers.BatchNormalization(),
68
- layers.Activation("relu"),
69
- *conv_sequence(out_chan, "relu", True, kernel_size=1),
70
- ]
71
- )
49
+ return Sequential([
50
+ *conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs),
51
+ layers.Conv2DTranspose(
52
+ filters=in_chan // 4,
53
+ kernel_size=3,
54
+ strides=stride,
55
+ padding="same",
56
+ use_bias=False,
57
+ kernel_initializer="he_normal",
58
+ ),
59
+ layers.BatchNormalization(),
60
+ layers.Activation("relu"),
61
+ *conv_sequence(out_chan, "relu", True, kernel_size=1),
62
+ ])
72
63
 
73
64
 
74
65
  class LinkNetFPN(Model, NestedObject):
@@ -104,8 +95,11 @@ class LinkNet(_LinkNet, keras.Model):
104
95
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
105
96
 
106
97
  Args:
98
+ ----
107
99
  feature extractor: the backbone serving as feature extractor
108
100
  fpn_channels: number of channels each extracted feature maps is mapped to
101
+ bin_thresh: threshold for binarization of the output feature map
102
+ box_thresh: minimal objectness score to consider a box
109
103
  assume_straight_pages: if True, fit straight bounding boxes only
110
104
  exportable: onnx exportable returns only logits
111
105
  cfg: the configuration dict of the model
@@ -119,6 +113,7 @@ class LinkNet(_LinkNet, keras.Model):
119
113
  feat_extractor: IntermediateLayerGetter,
120
114
  fpn_channels: int = 64,
121
115
  bin_thresh: float = 0.1,
116
+ box_thresh: float = 0.1,
122
117
  assume_straight_pages: bool = True,
123
118
  exportable: bool = False,
124
119
  cfg: Optional[Dict[str, Any]] = None,
@@ -137,32 +132,32 @@ class LinkNet(_LinkNet, keras.Model):
137
132
  self.fpn = LinkNetFPN(fpn_channels, [_shape[1:] for _shape in self.feat_extractor.output_shape])
138
133
  self.fpn.build(self.feat_extractor.output_shape)
139
134
 
140
- self.classifier = Sequential(
141
- [
142
- layers.Conv2DTranspose(
143
- filters=32,
144
- kernel_size=3,
145
- strides=2,
146
- padding="same",
147
- use_bias=False,
148
- kernel_initializer="he_normal",
149
- input_shape=self.fpn.decoders[-1].output_shape[1:],
150
- ),
151
- layers.BatchNormalization(),
152
- layers.Activation("relu"),
153
- *conv_sequence(32, "relu", True, kernel_size=3, strides=1),
154
- layers.Conv2DTranspose(
155
- filters=num_classes,
156
- kernel_size=2,
157
- strides=2,
158
- padding="same",
159
- use_bias=True,
160
- kernel_initializer="he_normal",
161
- ),
162
- ]
163
- )
135
+ self.classifier = Sequential([
136
+ layers.Conv2DTranspose(
137
+ filters=32,
138
+ kernel_size=3,
139
+ strides=2,
140
+ padding="same",
141
+ use_bias=False,
142
+ kernel_initializer="he_normal",
143
+ input_shape=self.fpn.decoders[-1].output_shape[1:],
144
+ ),
145
+ layers.BatchNormalization(),
146
+ layers.Activation("relu"),
147
+ *conv_sequence(32, "relu", True, kernel_size=3, strides=1),
148
+ layers.Conv2DTranspose(
149
+ filters=num_classes,
150
+ kernel_size=2,
151
+ strides=2,
152
+ padding="same",
153
+ use_bias=True,
154
+ kernel_initializer="he_normal",
155
+ ),
156
+ ])
164
157
 
165
- self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)
158
+ self.postprocessor = LinkNetPostProcessor(
159
+ assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
160
+ )
166
161
 
167
162
  def compute_loss(
168
163
  self,
@@ -176,6 +171,7 @@ class LinkNet(_LinkNet, keras.Model):
176
171
  <https://github.com/tensorflow/addons/>`_.
177
172
 
178
173
  Args:
174
+ ----
179
175
  out_map: output feature map of the model of shape N x H x W x 1
180
176
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
181
177
  gamma: modulating factor in the focal loss formula
@@ -183,6 +179,7 @@ class LinkNet(_LinkNet, keras.Model):
183
179
  eps: epsilon factor in dice loss
184
180
 
185
181
  Returns:
182
+ -------
186
183
  A loss tensor
187
184
  """
188
185
  seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True)
@@ -204,10 +201,12 @@ class LinkNet(_LinkNet, keras.Model):
204
201
  # Class reduced
205
202
  focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))
206
203
 
207
- # Dice loss
208
- inter = tf.math.reduce_sum(seg_mask * proba_map * seg_target, (0, 1, 2, 3))
209
- cardinality = tf.math.reduce_sum((proba_map + seg_target), (0, 1, 2, 3))
210
- dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)
204
+ # Compute dice loss for each class
205
+ dice_map = tf.nn.softmax(out_map, axis=-1) if len(self.class_names) > 1 else proba_map
206
+ # Class-reduced dice loss
207
+ inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2])
208
+ cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2])
209
+ dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps))
211
210
 
212
211
  return focal_loss + dice_loss
213
212
 
@@ -229,7 +228,8 @@ class LinkNet(_LinkNet, keras.Model):
229
228
  return out
230
229
 
231
230
  if return_model_output or target is None or return_preds:
232
- prob_map = tf.math.sigmoid(logits)
231
+ prob_map = _bf16_to_float32(tf.math.sigmoid(logits))
232
+
233
233
  if return_model_output:
234
234
  out["out_map"] = prob_map
235
235
 
@@ -293,12 +293,14 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
293
293
  >>> out = model(input_tensor)
294
294
 
295
295
  Args:
296
+ ----
296
297
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
298
+ **kwargs: keyword arguments of the LinkNet architecture
297
299
 
298
300
  Returns:
301
+ -------
299
302
  text detection architecture
300
303
  """
301
-
302
304
  return _linknet(
303
305
  "linknet_resnet18",
304
306
  pretrained,
@@ -308,32 +310,6 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
308
310
  )
309
311
 
310
312
 
311
- def linknet_resnet18_rotation(pretrained: bool = False, **kwargs: Any) -> LinkNet:
312
- """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
313
- <https://arxiv.org/pdf/1707.03718.pdf>`_.
314
-
315
- >>> import tensorflow as tf
316
- >>> from doctr.models import linknet_resnet18_rotation
317
- >>> model = linknet_resnet18_rotation(pretrained=True)
318
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
319
- >>> out = model(input_tensor)
320
-
321
- Args:
322
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
323
-
324
- Returns:
325
- text detection architecture
326
- """
327
-
328
- return _linknet(
329
- "linknet_resnet18_rotation",
330
- pretrained,
331
- resnet18,
332
- ["resnet_block_1", "resnet_block_3", "resnet_block_5", "resnet_block_7"],
333
- **kwargs,
334
- )
335
-
336
-
337
313
  def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
338
314
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
339
315
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
@@ -345,12 +321,14 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
345
321
  >>> out = model(input_tensor)
346
322
 
347
323
  Args:
324
+ ----
348
325
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
326
+ **kwargs: keyword arguments of the LinkNet architecture
349
327
 
350
328
  Returns:
329
+ -------
351
330
  text detection architecture
352
331
  """
353
-
354
332
  return _linknet(
355
333
  "linknet_resnet34",
356
334
  pretrained,
@@ -371,12 +349,14 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
371
349
  >>> out = model(input_tensor)
372
350
 
373
351
  Args:
352
+ ----
374
353
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
354
+ **kwargs: keyword arguments of the LinkNet architecture
375
355
 
376
356
  Returns:
357
+ -------
377
358
  text detection architecture
378
359
  """
379
-
380
360
  return _linknet(
381
361
  "linknet_resnet50",
382
362
  pretrained,
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Union
6
+ from typing import Any, Dict, List, Tuple, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -19,6 +19,7 @@ class DetectionPredictor(nn.Module):
19
19
  """Implements an object able to localize text elements in a document
20
20
 
21
21
  Args:
22
+ ----
22
23
  pre_processor: transform inputs for easier batched model inference
23
24
  model: core detection architecture
24
25
  """
@@ -32,12 +33,13 @@ class DetectionPredictor(nn.Module):
32
33
  self.pre_processor = pre_processor
33
34
  self.model = model.eval()
34
35
 
35
- @torch.no_grad()
36
+ @torch.inference_mode()
36
37
  def forward(
37
38
  self,
38
39
  pages: List[Union[np.ndarray, torch.Tensor]],
40
+ return_maps: bool = False,
39
41
  **kwargs: Any,
40
- ) -> List[np.ndarray]:
42
+ ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
41
43
  # Dimension check
42
44
  if any(page.ndim != 3 for page in pages):
43
45
  raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
@@ -47,5 +49,13 @@ class DetectionPredictor(nn.Module):
47
49
  self.model, processed_batches = set_device_and_dtype(
48
50
  self.model, processed_batches, _params.device, _params.dtype
49
51
  )
50
- predicted_batches = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches]
51
- return [pred for batch in predicted_batches for pred in batch]
52
+ predicted_batches = [
53
+ self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
54
+ ]
55
+ preds = [pred for batch in predicted_batches for pred in batch["preds"]]
56
+ if return_maps:
57
+ seg_maps = [
58
+ pred.permute(1, 2, 0).detach().cpu().numpy() for batch in predicted_batches for pred in batch["out_map"]
59
+ ]
60
+ return preds, seg_maps
61
+ return preds
@@ -1,9 +1,9 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Union
6
+ from typing import Any, Dict, List, Tuple, Union
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
@@ -19,6 +19,7 @@ class DetectionPredictor(NestedObject):
19
19
  """Implements an object able to localize text elements in a document
20
20
 
21
21
  Args:
22
+ ----
22
23
  pre_processor: transform inputs for easier batched model inference
23
24
  model: core detection architecture
24
25
  """
@@ -36,14 +37,21 @@ class DetectionPredictor(NestedObject):
36
37
  def __call__(
37
38
  self,
38
39
  pages: List[Union[np.ndarray, tf.Tensor]],
40
+ return_maps: bool = False,
39
41
  **kwargs: Any,
40
- ) -> List[Dict[str, np.ndarray]]:
42
+ ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
41
43
  # Dimension check
42
44
  if any(page.ndim != 3 for page in pages):
43
45
  raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
44
46
 
45
47
  processed_batches = self.pre_processor(pages)
46
48
  predicted_batches = [
47
- self.model(batch, return_preds=True, training=False, **kwargs)["preds"] for batch in processed_batches
49
+ self.model(batch, return_preds=True, return_model_output=True, training=False, **kwargs)
50
+ for batch in processed_batches
48
51
  ]
49
- return [pred for batch in predicted_batches for pred in batch]
52
+
53
+ preds = [pred for batch in predicted_batches for pred in batch["preds"]]
54
+ if return_maps:
55
+ seg_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]]
56
+ return preds, seg_maps
57
+ return preds
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -14,12 +14,19 @@ from .predictor import DetectionPredictor
14
14
  __all__ = ["detection_predictor"]
15
15
 
16
16
  ARCHS: List[str]
17
- ROT_ARCHS: List[str]
18
17
 
19
18
 
20
19
  if is_tf_available():
21
- ARCHS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
22
- ROT_ARCHS = ["linknet_resnet18_rotation"]
20
+ ARCHS = [
21
+ "db_resnet50",
22
+ "db_mobilenet_v3_large",
23
+ "linknet_resnet18",
24
+ "linknet_resnet34",
25
+ "linknet_resnet50",
26
+ "fast_tiny",
27
+ "fast_small",
28
+ "fast_base",
29
+ ]
23
30
  elif is_torch_available():
24
31
  ARCHS = [
25
32
  "db_resnet34",
@@ -28,30 +35,24 @@ elif is_torch_available():
28
35
  "linknet_resnet18",
29
36
  "linknet_resnet34",
30
37
  "linknet_resnet50",
38
+ "fast_tiny",
39
+ "fast_small",
40
+ "fast_base",
31
41
  ]
32
- ROT_ARCHS = ["db_resnet50_rotation"]
33
42
 
34
43
 
35
44
  def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
36
45
  if isinstance(arch, str):
37
- if arch not in ARCHS + ROT_ARCHS:
46
+ if arch not in ARCHS:
38
47
  raise ValueError(f"unknown architecture '{arch}'")
39
48
 
40
- if arch not in ROT_ARCHS and not assume_straight_pages:
41
- raise AssertionError(
42
- "You are trying to use a model trained on straight pages while not assuming"
43
- " your pages are straight. If you have only straight documents, don't pass"
44
- " assume_straight_pages=False, otherwise you should use one of these archs:"
45
- f"{ROT_ARCHS}"
46
- )
47
-
48
49
  _model = detection.__dict__[arch](
49
50
  pretrained=pretrained,
50
51
  pretrained_backbone=kwargs.get("pretrained_backbone", True),
51
52
  assume_straight_pages=assume_straight_pages,
52
53
  )
53
54
  else:
54
- if not isinstance(arch, (detection.DBNet, detection.LinkNet)):
55
+ if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
55
56
  raise ValueError(f"unknown architecture: {type(arch)}")
56
57
 
57
58
  _model = arch
@@ -84,12 +85,14 @@ def detection_predictor(
84
85
  >>> out = model([input_page])
85
86
 
86
87
  Args:
88
+ ----
87
89
  arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
88
90
  pretrained: If True, returns a model pre-trained on our text detection dataset
89
91
  assume_straight_pages: If True, fit straight boxes to the page
92
+ **kwargs: optional keyword arguments passed to the architecture
90
93
 
91
94
  Returns:
95
+ -------
92
96
  Detection predictor
93
97
  """
94
-
95
98
  return _predictor(arch, pretrained, assume_straight_pages, **kwargs)
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -13,7 +13,15 @@ import textwrap
13
13
  from pathlib import Path
14
14
  from typing import Any
15
15
 
16
- from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, snapshot_download
16
+ from huggingface_hub import (
17
+ HfApi,
18
+ Repository,
19
+ get_token,
20
+ get_token_permission,
21
+ hf_hub_download,
22
+ login,
23
+ snapshot_download,
24
+ )
17
25
 
18
26
  from doctr import models
19
27
  from doctr.file_utils import is_tf_available, is_torch_available
@@ -26,7 +34,7 @@ __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config
26
34
 
27
35
  AVAILABLE_ARCHS = {
28
36
  "classification": models.classification.zoo.ARCHS,
29
- "detection": models.detection.zoo.ARCHS + models.detection.zoo.ROT_ARCHS,
37
+ "detection": models.detection.zoo.ARCHS,
30
38
  "recognition": models.recognition.zoo.ARCHS,
31
39
  "obj_detection": ["fasterrcnn_mobilenet_v3_large_fpn"] if is_torch_available() else None,
32
40
  }
@@ -34,13 +42,12 @@ AVAILABLE_ARCHS = {
34
42
 
35
43
  def login_to_hub() -> None: # pragma: no cover
36
44
  """Login to huggingface hub"""
37
- access_token = HfFolder.get_token()
38
- if access_token is not None and HfApi()._is_valid_token(access_token):
45
+ access_token = get_token()
46
+ if access_token is not None and get_token_permission(access_token):
39
47
  logging.info("Huggingface Hub token found and valid")
40
- HfApi().set_access_token(access_token)
48
+ login(token=access_token, write_permission=True)
41
49
  else:
42
- subprocess.call(["huggingface-cli", "login"])
43
- HfApi().set_access_token(HfFolder().get_token())
50
+ login()
44
51
  # check if git lfs is installed
45
52
  try:
46
53
  subprocess.call(["git", "lfs", "version"])
@@ -56,6 +63,7 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
56
63
  """Save model and config to disk for pushing to huggingface hub
57
64
 
58
65
  Args:
66
+ ----
59
67
  model: TF or PyTorch model to be saved
60
68
  save_dir: directory to save model and config
61
69
  arch: architecture name
@@ -91,6 +99,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
91
99
  >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
92
100
 
93
101
  Args:
102
+ ----
94
103
  model: TF or PyTorch model to be saved
95
104
  model_name: name of the model which is also the repository name
96
105
  task: task name
@@ -165,7 +174,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
165
174
  commit_message = f"Add {model_name} model"
166
175
 
167
176
  local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name)
168
- repo_url = HfApi().create_repo(model_name, token=HfFolder.get_token(), exist_ok=False)
177
+ repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False)
169
178
  repo = Repository(local_dir=local_cache_dir, clone_from=repo_url, use_auth_token=True)
170
179
 
171
180
  with repo.commit(commit_message):
@@ -183,13 +192,14 @@ def from_hub(repo_id: str, **kwargs: Any):
183
192
  >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn")
184
193
 
185
194
  Args:
195
+ ----
186
196
  repo_id: HuggingFace model hub repo
187
197
  kwargs: kwargs of `hf_hub_download` or `snapshot_download`
188
198
 
189
199
  Returns:
200
+ -------
190
201
  Model loaded with the checkpoint
191
202
  """
192
-
193
203
  # Get the config
194
204
  with open(hf_hub_download(repo_id, filename="config.json", **kwargs), "rb") as f:
195
205
  cfg = json.load(f)
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -17,6 +17,7 @@ class _KIEPredictor(_OCRPredictor):
17
17
  """Implements an object able to localize and identify text elements in a set of documents
18
18
 
19
19
  Args:
20
+ ----
20
21
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
21
22
  without rotated textual elements.
22
23
  straighten_pages: if True, estimates the page general orientation based on the median line orientation.