python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -15,7 +15,7 @@ from tensorflow.keras import layers
15
15
  from tensorflow.keras.applications import ResNet50
16
16
 
17
17
  from doctr.file_utils import CLASS_NAME
18
- from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params
18
+ from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
19
19
  from doctr.utils.repr import NestedObject
20
20
 
21
21
  from ...classification import mobilenet_v3_large
@@ -29,13 +29,13 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
29
29
  "mean": (0.798, 0.785, 0.772),
30
30
  "std": (0.264, 0.2749, 0.287),
31
31
  "input_shape": (1024, 1024, 3),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.2.0/db_resnet50-adcafc63.zip&src=0",
32
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-84171458.zip&src=0",
33
33
  },
34
34
  "db_mobilenet_v3_large": {
35
35
  "mean": (0.798, 0.785, 0.772),
36
36
  "std": (0.264, 0.2749, 0.287),
37
37
  "input_shape": (1024, 1024, 3),
38
- "url": "https://doctr-static.mindee.com/models?id=v0.3.1/db_mobilenet_v3_large-8c16d5bf.zip&src=0",
38
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-da524564.zip&src=0",
39
39
  },
40
40
  }
41
41
 
@@ -45,6 +45,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
45
45
  <https://arxiv.org/pdf/1612.03144.pdf>`_.
46
46
 
47
47
  Args:
48
+ ----
48
49
  channels: number of channel to output
49
50
  """
50
51
 
@@ -66,14 +67,15 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
66
67
  """Module which performs a 3x3 convolution followed by up-sampling
67
68
 
68
69
  Args:
70
+ ----
69
71
  channels: number of output channels
70
72
  dilation_factor (int): dilation factor to scale the convolution output before concatenation
71
73
 
72
74
  Returns:
75
+ -------
73
76
  a keras.layers.Layer object, wrapping these operations in a sequential module
74
77
 
75
78
  """
76
-
77
79
  _layers = conv_sequence(channels, "relu", True, kernel_size=3)
78
80
 
79
81
  if dilation_factor > 1:
@@ -107,8 +109,11 @@ class DBNet(_DBNet, keras.Model, NestedObject):
107
109
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
108
110
 
109
111
  Args:
112
+ ----
110
113
  feature extractor: the backbone serving as feature extractor
111
114
  fpn_channels: number of channels each extracted feature maps is mapped to
115
+ bin_thresh: threshold for binarization
116
+ box_thresh: minimal objectness score to consider a box
112
117
  assume_straight_pages: if True, fit straight bounding boxes only
113
118
  exportable: onnx exportable returns only logits
114
119
  cfg: the configuration dict of the model
@@ -122,6 +127,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
122
127
  feature_extractor: IntermediateLayerGetter,
123
128
  fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea
124
129
  bin_thresh: float = 0.3,
130
+ box_thresh: float = 0.1,
125
131
  assume_straight_pages: bool = True,
126
132
  exportable: bool = False,
127
133
  cfg: Optional[Dict[str, Any]] = None,
@@ -141,87 +147,96 @@ class DBNet(_DBNet, keras.Model, NestedObject):
141
147
  _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
142
148
  output_shape = tuple(self.fpn(_inputs).shape)
143
149
 
144
- self.probability_head = keras.Sequential(
145
- [
146
- *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
147
- layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
148
- layers.BatchNormalization(),
149
- layers.Activation("relu"),
150
- layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
151
- ]
152
- )
153
- self.threshold_head = keras.Sequential(
154
- [
155
- *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
156
- layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
157
- layers.BatchNormalization(),
158
- layers.Activation("relu"),
159
- layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
160
- ]
150
+ self.probability_head = keras.Sequential([
151
+ *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
152
+ layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
153
+ layers.BatchNormalization(),
154
+ layers.Activation("relu"),
155
+ layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
156
+ ])
157
+ self.threshold_head = keras.Sequential([
158
+ *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
159
+ layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
160
+ layers.BatchNormalization(),
161
+ layers.Activation("relu"),
162
+ layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
163
+ ])
164
+
165
+ self.postprocessor = DBPostProcessor(
166
+ assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
161
167
  )
162
168
 
163
- self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)
164
-
165
169
  def compute_loss(
166
170
  self,
167
171
  out_map: tf.Tensor,
168
172
  thresh_map: tf.Tensor,
169
173
  target: List[Dict[str, np.ndarray]],
174
+ gamma: float = 2.0,
175
+ alpha: float = 0.5,
176
+ eps: float = 1e-8,
170
177
  ) -> tf.Tensor:
171
178
  """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
