python-doctr 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/datasets/pytorch.py +2 -2
  6. doctr/datasets/generator/base.py +6 -5
  7. doctr/datasets/imgur5k.py +1 -1
  8. doctr/datasets/loader.py +1 -6
  9. doctr/datasets/utils.py +2 -1
  10. doctr/datasets/vocabs.py +9 -2
  11. doctr/file_utils.py +26 -12
  12. doctr/io/elements.py +40 -6
  13. doctr/io/html.py +2 -2
  14. doctr/io/image/pytorch.py +6 -8
  15. doctr/io/image/tensorflow.py +1 -1
  16. doctr/io/pdf.py +5 -2
  17. doctr/io/reader.py +6 -0
  18. doctr/models/__init__.py +0 -1
  19. doctr/models/_utils.py +57 -20
  20. doctr/models/builder.py +71 -13
  21. doctr/models/classification/mobilenet/pytorch.py +45 -9
  22. doctr/models/classification/mobilenet/tensorflow.py +38 -7
  23. doctr/models/classification/predictor/pytorch.py +18 -11
  24. doctr/models/classification/predictor/tensorflow.py +16 -10
  25. doctr/models/classification/textnet/pytorch.py +3 -3
  26. doctr/models/classification/textnet/tensorflow.py +3 -3
  27. doctr/models/classification/zoo.py +39 -15
  28. doctr/models/detection/_utils/__init__.py +1 -0
  29. doctr/models/detection/_utils/base.py +66 -0
  30. doctr/models/detection/differentiable_binarization/base.py +4 -3
  31. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  32. doctr/models/detection/fast/base.py +6 -5
  33. doctr/models/detection/fast/pytorch.py +4 -4
  34. doctr/models/detection/fast/tensorflow.py +4 -4
  35. doctr/models/detection/linknet/base.py +4 -3
  36. doctr/models/detection/predictor/pytorch.py +15 -1
  37. doctr/models/detection/predictor/tensorflow.py +15 -1
  38. doctr/models/detection/zoo.py +7 -2
  39. doctr/models/factory/hub.py +3 -12
  40. doctr/models/kie_predictor/base.py +9 -3
  41. doctr/models/kie_predictor/pytorch.py +41 -20
  42. doctr/models/kie_predictor/tensorflow.py +36 -16
  43. doctr/models/modules/layers/pytorch.py +2 -3
  44. doctr/models/modules/layers/tensorflow.py +6 -8
  45. doctr/models/modules/transformer/pytorch.py +2 -2
  46. doctr/models/predictor/base.py +77 -50
  47. doctr/models/predictor/pytorch.py +31 -20
  48. doctr/models/predictor/tensorflow.py +27 -17
  49. doctr/models/preprocessor/pytorch.py +4 -4
  50. doctr/models/preprocessor/tensorflow.py +3 -2
  51. doctr/models/recognition/master/pytorch.py +2 -2
  52. doctr/models/recognition/parseq/pytorch.py +4 -3
  53. doctr/models/recognition/parseq/tensorflow.py +4 -3
  54. doctr/models/recognition/sar/pytorch.py +7 -6
  55. doctr/models/recognition/sar/tensorflow.py +3 -9
  56. doctr/models/recognition/vitstr/pytorch.py +1 -1
  57. doctr/models/recognition/zoo.py +1 -1
  58. doctr/models/zoo.py +2 -2
  59. doctr/py.typed +0 -0
  60. doctr/transforms/functional/base.py +1 -1
  61. doctr/transforms/functional/pytorch.py +4 -4
  62. doctr/transforms/modules/base.py +37 -15
  63. doctr/transforms/modules/pytorch.py +66 -8
  64. doctr/transforms/modules/tensorflow.py +63 -7
  65. doctr/utils/fonts.py +7 -5
  66. doctr/utils/geometry.py +35 -12
  67. doctr/utils/metrics.py +33 -174
  68. doctr/utils/reconstitution.py +126 -0
  69. doctr/utils/visualization.py +5 -118
  70. doctr/version.py +1 -1
  71. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/METADATA +84 -80
  72. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/RECORD +76 -76
  73. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/WHEEL +1 -1
  74. doctr/models/artefacts/__init__.py +0 -2
  75. doctr/models/artefacts/barcode.py +0 -74
  76. doctr/models/artefacts/face.py +0 -63
  77. doctr/models/obj_detection/__init__.py +0 -1
  78. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  79. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  80. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/LICENSE +0 -0
  81. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/top_level.txt +0 -0
  82. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/zip-safe +0 -0
