python-doctr 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/datasets/pytorch.py +2 -2
  6. doctr/datasets/generator/base.py +6 -5
  7. doctr/datasets/imgur5k.py +1 -1
  8. doctr/datasets/loader.py +1 -6
  9. doctr/datasets/utils.py +2 -1
  10. doctr/datasets/vocabs.py +9 -2
  11. doctr/file_utils.py +26 -12
  12. doctr/io/elements.py +40 -6
  13. doctr/io/html.py +2 -2
  14. doctr/io/image/pytorch.py +6 -8
  15. doctr/io/image/tensorflow.py +1 -1
  16. doctr/io/pdf.py +5 -2
  17. doctr/io/reader.py +6 -0
  18. doctr/models/__init__.py +0 -1
  19. doctr/models/_utils.py +57 -20
  20. doctr/models/builder.py +71 -13
  21. doctr/models/classification/mobilenet/pytorch.py +45 -9
  22. doctr/models/classification/mobilenet/tensorflow.py +38 -7
  23. doctr/models/classification/predictor/pytorch.py +18 -11
  24. doctr/models/classification/predictor/tensorflow.py +16 -10
  25. doctr/models/classification/textnet/pytorch.py +3 -3
  26. doctr/models/classification/textnet/tensorflow.py +3 -3
  27. doctr/models/classification/zoo.py +39 -15
  28. doctr/models/detection/__init__.py +1 -0
  29. doctr/models/detection/_utils/__init__.py +1 -0
  30. doctr/models/detection/_utils/base.py +66 -0
  31. doctr/models/detection/differentiable_binarization/base.py +4 -3
  32. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  33. doctr/models/detection/differentiable_binarization/tensorflow.py +14 -18
  34. doctr/models/detection/fast/__init__.py +6 -0
  35. doctr/models/detection/fast/base.py +257 -0
  36. doctr/models/detection/fast/pytorch.py +442 -0
  37. doctr/models/detection/fast/tensorflow.py +428 -0
  38. doctr/models/detection/linknet/base.py +4 -3
  39. doctr/models/detection/predictor/pytorch.py +15 -1
  40. doctr/models/detection/predictor/tensorflow.py +15 -1
  41. doctr/models/detection/zoo.py +21 -4
  42. doctr/models/factory/hub.py +3 -12
  43. doctr/models/kie_predictor/base.py +9 -3
  44. doctr/models/kie_predictor/pytorch.py +41 -20
  45. doctr/models/kie_predictor/tensorflow.py +36 -16
  46. doctr/models/modules/layers/pytorch.py +89 -10
  47. doctr/models/modules/layers/tensorflow.py +88 -10
  48. doctr/models/modules/transformer/pytorch.py +2 -2
  49. doctr/models/predictor/base.py +77 -50
  50. doctr/models/predictor/pytorch.py +31 -20
  51. doctr/models/predictor/tensorflow.py +27 -17
  52. doctr/models/preprocessor/pytorch.py +4 -4
  53. doctr/models/preprocessor/tensorflow.py +3 -2
  54. doctr/models/recognition/master/pytorch.py +2 -2
  55. doctr/models/recognition/parseq/pytorch.py +4 -3
  56. doctr/models/recognition/parseq/tensorflow.py +4 -3
  57. doctr/models/recognition/sar/pytorch.py +7 -6
  58. doctr/models/recognition/sar/tensorflow.py +3 -9
  59. doctr/models/recognition/vitstr/pytorch.py +1 -1
  60. doctr/models/recognition/zoo.py +1 -1
  61. doctr/models/zoo.py +2 -2
  62. doctr/py.typed +0 -0
  63. doctr/transforms/functional/base.py +1 -1
  64. doctr/transforms/functional/pytorch.py +4 -4
  65. doctr/transforms/modules/base.py +37 -15
  66. doctr/transforms/modules/pytorch.py +66 -8
  67. doctr/transforms/modules/tensorflow.py +63 -7
  68. doctr/utils/fonts.py +7 -5
  69. doctr/utils/geometry.py +35 -12
  70. doctr/utils/metrics.py +33 -174
  71. doctr/utils/reconstitution.py +126 -0
  72. doctr/utils/visualization.py +5 -118
  73. doctr/version.py +1 -1
  74. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/METADATA +96 -91
  75. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/RECORD +79 -75
  76. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/WHEEL +1 -1
  77. doctr/models/artefacts/__init__.py +0 -2
  78. doctr/models/artefacts/barcode.py +0 -74
  79. doctr/models/artefacts/face.py +0 -63
  80. doctr/models/obj_detection/__init__.py +0 -1
  81. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  82. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  83. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/LICENSE +0 -0
  84. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/top_level.txt +0 -0
  85. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/zip-safe +0 -0