172
179
  and a list of masks for each image. From there it computes the loss with the model output
173
180
 
174
181
  Args:
182
+ ----
175
183
  out_map: output feature map of the model of shape (N, H, W, C)
176
184
  thresh_map: threshold map of shape (N, H, W, C)
177
185
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
186
+ gamma: modulating factor in the focal loss formula
187
+ alpha: balancing factor in the focal loss formula
188
+ eps: epsilon factor in dice loss
178
189
 
179
190
  Returns:
191
+ -------
180
192
  A loss tensor
181
193
  """
194
+ if gamma < 0:
195
+ raise ValueError("Value of gamma should be greater than or equal to zero.")
182
196
 
183
197
  prob_map = tf.math.sigmoid(out_map)
184
198
  thresh_map = tf.math.sigmoid(thresh_map)
185
199
 
186
- seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape, True)
200
+ seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True)
187
201
  seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
188
202
  seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
203
+ seg_mask = tf.cast(seg_mask, tf.float32)
189
204
  thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
190
205
  thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool)
191
206
 
192
- # Compute balanced BCE loss for proba_map
193
- bce_scale = 5.0
194
- bce_loss = tf.keras.losses.binary_crossentropy(
195
- seg_target[..., None],
196
- out_map[..., None],
197
- from_logits=True,
198
- )[seg_mask]
199
-
200
- neg_target = 1 - seg_target[seg_mask]
201
- positive_count = tf.math.reduce_sum(seg_target[seg_mask])
202
- negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count])
203
- negative_loss = bce_loss * neg_target
204
- negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32))
205
- sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss)
206
- balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)
207
-
208
- # Compute dice loss for approxbin_map
209
- bin_map = 1 / (1 + tf.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask])))
210
-
211
- bce_min = tf.math.reduce_min(bce_loss)
212
- weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0
213
- inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights)
214
- union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8
215
- dice_loss = 1 - 2.0 * inter / union
207
+ # Focal loss
208
+ focal_scale = 10.0
209
+ bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
210
+
211
+ # Convert logits to prob, compute gamma factor
212
+ p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
213
+ alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha)
214
+ # Unreduced loss
215
+ focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
216
+ # Class reduced
217
+ focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))
218
+
219
+ # Compute dice loss for each class or for approx binary_map
220
+ if len(self.class_names) > 1:
221
+ dice_map = tf.nn.softmax(out_map, axis=-1)
222
+ else:
223
+ # compute binary map instead
224
+ dice_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map)))
225
+ # Class-reduced dice loss
226
+ inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2])
227
+ cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2])
228
+ dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps))
216
229
 
217
230
  # Compute l1 loss for thresh_map
218
- l1_scale = 10.0
219
231
  if tf.reduce_any(thresh_mask):
220
- l1_loss = tf.math.reduce_mean(tf.math.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))
232
+ thresh_mask = tf.cast(thresh_mask, tf.float32)
233
+ l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / (
234
+ tf.reduce_sum(thresh_mask) + eps
235
+ )
221
236
  else:
222
237
  l1_loss = tf.constant(0.0)
223
238
 
224
- return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss
239
+ return l1_loss + focal_scale * focal_loss + dice_loss
225
240
 
226
241
  def call(
227
242
  self,
@@ -241,7 +256,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
241
256
  return out
242
257
 
243
258
  if return_model_output or target is None or return_preds:
244
- prob_map = tf.math.sigmoid(logits)
259
+ prob_map = _bf16_to_float32(tf.math.sigmoid(logits))
245
260
 
246
261
  if return_model_output:
247
262
  out["out_map"] = prob_map
@@ -342,12 +357,14 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
342
357
  >>> out = model(input_tensor)
343
358
 
344
359
  Args:
360
+ ----
345
361
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
362
+ **kwargs: keyword arguments of the DBNet architecture
346
363
 
347
364
  Returns:
365
+ -------
348
366
  text detection architecture
349
367
  """
350
-
351
368
  return _db_resnet(
352
369
  "db_resnet50",
353
370
  pretrained,
@@ -368,12 +385,14 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
368
385
  >>> out = model(input_tensor)
369
386
 
370
387
  Args:
388
+ ----
371
389
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
390
+ **kwargs: keyword arguments of the DBNet architecture
372
391
 
373
392
  Returns:
393
+ -------
374
394
  text detection architecture
375
395
  """
376
-
377
396
  return _db_mobilenet(
378
397
  "db_mobilenet_v3_large",
379
398
  pretrained,
@@ -0,0 +1,6 @@
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
+
3
+ if is_tf_available():
4
+ from .tensorflow import *
5
+ elif is_torch_available():
6
+ from .pytorch import * # type: ignore[assignment]
@@ -0,0 +1,256 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
+
8
+ from typing import Dict, List, Tuple, Union
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import pyclipper
13
+ from shapely.geometry import Polygon
14
+
15
+ from doctr.models.core import BaseModel
16
+
17
+ from ..core import DetectionPostProcessor
18
+
19
+ __all__ = ["_FAST", "FASTPostProcessor"]
20
+
21
+
22
+ class FASTPostProcessor(DetectionPostProcessor):
23
+ """Implements a post processor for FAST model.
24
+
25
+ Args:
26
+ ----
27
+ bin_thresh: threshold used to binzarized p_map at inference time
28
+ box_thresh: minimal objectness score to consider a box
29
+ assume_straight_pages: whether the inputs were expected to have horizontal text elements
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ bin_thresh: float = 0.3,
35
+ box_thresh: float = 0.1,
36
+ assume_straight_pages: bool = True,
37
+ ) -> None:
38
+ super().__init__(box_thresh, bin_thresh, assume_straight_pages)
39
+ self.unclip_ratio = 1.0
40
+
41
+ def polygon_to_box(
42
+ self,
43
+ points: np.ndarray,
44
+ ) -> np.ndarray:
45
+ """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
46
+
47
+ Args:
48
+ ----
49
+ points: The first parameter.
50
+
51
+ Returns:
52
+ -------
53
+ a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
54
+ """
55
+ if not self.assume_straight_pages:
56
+ # Compute the rectangle polygon enclosing the raw polygon
57
+ rect = cv2.minAreaRect(points)
58
+ points = cv2.boxPoints(rect)
59
+ # Add 1 pixel to correct cv2 approx
60
+ area = (rect[1][0] + 1) * (1 + rect[1][1])
61
+ length = 2 * (rect[1][0] + rect[1][1]) + 2
62
+ else:
63
+ poly = Polygon(points)
64
+ area = poly.area
65
+ length = poly.length
66
+ distance = area * self.unclip_ratio / length # compute distance to expand polygon
67
+ offset = pyclipper.PyclipperOffset()
68
+ offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
69
+ _points = offset.Execute(distance)
70
+ # Take biggest stack of points
71
+ idx = 0
72
+ if len(_points) > 1:
73
+ max_size = 0
74
+ for _idx, p in enumerate(_points):
75
+ if len(p) > max_size:
76
+ idx = _idx
77
+ max_size = len(p)
78
+ # We ensure that _points can be correctly casted to a ndarray
79
+ _points = [_points[idx]]
80
+ expanded_points: np.ndarray = np.asarray(_points) # expand polygon
81
+ if len(expanded_points) < 1:
82
+ return None # type: ignore[return-value]
83
+ return (
84
+ cv2.boundingRect(expanded_points) # type: ignore[return-value]
85
+ if self.assume_straight_pages
86
+ else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
87
+ )
88
+
89
+ def bitmap_to_boxes(
90
+ self,
91
+ pred: np.ndarray,
92
+ bitmap: np.ndarray,
93
+ ) -> np.ndarray:
94
+ """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
95
+
96
+ Args:
97
+ ----
98
+ pred: Pred map from differentiable linknet output
99
+ bitmap: Bitmap map computed from pred (binarized)
100
+ angle_tol: Comparison tolerance of the angle with the median angle across the page
101
+ ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
102
+
103
+ Returns:
104
+ -------
105
+ np tensor boxes for the bitmap, each box is a 6-element list
106
+ containing x, y, w, h, alpha, score for the box
107
+ """
108
+ height, width = bitmap.shape[:2]
109
+ boxes: List[Union[np.ndarray, List[float]]] = []
110
+ # get contours from connected components on the bitmap
111
+ contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
112
+ for contour in contours:
113
+ # Check whether smallest enclosing bounding box is not too small
114
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
115
+ continue
116
+ # Compute objectness
117
+ if self.assume_straight_pages:
118
+ x, y, w, h = cv2.boundingRect(contour)
119
+ points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]])
120
+ score = self.box_score(pred, points, assume_straight_pages=True)
121
+ else:
122
+ score = self.box_score(pred, contour, assume_straight_pages=False)
123
+
124
+ if score < self.box_thresh: # remove polygons with a weak objectness
125
+ continue
126
+
127
+ if self.assume_straight_pages:
128
+ _box = self.polygon_to_box(points)
129
+ else:
130
+ _box = self.polygon_to_box(np.squeeze(contour))
131
+
132
+ if self.assume_straight_pages:
133
+ # compute relative polygon to get rid of img shape
134
+ x, y, w, h = _box
135
+ xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
136
+ boxes.append([xmin, ymin, xmax, ymax, score])
137
+ else:
138
+ # compute relative box to get rid of img shape
139
+ _box[:, 0] /= width
140
+ _box[:, 1] /= height
141
+ boxes.append(_box)
142
+
143
+ if not self.assume_straight_pages:
144
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
145
+ else:
146
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
147
+
148
+
149
+ class _FAST(BaseModel):
150
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
151
+ <https://arxiv.org/pdf/2111.02394.pdf>`_.
152
+ """
153
+
154
+ min_size_box: int = 3
155
+ assume_straight_pages: bool = True
156
+ shrink_ratio = 0.1
157
+
158
+ def build_target(
159
+ self,
160
+ target: List[Dict[str, np.ndarray]],
161
+ output_shape: Tuple[int, int, int],
162
+ channels_last: bool = True,
163
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
164
+ """Build the target, and it's mask to be used from loss computation.
165
+
166
+ Args:
167
+ ----
168
+ target: target coming from dataset
169
+ output_shape: shape of the output of the model without batch_size
170
+ channels_last: whether channels are last or not
171
+
172
+ Returns:
173
+ -------
174
+ the new formatted target, mask and shrunken text kernel
175
+ """
176
+ if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
177
+ raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
178
+ if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
179
+ raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")
180
+
181
+ h: int
182
+ w: int
183
+ if channels_last:
184
+ h, w, num_classes = output_shape
185
+ else:
186
+ num_classes, h, w = output_shape
187
+ target_shape = (len(target), num_classes, h, w)
188
+
189
+ seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
190
+ seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
191
+ shrunken_kernel: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
192
+
193
+ for idx, tgt in enumerate(target):
194
+ for class_idx, _tgt in enumerate(tgt.values()):
195
+ # Draw each polygon on gt
196
+ if _tgt.shape[0] == 0:
197
+ # Empty image, full masked
198
+ seg_mask[idx, class_idx] = False
199
+
200
+ # Absolute bounding boxes
201
+ abs_boxes = _tgt.copy()
202
+
203
+ if abs_boxes.ndim == 3:
204
+ abs_boxes[:, :, 0] *= w
205
+ abs_boxes[:, :, 1] *= h
206
+ polys = abs_boxes
207
+ boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1)
208
+ abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32)
209
+ else:
210
+ abs_boxes[:, [0, 2]] *= w
211
+ abs_boxes[:, [1, 3]] *= h
212
+ abs_boxes = abs_boxes.round().astype(np.int32)
213
+ polys = np.stack(
214
+ [
215
+ abs_boxes[:, [0, 1]],
216
+ abs_boxes[:, [0, 3]],
217
+ abs_boxes[:, [2, 3]],
218
+ abs_boxes[:, [2, 1]],
219
+ ],
220
+ axis=1,
221
+ )
222
+ boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])
223
+
224
+ for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
225
+ # Mask boxes that are too small
226
+ if box_size < self.min_size_box:
227
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
228
+ continue
229
+
230
+ # Negative shrink for gt, as described in paper
231
+ polygon = Polygon(poly)
232
+ distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
233
+ subject = [tuple(coor) for coor in poly]
234
+ padding = pyclipper.PyclipperOffset()
235
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
236
+ shrunken = padding.Execute(-distance)
237
+
238
+ # Draw polygon on gt if it is valid
239
+ if len(shrunken) == 0:
240
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
241
+ continue
242
+ shrunken = np.array(shrunken[0]).reshape(-1, 2)
243
+ if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
244
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
245
+ continue
246
+ cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
247
+ # draw the original polygon on the segmentation target
248
+ cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload]
249
+
250
+ # Don't forget to switch back to channel last if Tensorflow is used
251
+ if channels_last:
252
+ seg_target = seg_target.transpose((0, 2, 3, 1))
253
+ seg_mask = seg_mask.transpose((0, 2, 3, 1))
254
+ shrunken_kernel = shrunken_kernel.transpose((0, 2, 3, 1))
255
+
256
+ return seg_target, seg_mask, shrunken_kernel