@@ -22,21 +22,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
22
22
  "std": (0.299, 0.296, 0.301),
23
23
  "input_shape": (32, 32, 3),
24
24
  "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-9e605bd8.zip&src=0",
25
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-fe9cc245.zip&src=0",
26
26
  },
27
27
  "textnet_small": {
28
28
  "mean": (0.694, 0.695, 0.693),
29
29
  "std": (0.299, 0.296, 0.301),
30
30
  "input_shape": (32, 32, 3),
31
31
  "classes": list(VOCABS["french"]),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-4784b292.zip&src=0",
32
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-29c39c82.zip&src=0",
33
33
  },
34
34
  "textnet_base": {
35
35
  "mean": (0.694, 0.695, 0.693),
36
36
  "std": (0.299, 0.296, 0.301),
37
37
  "input_shape": (32, 32, 3),
38
38
  "classes": list(VOCABS["french"]),
39
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-2c3f3265.zip&src=0",
39
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-168aa82c.zip&src=0",
40
40
  },
41
41
  }
42
42
 
@@ -9,9 +9,9 @@ from doctr.file_utils import is_tf_available
9
9
 
10
10
  from .. import classification
11
11
  from ..preprocessor import PreProcessor
12
- from .predictor import CropOrientationPredictor
12
+ from .predictor import OrientationPredictor
13
13
 
14
- __all__ = ["crop_orientation_predictor"]
14
+ __all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
15
15
 