@@ -22,21 +22,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
22
22
  "std": (0.299, 0.296, 0.301),
23
23
  "input_shape": (32, 32, 3),
24
24
  "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-9e605bd8.zip&src=0",
25
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-fe9cc245.zip&src=0",
26
26
  },
27
27
  "textnet_small": {
28
28
  "mean": (0.694, 0.695, 0.693),
29
29
  "std": (0.299, 0.296, 0.301),
30
30
  "input_shape": (32, 32, 3),
31
31
  "classes": list(VOCABS["french"]),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-4784b292.zip&src=0",
32
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-29c39c82.zip&src=0",
33
33
  },
34
34
  "textnet_base": {
35
35
  "mean": (0.694, 0.695, 0.693),
36
36
  "std": (0.299, 0.296, 0.301),
37
37
  "input_shape": (32, 32, 3),
38
38
  "classes": list(VOCABS["french"]),
39
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-2c3f3265.zip&src=0",
39
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-168aa82c.zip&src=0",
40
40
  },
41
41
  }
42
42
 
@@ -9,9 +9,9 @@ from doctr.file_utils import is_tf_available
9
9
 
10
10
  from .. import classification
11
11
  from ..preprocessor import PreProcessor
12
- from .predictor import CropOrientationPredictor
12
+ from .predictor import OrientationPredictor
13
13
 
14
- __all__ = ["crop_orientation_predictor"]
14
+ __all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
15
15
 
16
16
  ARCHS: List[str] = [
17
17
  "magc_resnet31",
@@ -31,10 +31,10 @@ ARCHS: List[str] = [
31
31
  "vit_s",
32
32
  "vit_b",
33
33
  ]
34
- ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"]
34
+ ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
35
35
 
36
36
 
37
- def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> CropOrientationPredictor:
37
+ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor:
38
38
  if arch not in ORIENTATION_ARCHS:
39
39
  raise ValueError(f"unknown architecture '{arch}'")
40
40
 
@@ -42,33 +42,57 @@ def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> C
42
42
  _model = classification.__dict__[arch](pretrained=pretrained)
43
43
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
44
44
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
45
- kwargs["batch_size"] = kwargs.get("batch_size", 64)
45
+ kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
46
46
  input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
47
- predictor = CropOrientationPredictor(
47
+ predictor = OrientationPredictor(
48
48
  PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
49
49
  )
50
50
  return predictor
51
51
 
52
52
 
53
53
  def crop_orientation_predictor(
54
- arch: str = "mobilenet_v3_small_orientation", pretrained: bool = False, **kwargs: Any
55
- ) -> CropOrientationPredictor:
56
- """Orientation classification architecture.
54
+ arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
55
+ ) -> OrientationPredictor:
56
+ """Crop orientation classification architecture.
57
57
 
58
58
  >>> import numpy as np
59
59
  >>> from doctr.models import crop_orientation_predictor
60
- >>> model = crop_orientation_predictor(arch='classif_mobilenet_v3_small', pretrained=True)
61
- >>> input_crop = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
60
+ >>> model = crop_orientation_predictor(arch='mobilenet_v3_small_crop_orientation', pretrained=True)
61
+ >>> input_crop = (255 * np.random.rand(256, 256, 3)).astype(np.uint8)
62
62
  >>> out = model([input_crop])
63
63
 
64
64
  Args:
65
65
  ----
66
- arch: name of the architecture to use (e.g. 'mobilenet_v3_small')
66
+ arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
67
67
  pretrained: If True, returns a model pre-trained on our recognition crops dataset
68
- **kwargs: keyword arguments to be passed to the CropOrientationPredictor
68
+ **kwargs: keyword arguments to be passed to the OrientationPredictor
69
69
 
70
70
  Returns:
71
71
  -------
72
- CropOrientationPredictor
72
+ OrientationPredictor
73
73
  """
74
- return _crop_orientation_predictor(arch, pretrained, **kwargs)
74
+ return _orientation_predictor(arch, pretrained, **kwargs)
75
+
76
+
77
+ def page_orientation_predictor(
78
+ arch: str = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
79
+ ) -> OrientationPredictor:
80
+ """Page orientation classification architecture.
81
+
82
+ >>> import numpy as np
83
+ >>> from doctr.models import page_orientation_predictor
84
+ >>> model = page_orientation_predictor(arch='mobilenet_v3_small_page_orientation', pretrained=True)
85
+ >>> input_page = (255 * np.random.rand(512, 512, 3)).astype(np.uint8)
86
+ >>> out = model([input_page])
87
+
88
+ Args:
89
+ ----
90
+ arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
91
+ pretrained: If True, returns a model pre-trained on our recognition crops dataset
92
+ **kwargs: keyword arguments to be passed to the OrientationPredictor
93
+
94
+ Returns:
95
+ -------
96
+ OrientationPredictor
97
+ """
98
+ return _orientation_predictor(arch, pretrained, **kwargs)
@@ -1,3 +1,4 @@
1
1
  from .differentiable_binarization import *
2
2
  from .linknet import *
3
+ from .fast import *
3
4
  from .zoo import *
@@ -1,4 +1,5 @@
1
1
  from doctr.file_utils import is_tf_available
2
+ from .base import *
2
3
 
3
4
  if is_tf_available():
4
5
  from .tensorflow import *
@@ -0,0 +1,66 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Dict, List
7
+
8
+ import numpy as np
9
+
10
+ __all__ = ["_remove_padding"]
11
+
12
+
13
+ def _remove_padding(
14
+ pages: List[np.ndarray],
15
+ loc_preds: List[Dict[str, np.ndarray]],
16
+ preserve_aspect_ratio: bool,
17
+ symmetric_pad: bool,
18
+ assume_straight_pages: bool,
19
+ ) -> List[Dict[str, np.ndarray]]:
20
+ """Remove padding from the localization predictions
21
+
22
+ Args:
23
+ ----
24
+ pages: list of pages
25
+ loc_preds: list of localization predictions
26
+ preserve_aspect_ratio: whether the aspect ratio was preserved during padding
27
+ symmetric_pad: whether the padding was symmetric
28
+ assume_straight_pages: whether the pages are assumed to be straight
29
+
30
+ Returns:
31
+ -------
32
+ list of unpaded localization predictions
33
+ """
34
+ if preserve_aspect_ratio:
35
+ # Rectify loc_preds to remove padding
36
+ rectified_preds = []
37
+ for page, dict_loc_preds in zip(pages, loc_preds):
38
+ for k, loc_pred in dict_loc_preds.items():
39
+ h, w = page.shape[0], page.shape[1]
40
+ if h > w:
41
+ # y unchanged, dilate x coord
42
+ if symmetric_pad:
43
+ if assume_straight_pages:
44
+ loc_pred[:, [0, 2]] = (loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5
45
+ else:
46
+ loc_pred[:, :, 0] = (loc_pred[:, :, 0] - 0.5) * h / w + 0.5
47
+ else:
48
+ if assume_straight_pages:
49
+ loc_pred[:, [0, 2]] *= h / w
50
+ else:
51
+ loc_pred[:, :, 0] *= h / w
52
+ elif w > h:
53
+ # x unchanged, dilate y coord
54
+ if symmetric_pad:
55
+ if assume_straight_pages:
56
+ loc_pred[:, [1, 3]] = (loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5
57
+ else:
58
+ loc_pred[:, :, 1] = (loc_pred[:, :, 1] - 0.5) * w / h + 0.5
59
+ else:
60
+ if assume_straight_pages:
61
+ loc_pred[:, [1, 3]] *= w / h
62
+ else:
63
+ loc_pred[:, :, 1] *= w / h
64
+ rectified_preds.append({k: np.clip(loc_pred, 0, 1)})
65
+ return rectified_preds
66
+ return loc_preds
@@ -114,7 +114,7 @@ class DBPostProcessor(DetectionPostProcessor):
114
114
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
115
115
  for contour in contours:
116
116
  # Check whether smallest enclosing bounding box is not too small
117
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
117
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): # type: ignore[index]
118
118
  continue
119
119
  # Compute objectness
120
120
  if self.assume_straight_pages:
@@ -150,10 +150,11 @@ class DBPostProcessor(DetectionPostProcessor):
150
150
  raise AssertionError("When assume straight pages is false a box is a (4, 2) array (polygon)")
151
151
  _box[:, 0] /= width
152
152
  _box[:, 1] /= height
153
- boxes.append(_box)
153
+ # Add score to box as (0, score)
154
+ boxes.append(np.vstack([_box, np.array([0.0, score])]))
154
155
 
155
156
  if not self.assume_straight_pages:
156
- return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
157
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
157
158
  else:
158
159
  return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
159
160
 
@@ -39,7 +39,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
39
39
  "input_shape": (3, 1024, 1024),
40
40
  "mean": (0.798, 0.785, 0.772),
41
41
  "std": (0.264, 0.2749, 0.287),
42
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-81e9b152.pt&src=0",
42
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/db_mobilenet_v3_large-21748dd0.pt&src=0",
43
43
  },
44
44
  }
45
45
 
@@ -273,7 +273,7 @@ class DBNet(_DBNet, nn.Module):
273
273
  dice_map = torch.softmax(out_map, dim=1)
274
274
  else:
275
275
  # compute binary map instead
276
- dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
276
+ dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
277
277
  # Class reduced
278
278
  inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
279
279
  cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
@@ -147,24 +147,20 @@ class DBNet(_DBNet, keras.Model, NestedObject):
147
147
  _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
148
148
  output_shape = tuple(self.fpn(_inputs).shape)
149
149
 
150
- self.probability_head = keras.Sequential(
151
- [
152
- *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
153
- layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
154
- layers.BatchNormalization(),
155
- layers.Activation("relu"),
156
- layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
157
- ]
158
- )
159
- self.threshold_head = keras.Sequential(
160
- [
161
- *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
162
- layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
163
- layers.BatchNormalization(),
164
- layers.Activation("relu"),
165
- layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
166
- ]
167
- )
150
+ self.probability_head = keras.Sequential([
151
+ *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
152
+ layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
153
+ layers.BatchNormalization(),
154
+ layers.Activation("relu"),
155
+ layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
156
+ ])
157
+ self.threshold_head = keras.Sequential([
158
+ *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
159
+ layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
160
+ layers.BatchNormalization(),
161
+ layers.Activation("relu"),
162
+ layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
163
+ ])
168
164
 
169
165
  self.postprocessor = DBPostProcessor(
170
166
  assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
@@ -0,0 +1,6 @@
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
+
3
+ if is_tf_available():
4
+ from .tensorflow import *
5
+ elif is_torch_available():
6
+ from .pytorch import * # type: ignore[assignment]
@@ -0,0 +1,257 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
+
8
+ from typing import Dict, List, Tuple, Union
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import pyclipper
13
+ from shapely.geometry import Polygon
14
+
15
+ from doctr.models.core import BaseModel
16
+
17
+ from ..core import DetectionPostProcessor
18
+
19
+ __all__ = ["_FAST", "FASTPostProcessor"]
20
+
21
+
22
+ class FASTPostProcessor(DetectionPostProcessor):
23
+ """Implements a post processor for FAST model.
24
+
25
+ Args:
26
+ ----
27
+ bin_thresh: threshold used to binzarized p_map at inference time
28
+ box_thresh: minimal objectness score to consider a box
29
+ assume_straight_pages: whether the inputs were expected to have horizontal text elements
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ bin_thresh: float = 0.1,
35
+ box_thresh: float = 0.1,
36
+ assume_straight_pages: bool = True,
37
+ ) -> None:
38
+ super().__init__(box_thresh, bin_thresh, assume_straight_pages)
39
+ self.unclip_ratio = 1.0
40
+
41
+ def polygon_to_box(
42
+ self,
43
+ points: np.ndarray,
44
+ ) -> np.ndarray:
45
+ """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
46
+
47
+ Args:
48
+ ----
49
+ points: The first parameter.
50
+
51
+ Returns:
52
+ -------
53
+ a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
54
+ """
55
+ if not self.assume_straight_pages:
56
+ # Compute the rectangle polygon enclosing the raw polygon
57
+ rect = cv2.minAreaRect(points)
58
+ points = cv2.boxPoints(rect)
59
+ # Add 1 pixel to correct cv2 approx
60
+ area = (rect[1][0] + 1) * (1 + rect[1][1])
61
+ length = 2 * (rect[1][0] + rect[1][1]) + 2
62
+ else:
63
+ poly = Polygon(points)
64
+ area = poly.area
65
+ length = poly.length
66
+ distance = area * self.unclip_ratio / length # compute distance to expand polygon
67
+ offset = pyclipper.PyclipperOffset()
68
+ offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
69
+ _points = offset.Execute(distance)
70
+ # Take biggest stack of points
71
+ idx = 0
72
+ if len(_points) > 1:
73
+ max_size = 0
74
+ for _idx, p in enumerate(_points):
75
+ if len(p) > max_size:
76
+ idx = _idx
77
+ max_size = len(p)
78
+ # We ensure that _points can be correctly casted to a ndarray
79
+ _points = [_points[idx]]
80
+ expanded_points: np.ndarray = np.asarray(_points) # expand polygon
81
+ if len(expanded_points) < 1:
82
+ return None # type: ignore[return-value]
83
+ return (
84
+ cv2.boundingRect(expanded_points) # type: ignore[return-value]
85
+ if self.assume_straight_pages
86
+ else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
87
+ )
88
+
89
+ def bitmap_to_boxes(
90
+ self,
91
+ pred: np.ndarray,
92
+ bitmap: np.ndarray,
93
+ ) -> np.ndarray:
94
+ """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
95
+
96
+ Args:
97
+ ----
98
+ pred: Pred map from differentiable linknet output
99
+ bitmap: Bitmap map computed from pred (binarized)
100
+ angle_tol: Comparison tolerance of the angle with the median angle across the page
101
+ ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
102
+
103
+ Returns:
104
+ -------
105
+ np tensor boxes for the bitmap, each box is a 6-element list
106
+ containing x, y, w, h, alpha, score for the box
107
+ """
108
+ height, width = bitmap.shape[:2]
109
+ boxes: List[Union[np.ndarray, List[float]]] = []
110
+ # get contours from connected components on the bitmap
111
+ contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
112
+ for contour in contours:
113
+ # Check whether smallest enclosing bounding box is not too small
114
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
115
+ continue
116
+ # Compute objectness
117
+ if self.assume_straight_pages:
118
+ x, y, w, h = cv2.boundingRect(contour)
119
+ points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]])
120
+ score = self.box_score(pred, points, assume_straight_pages=True)
121
+ else:
122
+ score = self.box_score(pred, contour, assume_straight_pages=False)
123
+
124
+ if score < self.box_thresh: # remove polygons with a weak objectness
125
+ continue
126
+
127
+ if self.assume_straight_pages:
128
+ _box = self.polygon_to_box(points)
129
+ else:
130
+ _box = self.polygon_to_box(np.squeeze(contour))
131
+
132
+ if self.assume_straight_pages:
133
+ # compute relative polygon to get rid of img shape
134
+ x, y, w, h = _box
135
+ xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
136
+ boxes.append([xmin, ymin, xmax, ymax, score])
137
+ else:
138
+ # compute relative box to get rid of img shape
139
+ _box[:, 0] /= width
140
+ _box[:, 1] /= height
141
+ # Add score to box as (0, score)
142
+ boxes.append(np.vstack([_box, np.array([0.0, score])]))
143
+
144
+ if not self.assume_straight_pages:
145
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
146
+ else:
147
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
148
+
149
+
150
+ class _FAST(BaseModel):
151
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
152
+ <https://arxiv.org/pdf/2111.02394.pdf>`_.
153
+ """
154
+
155
+ min_size_box: int = 3
156
+ assume_straight_pages: bool = True
157
+ shrink_ratio = 0.4
158
+
159
+ def build_target(
160
+ self,
161
+ target: List[Dict[str, np.ndarray]],
162
+ output_shape: Tuple[int, int, int],
163
+ channels_last: bool = True,
164
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
165
+ """Build the target, and it's mask to be used from loss computation.
166
+
167
+ Args:
168
+ ----
169
+ target: target coming from dataset
170
+ output_shape: shape of the output of the model without batch_size
171
+ channels_last: whether channels are last or not
172
+
173
+ Returns:
174
+ -------
175
+ the new formatted target, mask and shrunken text kernel
176
+ """
177
+ if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
178
+ raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
179
+ if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
180
+ raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")
181
+
182
+ h: int
183
+ w: int
184
+ if channels_last:
185
+ h, w, num_classes = output_shape
186
+ else:
187
+ num_classes, h, w = output_shape
188
+ target_shape = (len(target), num_classes, h, w)
189
+
190
+ seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
191
+ seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
192
+ shrunken_kernel: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
193
+
194
+ for idx, tgt in enumerate(target):
195
+ for class_idx, _tgt in enumerate(tgt.values()):
196
+ # Draw each polygon on gt
197
+ if _tgt.shape[0] == 0:
198
+ # Empty image, full masked
199
+ seg_mask[idx, class_idx] = False
200
+
201
+ # Absolute bounding boxes
202
+ abs_boxes = _tgt.copy()
203
+
204
+ if abs_boxes.ndim == 3:
205
+ abs_boxes[:, :, 0] *= w
206
+ abs_boxes[:, :, 1] *= h
207
+ polys = abs_boxes
208
+ boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1)
209
+ abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32)
210
+ else:
211
+ abs_boxes[:, [0, 2]] *= w
212
+ abs_boxes[:, [1, 3]] *= h
213
+ abs_boxes = abs_boxes.round().astype(np.int32)
214
+ polys = np.stack(
215
+ [
216
+ abs_boxes[:, [0, 1]],
217
+ abs_boxes[:, [0, 3]],
218
+ abs_boxes[:, [2, 3]],
219
+ abs_boxes[:, [2, 1]],
220
+ ],
221
+ axis=1,
222
+ )
223
+ boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])
224
+
225
+ for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
226
+ # Mask boxes that are too small
227
+ if box_size < self.min_size_box:
228
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
229
+ continue
230
+
231
+ # Negative shrink for gt, as described in paper
232
+ polygon = Polygon(poly)
233
+ distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
234
+ subject = [tuple(coor) for coor in poly]
235
+ padding = pyclipper.PyclipperOffset()
236
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
237
+ shrunken = padding.Execute(-distance)
238
+
239
+ # Draw polygon on gt if it is valid
240
+ if len(shrunken) == 0:
241
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
242
+ continue
243
+ shrunken = np.array(shrunken[0]).reshape(-1, 2)
244
+ if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
245
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
246
+ continue
247
+ cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
248
+ # draw the original polygon on the segmentation target
249
+ cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload]
250
+
251
+ # Don't forget to switch back to channel last if Tensorflow is used
252
+ if channels_last:
253
+ seg_target = seg_target.transpose((0, 2, 3, 1))
254
+ seg_mask = seg_mask.transpose((0, 2, 3, 1))
255
+ shrunken_kernel = shrunken_kernel.transpose((0, 2, 3, 1))
256
+
257
+ return seg_target, seg_mask, shrunken_kernel