16
16
  ARCHS: List[str] = [
17
17
  "magc_resnet31",
@@ -31,10 +31,10 @@ ARCHS: List[str] = [
31
31
  "vit_s",
32
32
  "vit_b",
33
33
  ]
34
- ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"]
34
+ ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
35
35
 
36
36
 
37
- def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> CropOrientationPredictor:
37
+ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor:
38
38
  if arch not in ORIENTATION_ARCHS:
39
39
  raise ValueError(f"unknown architecture '{arch}'")
40
40
 
@@ -42,33 +42,57 @@ def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> C
42
42
  _model = classification.__dict__[arch](pretrained=pretrained)
43
43
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
44
44
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
45
- kwargs["batch_size"] = kwargs.get("batch_size", 64)
45
+ kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
46
46
  input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
47
- predictor = CropOrientationPredictor(
47
+ predictor = OrientationPredictor(
48
48
  PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
49
49
  )
50
50
  return predictor
51
51
 
52
52
 
53
53
  def crop_orientation_predictor(
54
- arch: str = "mobilenet_v3_small_orientation", pretrained: bool = False, **kwargs: Any
55
- ) -> CropOrientationPredictor:
56
- """Orientation classification architecture.
54
+ arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
55
+ ) -> OrientationPredictor:
56
+ """Crop orientation classification architecture.
57
57
 
58
58
  >>> import numpy as np
59
59
  >>> from doctr.models import crop_orientation_predictor
60
- >>> model = crop_orientation_predictor(arch='classif_mobilenet_v3_small', pretrained=True)
61
- >>> input_crop = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
60
+ >>> model = crop_orientation_predictor(arch='mobilenet_v3_small_crop_orientation', pretrained=True)
61
+ >>> input_crop = (255 * np.random.rand(256, 256, 3)).astype(np.uint8)
62
62
  >>> out = model([input_crop])
63
63
 
64
64
  Args:
65
65
  ----
66
- arch: name of the architecture to use (e.g. 'mobilenet_v3_small')
66
+ arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
67
67
  pretrained: If True, returns a model pre-trained on our recognition crops dataset
68
- **kwargs: keyword arguments to be passed to the CropOrientationPredictor
68
+ **kwargs: keyword arguments to be passed to the OrientationPredictor
69
69
 
70
70
  Returns:
71
71
  -------
72
- CropOrientationPredictor
72
+ OrientationPredictor
73
73
  """
74
- return _crop_orientation_predictor(arch, pretrained, **kwargs)
74
+ return _orientation_predictor(arch, pretrained, **kwargs)
75
+
76
+
77
+ def page_orientation_predictor(
78
+ arch: str = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
79
+ ) -> OrientationPredictor:
80
+ """Page orientation classification architecture.
81
+
82
+ >>> import numpy as np
83
+ >>> from doctr.models import page_orientation_predictor
84
+ >>> model = page_orientation_predictor(arch='mobilenet_v3_small_page_orientation', pretrained=True)
85
+ >>> input_page = (255 * np.random.rand(512, 512, 3)).astype(np.uint8)
86
+ >>> out = model([input_page])
87
+
88
+ Args:
89
+ ----
90
+ arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
91
+ pretrained: If True, returns a model pre-trained on our recognition crops dataset
92
+ **kwargs: keyword arguments to be passed to the OrientationPredictor
93
+
94
+ Returns:
95
+ -------
96
+ OrientationPredictor
97
+ """
98
+ return _orientation_predictor(arch, pretrained, **kwargs)
@@ -1,4 +1,5 @@
1
1
  from doctr.file_utils import is_tf_available
2
+ from .base import *
2
3
 
3
4
  if is_tf_available():
4
5
  from .tensorflow import *
@@ -0,0 +1,66 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Dict, List
7
+
8
+ import numpy as np
9
+
10
+ __all__ = ["_remove_padding"]
11
+
12
+
13
+ def _remove_padding(
14
+ pages: List[np.ndarray],
15
+ loc_preds: List[Dict[str, np.ndarray]],
16
+ preserve_aspect_ratio: bool,
17
+ symmetric_pad: bool,
18
+ assume_straight_pages: bool,
19
+ ) -> List[Dict[str, np.ndarray]]:
20
+ """Remove padding from the localization predictions
21
+
22
+ Args:
23
+ ----
24
+ pages: list of pages
25
+ loc_preds: list of localization predictions
26
+ preserve_aspect_ratio: whether the aspect ratio was preserved during padding
27
+ symmetric_pad: whether the padding was symmetric
28
+ assume_straight_pages: whether the pages are assumed to be straight
29
+
30
+ Returns:
31
+ -------
32
+ list of unpaded localization predictions
33
+ """
34
+ if preserve_aspect_ratio:
35
+ # Rectify loc_preds to remove padding
36
+ rectified_preds = []
37
+ for page, dict_loc_preds in zip(pages, loc_preds):
38
+ for k, loc_pred in dict_loc_preds.items():
39
+ h, w = page.shape[0], page.shape[1]
40
+ if h > w:
41
+ # y unchanged, dilate x coord
42
+ if symmetric_pad:
43
+ if assume_straight_pages:
44
+ loc_pred[:, [0, 2]] = (loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5
45
+ else:
46
+ loc_pred[:, :, 0] = (loc_pred[:, :, 0] - 0.5) * h / w + 0.5
47
+ else:
48
+ if assume_straight_pages:
49
+ loc_pred[:, [0, 2]] *= h / w
50
+ else:
51
+ loc_pred[:, :, 0] *= h / w
52
+ elif w > h:
53
+ # x unchanged, dilate y coord
54
+ if symmetric_pad:
55
+ if assume_straight_pages:
56
+ loc_pred[:, [1, 3]] = (loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5
57
+ else:
58
+ loc_pred[:, :, 1] = (loc_pred[:, :, 1] - 0.5) * w / h + 0.5
59
+ else:
60
+ if assume_straight_pages:
61
+ loc_pred[:, [1, 3]] *= w / h
62
+ else:
63
+ loc_pred[:, :, 1] *= w / h
64
+ rectified_preds.append({k: np.clip(loc_pred, 0, 1)})
65
+ return rectified_preds
66
+ return loc_preds
@@ -114,7 +114,7 @@ class DBPostProcessor(DetectionPostProcessor):
114
114
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
115
115
  for contour in contours:
116
116
  # Check whether smallest enclosing bounding box is not too small
117
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
117
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): # type: ignore[index]
118
118
  continue
119
119
  # Compute objectness
120
120
  if self.assume_straight_pages:
@@ -150,10 +150,11 @@ class DBPostProcessor(DetectionPostProcessor):
150
150
  raise AssertionError("When assume straight pages is false a box is a (4, 2) array (polygon)")
151
151
  _box[:, 0] /= width
152
152
  _box[:, 1] /= height
153
- boxes.append(_box)
153
+ # Add score to box as (0, score)
154
+ boxes.append(np.vstack([_box, np.array([0.0, score])]))
154
155
 
155
156
  if not self.assume_straight_pages:
156
- return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
157
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
157
158
  else:
158
159
  return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
159
160
 
@@ -39,7 +39,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
39
39
  "input_shape": (3, 1024, 1024),
40
40
  "mean": (0.798, 0.785, 0.772),
41
41
  "std": (0.264, 0.2749, 0.287),
42
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-81e9b152.pt&src=0",
42
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/db_mobilenet_v3_large-21748dd0.pt&src=0",
43
43
  },
44
44
  }
45
45
 
@@ -273,7 +273,7 @@ class DBNet(_DBNet, nn.Module):
273
273
  dice_map = torch.softmax(out_map, dim=1)
274
274
  else:
275
275
  # compute binary map instead
276
- dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
276
+ dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
277
277
  # Class reduced
278
278
  inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
279
279
  cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
@@ -31,7 +31,7 @@ class FASTPostProcessor(DetectionPostProcessor):
31
31
 
32
32
  def __init__(
33
33
  self,
34
- bin_thresh: float = 0.3,
34
+ bin_thresh: float = 0.1,
35
35
  box_thresh: float = 0.1,
36
36
  assume_straight_pages: bool = True,
37
37
  ) -> None:
@@ -111,7 +111,7 @@ class FASTPostProcessor(DetectionPostProcessor):
111
111
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
112
112
  for contour in contours:
113
113
  # Check whether smallest enclosing bounding box is not too small
114
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
114
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
115
115
  continue
116
116
  # Compute objectness
117
117
  if self.assume_straight_pages:
@@ -138,10 +138,11 @@ class FASTPostProcessor(DetectionPostProcessor):
138
138
  # compute relative box to get rid of img shape
139
139
  _box[:, 0] /= width
140
140
  _box[:, 1] /= height
141
- boxes.append(_box)
141
+ # Add score to box as (0, score)
142
+ boxes.append(np.vstack([_box, np.array([0.0, score])]))
142
143
 
143
144
  if not self.assume_straight_pages:
144
- return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
145
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
145
146
  else:
146
147
  return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
147
148
 
@@ -153,7 +154,7 @@ class _FAST(BaseModel):
153
154
 
154
155
  min_size_box: int = 3
155
156
  assume_straight_pages: bool = True
156
- shrink_ratio = 0.1
157
+ shrink_ratio = 0.4
157
158
 
158
159
  def build_target(
159
160
  self,
@@ -26,19 +26,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
26
26
  "input_shape": (3, 1024, 1024),
27
27
  "mean": (0.798, 0.785, 0.772),
28
28
  "std": (0.264, 0.2749, 0.287),
29
- "url": None,
29
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_tiny-1acac421.pt&src=0",
30
30
  },
31
31
  "fast_small": {
32
32
  "input_shape": (3, 1024, 1024),
33
33
  "mean": (0.798, 0.785, 0.772),
34
34
  "std": (0.264, 0.2749, 0.287),
35
- "url": None,
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_small-10952cc1.pt&src=0",
36
36
  },
37
37
  "fast_base": {
38
38
  "input_shape": (3, 1024, 1024),
39
39
  "mean": (0.798, 0.785, 0.772),
40
40
  "std": (0.264, 0.2749, 0.287),
41
- "url": None,
41
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_base-688a8b34.pt&src=0",
42
42
  },
43
43
  }
44
44
 
@@ -119,7 +119,7 @@ class FAST(_FAST, nn.Module):
119
119
  def __init__(
120
120
  self,
121
121
  feat_extractor: IntermediateLayerGetter,
122
- bin_thresh: float = 0.3,
122
+ bin_thresh: float = 0.1,
123
123
  box_thresh: float = 0.1,
124
124
  dropout_prob: float = 0.1,
125
125
  pooling_size: int = 4, # different from paper performs better on close text-rich images
@@ -29,19 +29,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
29
29
  "input_shape": (1024, 1024, 3),
30
30
  "mean": (0.798, 0.785, 0.772),
31
31
  "std": (0.264, 0.2749, 0.287),
32
- "url": None,
32
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_tiny-959daecb.zip&src=0",
33
33
  },
34
34
  "fast_small": {
35
35
  "input_shape": (1024, 1024, 3),
36
36
  "mean": (0.798, 0.785, 0.772),
37
37
  "std": (0.264, 0.2749, 0.287),
38
- "url": None,
38
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_small-f1617503.zip&src=0",
39
39
  },
40
40
  "fast_base": {
41
41
  "input_shape": (1024, 1024, 3),
42
42
  "mean": (0.798, 0.785, 0.772),
43
43
  "std": (0.264, 0.2749, 0.287),
44
- "url": None,
44
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_base-255e2ac3.zip&src=0",
45
45
  },
46
46
  }
47
47
 
@@ -122,7 +122,7 @@ class FAST(_FAST, keras.Model, NestedObject):
122
122
  def __init__(
123
123
  self,
124
124
  feature_extractor: IntermediateLayerGetter,
125
- bin_thresh: float = 0.3,
125
+ bin_thresh: float = 0.1,
126
126
  box_thresh: float = 0.1,
127
127
  dropout_prob: float = 0.1,
128
128
  pooling_size: int = 4, # different from paper performs better on close text-rich images
@@ -111,7 +111,7 @@ class LinkNetPostProcessor(DetectionPostProcessor):
111
111
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
112
112
  for contour in contours:
113
113
  # Check whether smallest enclosing bounding box is not too small
114
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
114
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
115
115
  continue
116
116
  # Compute objectness
117
117
  if self.assume_straight_pages:
@@ -138,10 +138,11 @@ class LinkNetPostProcessor(DetectionPostProcessor):
138
138
  # compute relative box to get rid of img shape
139
139
  _box[:, 0] /= width
140
140
  _box[:, 1] /= height
141
- boxes.append(_box)
141
+ # Add score to box as (0, score)
142
+ boxes.append(np.vstack([_box, np.array([0.0, score])]))
142
143
 
143
144
  if not self.assume_straight_pages:
144
- return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
145
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
145
146
  else:
146
147
  return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
147
148
 
@@ -9,6 +9,7 @@ import numpy as np
9
9
  import torch
10
10
  from torch import nn
11
11
 
12
+ from doctr.models.detection._utils import _remove_padding
12
13
  from doctr.models.preprocessor import PreProcessor
13
14
  from doctr.models.utils import set_device_and_dtype
14
15
 
@@ -40,6 +41,11 @@ class DetectionPredictor(nn.Module):
40
41
  return_maps: bool = False,
41
42
  **kwargs: Any,
42
43
  ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
44
+ # Extract parameters from the preprocessor
45
+ preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
46
+ symmetric_pad = self.pre_processor.resize.symmetric_pad
47
+ assume_straight_pages = self.model.assume_straight_pages
48
+
43
49
  # Dimension check
44
50
  if any(page.ndim != 3 for page in pages):
45
51
  raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
@@ -52,7 +58,15 @@ class DetectionPredictor(nn.Module):
52
58
  predicted_batches = [
53
59
  self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
54
60
  ]
55
- preds = [pred for batch in predicted_batches for pred in batch["preds"]]
61
+ # Remove padding from loc predictions
62
+ preds = _remove_padding(
63
+ pages, # type: ignore[arg-type]
64
+ [pred for batch in predicted_batches for pred in batch["preds"]],
65
+ preserve_aspect_ratio=preserve_aspect_ratio,
66
+ symmetric_pad=symmetric_pad,
67
+ assume_straight_pages=assume_straight_pages,
68
+ )
69
+
56
70
  if return_maps:
57
71
  seg_maps = [
58
72
  pred.permute(1, 2, 0).detach().cpu().numpy() for batch in predicted_batches for pred in batch["out_map"]
@@ -9,6 +9,7 @@ import numpy as np
9
9
  import tensorflow as tf
10
10
  from tensorflow import keras
11
11
 
12
+ from doctr.models.detection._utils import _remove_padding
12
13
  from doctr.models.preprocessor import PreProcessor
13
14
  from doctr.utils.repr import NestedObject
14
15
 
@@ -40,6 +41,11 @@ class DetectionPredictor(NestedObject):
40
41
  return_maps: bool = False,
41
42
  **kwargs: Any,
42
43
  ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
44
+ # Extract parameters from the preprocessor
45
+ preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
46
+ symmetric_pad = self.pre_processor.resize.symmetric_pad
47
+ assume_straight_pages = self.model.assume_straight_pages
48
+
43
49
  # Dimension check
44
50
  if any(page.ndim != 3 for page in pages):
45
51
  raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
@@ -50,7 +56,15 @@ class DetectionPredictor(NestedObject):
50
56
  for batch in processed_batches
51
57
  ]
52
58
 
53
- preds = [pred for batch in predicted_batches for pred in batch["preds"]]
59
+ # Remove padding from loc predictions
60
+ preds = _remove_padding(
61
+ pages,
62
+ [pred for batch in predicted_batches for pred in batch["preds"]],
63
+ preserve_aspect_ratio=preserve_aspect_ratio,
64
+ symmetric_pad=symmetric_pad,
65
+ assume_straight_pages=assume_straight_pages,
66
+ )
67
+
54
68
  if return_maps:
55
69
  seg_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]]
56
70
  return preds, seg_maps
@@ -8,6 +8,7 @@ from typing import Any, List
8
8
  from doctr.file_utils import is_tf_available, is_torch_available
9
9
 
10
10
  from .. import detection
11
+ from ..detection.fast import reparameterize
11
12
  from ..preprocessor import PreProcessor
12
13
  from .predictor import DetectionPredictor
13
14
 
@@ -51,18 +52,22 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
51
52
  pretrained_backbone=kwargs.get("pretrained_backbone", True),
52
53
  assume_straight_pages=assume_straight_pages,
53
54
  )
55
+ # Reparameterize FAST models by default to lower inference latency and memory usage
56
+ if isinstance(_model, detection.FAST):
57
+ _model = reparameterize(_model)
54
58
  else:
55
59
  if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
56
60
  raise ValueError(f"unknown architecture: {type(arch)}")
57
61
 
58
62
  _model = arch
59
63
  _model.assume_straight_pages = assume_straight_pages
64
+ _model.postprocessor.assume_straight_pages = assume_straight_pages
60
65
 
61
66
  kwargs.pop("pretrained_backbone", None)
62
67
 
63
68
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
64
69
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
65
- kwargs["batch_size"] = kwargs.get("batch_size", 1)
70
+ kwargs["batch_size"] = kwargs.get("batch_size", 2)
66
71
  predictor = DetectionPredictor(
67
72
  PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs),
68
73
  _model,
@@ -71,7 +76,7 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
71
76
 
72
77
 
73
78
  def detection_predictor(
74
- arch: Any = "db_resnet50",
79
+ arch: Any = "fast_base",
75
80
  pretrained: bool = False,
76
81
  assume_straight_pages: bool = True,
77
82
  **kwargs: Any,
@@ -36,7 +36,6 @@ AVAILABLE_ARCHS = {
36
36
  "classification": models.classification.zoo.ARCHS,
37
37
  "detection": models.detection.zoo.ARCHS,
38
38
  "recognition": models.recognition.zoo.ARCHS,
39
- "obj_detection": ["fasterrcnn_mobilenet_v3_large_fpn"] if is_torch_available() else None,
40
39
  }
41
40
 
42
41
 
@@ -110,8 +109,8 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
110
109
 
111
110
  if run_config is None and arch is None:
112
111
  raise ValueError("run_config or arch must be specified")
113
- if task not in ["classification", "detection", "recognition", "obj_detection"]:
114
- raise ValueError("task must be one of classification, detection, recognition, obj_detection")
112
+ if task not in ["classification", "detection", "recognition"]:
113
+ raise ValueError("task must be one of classification, detection, recognition")
115
114
 
116
115
  # default readme
117
116
  readme = textwrap.dedent(
@@ -165,7 +164,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
165
164
  \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}"""
166
165
  )
167
166
 
168
- if arch not in AVAILABLE_ARCHS[task]: # type: ignore
167
+ if arch not in AVAILABLE_ARCHS[task]:
169
168
  raise ValueError(
170
169
  f"Architecture: {arch} for task: {task} not found.\
171
170
  \nAvailable architectures: {AVAILABLE_ARCHS}"
@@ -217,14 +216,6 @@ def from_hub(repo_id: str, **kwargs: Any):
217
216
  model = models.detection.__dict__[arch](pretrained=False)
218
217
  elif task == "recognition":
219
218
  model = models.recognition.__dict__[arch](pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"])
220
- elif task == "obj_detection" and is_torch_available():
221
- model = models.obj_detection.__dict__[arch](
222
- pretrained=False,
223
- image_mean=cfg["mean"],
224
- image_std=cfg["std"],
225
- max_size=cfg["input_shape"][-1],
226
- num_classes=len(cfg["classes"]),
227
- )
228
219
 
229
220
  # update model cfg
230
221
  model.cfg = cfg
@@ -7,7 +7,7 @@ from typing import Any, Optional
7
7
 
8
8
  from doctr.models.builder import KIEDocumentBuilder
9
9
 
10
- from ..classification.predictor import CropOrientationPredictor
10
+ from ..classification.predictor import OrientationPredictor
11
11
  from ..predictor.base import _OCRPredictor
12
12
 
13
13
  __all__ = ["_KIEPredictor"]
@@ -25,10 +25,13 @@ class _KIEPredictor(_OCRPredictor):
25
25
  accordingly. Doing so will improve performances for documents with page-uniform rotations.
26
26
  preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
27
27
  symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
28
+ detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
29
+ page. Doing so will slightly deteriorate the overall latency.
28
30
  kwargs: keyword args of `DocumentBuilder`
29
31
  """
30
32
 
31
- crop_orientation_predictor: Optional[CropOrientationPredictor]
33
+ crop_orientation_predictor: Optional[OrientationPredictor]
34
+ page_orientation_predictor: Optional[OrientationPredictor]
32
35
 
33
36
  def __init__(
34
37
  self,
@@ -36,8 +39,11 @@ class _KIEPredictor(_OCRPredictor):
36
39
  straighten_pages: bool = False,
37
40
  preserve_aspect_ratio: bool = True,
38
41
  symmetric_pad: bool = True,
42
+ detect_orientation: bool = False,
39
43
  **kwargs: Any,
40
44
  ) -> None:
41
- super().__init__(assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs)
45
+ super().__init__(
46
+ assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, detect_orientation, **kwargs
47
+ )
42
48
 
43
49
  self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs)