python-doctr 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/cord.py +10 -1
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +11 -1
  8. doctr/datasets/generator/base.py +6 -5
  9. doctr/datasets/ic03.py +11 -1
  10. doctr/datasets/ic13.py +10 -1
  11. doctr/datasets/iiit5k.py +26 -16
  12. doctr/datasets/imgur5k.py +11 -2
  13. doctr/datasets/loader.py +1 -6
  14. doctr/datasets/sroie.py +11 -1
  15. doctr/datasets/svhn.py +11 -1
  16. doctr/datasets/svt.py +11 -1
  17. doctr/datasets/synthtext.py +11 -1
  18. doctr/datasets/utils.py +9 -3
  19. doctr/datasets/vocabs.py +15 -4
  20. doctr/datasets/wildreceipt.py +12 -1
  21. doctr/file_utils.py +45 -12
  22. doctr/io/elements.py +52 -10
  23. doctr/io/html.py +2 -2
  24. doctr/io/image/pytorch.py +6 -8
  25. doctr/io/image/tensorflow.py +1 -1
  26. doctr/io/pdf.py +5 -2
  27. doctr/io/reader.py +6 -0
  28. doctr/models/__init__.py +0 -1
  29. doctr/models/_utils.py +57 -20
  30. doctr/models/builder.py +73 -15
  31. doctr/models/classification/magc_resnet/tensorflow.py +13 -6
  32. doctr/models/classification/mobilenet/pytorch.py +47 -9
  33. doctr/models/classification/mobilenet/tensorflow.py +51 -14
  34. doctr/models/classification/predictor/pytorch.py +28 -17
  35. doctr/models/classification/predictor/tensorflow.py +26 -16
  36. doctr/models/classification/resnet/tensorflow.py +21 -8
  37. doctr/models/classification/textnet/pytorch.py +3 -3
  38. doctr/models/classification/textnet/tensorflow.py +11 -5
  39. doctr/models/classification/vgg/tensorflow.py +9 -3
  40. doctr/models/classification/vit/tensorflow.py +10 -4
  41. doctr/models/classification/zoo.py +55 -19
  42. doctr/models/detection/_utils/__init__.py +1 -0
  43. doctr/models/detection/_utils/base.py +66 -0
  44. doctr/models/detection/differentiable_binarization/base.py +4 -3
  45. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  46. doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
  47. doctr/models/detection/fast/base.py +6 -5
  48. doctr/models/detection/fast/pytorch.py +4 -4
  49. doctr/models/detection/fast/tensorflow.py +15 -12
  50. doctr/models/detection/linknet/base.py +4 -3
  51. doctr/models/detection/linknet/tensorflow.py +23 -11
  52. doctr/models/detection/predictor/pytorch.py +15 -1
  53. doctr/models/detection/predictor/tensorflow.py +17 -3
  54. doctr/models/detection/zoo.py +7 -2
  55. doctr/models/factory/hub.py +8 -18
  56. doctr/models/kie_predictor/base.py +13 -3
  57. doctr/models/kie_predictor/pytorch.py +45 -20
  58. doctr/models/kie_predictor/tensorflow.py +44 -17
  59. doctr/models/modules/layers/pytorch.py +2 -3
  60. doctr/models/modules/layers/tensorflow.py +6 -8
  61. doctr/models/modules/transformer/pytorch.py +2 -2
  62. doctr/models/modules/transformer/tensorflow.py +0 -2
  63. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  64. doctr/models/modules/vision_transformer/tensorflow.py +1 -1
  65. doctr/models/predictor/base.py +97 -58
  66. doctr/models/predictor/pytorch.py +35 -20
  67. doctr/models/predictor/tensorflow.py +35 -18
  68. doctr/models/preprocessor/pytorch.py +4 -4
  69. doctr/models/preprocessor/tensorflow.py +3 -2
  70. doctr/models/recognition/crnn/tensorflow.py +8 -6
  71. doctr/models/recognition/master/pytorch.py +2 -2
  72. doctr/models/recognition/master/tensorflow.py +9 -4
  73. doctr/models/recognition/parseq/pytorch.py +4 -3
  74. doctr/models/recognition/parseq/tensorflow.py +14 -11
  75. doctr/models/recognition/sar/pytorch.py +7 -6
  76. doctr/models/recognition/sar/tensorflow.py +10 -12
  77. doctr/models/recognition/vitstr/pytorch.py +1 -1
  78. doctr/models/recognition/vitstr/tensorflow.py +9 -4
  79. doctr/models/recognition/zoo.py +1 -1
  80. doctr/models/utils/pytorch.py +1 -1
  81. doctr/models/utils/tensorflow.py +15 -15
  82. doctr/models/zoo.py +2 -2
  83. doctr/py.typed +0 -0
  84. doctr/transforms/functional/base.py +1 -1
  85. doctr/transforms/functional/pytorch.py +5 -5
  86. doctr/transforms/modules/base.py +37 -15
  87. doctr/transforms/modules/pytorch.py +73 -14
  88. doctr/transforms/modules/tensorflow.py +78 -19
  89. doctr/utils/fonts.py +7 -5
  90. doctr/utils/geometry.py +141 -31
  91. doctr/utils/metrics.py +34 -175
  92. doctr/utils/reconstitution.py +212 -0
  93. doctr/utils/visualization.py +5 -118
  94. doctr/version.py +1 -1
  95. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/METADATA +85 -81
  96. python_doctr-0.10.0.dist-info/RECORD +173 -0
  97. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
  98. doctr/models/artefacts/__init__.py +0 -2
  99. doctr/models/artefacts/barcode.py +0 -74
  100. doctr/models/artefacts/face.py +0 -63
  101. doctr/models/obj_detection/__init__.py +0 -1
  102. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  103. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  104. python_doctr-0.8.1.dist-info/RECORD +0 -173
  105. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
  106. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
  107. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
@@ -10,10 +10,10 @@ import torch
10
10
  from torch import nn
11
11
 
12
12
  from doctr.io.elements import Document
13
- from doctr.models._utils import estimate_orientation, get_language, invert_data_structure
13
+ from doctr.models._utils import get_language, invert_data_structure
14
14
  from doctr.models.detection.predictor import DetectionPredictor
15
15
  from doctr.models.recognition.predictor import RecognitionPredictor
16
- from doctr.utils.geometry import rotate_image
16
+ from doctr.utils.geometry import detach_scores
17
17
 
18
18
  from .base import _KIEPredictor
19
19
 
@@ -55,7 +55,13 @@ class KIEPredictor(nn.Module, _KIEPredictor):
55
55
  self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
56
56
  self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
57
57
  _KIEPredictor.__init__(
58
- self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
58
+ self,
59
+ assume_straight_pages,
60
+ straighten_pages,
61
+ preserve_aspect_ratio,
62
+ symmetric_pad,
63
+ detect_orientation,
64
+ **kwargs,
59
65
  )
60
66
  self.detect_orientation = detect_orientation
61
67
  self.detect_language = detect_language
@@ -83,29 +89,34 @@ class KIEPredictor(nn.Module, _KIEPredictor):
83
89
  for out_map in out_maps
84
90
  ]
85
91
  if self.detect_orientation:
86
- origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
92
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
87
93
  orientations = [
88
- {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
94
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
89
95
  ]
90
96
  else:
91
97
  orientations = None
98
+ general_pages_orientations = None
99
+ origin_pages_orientations = None
92
100
  if self.straighten_pages:
93
- origin_page_orientations = (
94
- origin_page_orientations
95
- if self.detect_orientation
96
- else [estimate_orientation(seq_map) for seq_map in seg_maps]
97
- )
98
- pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
101
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
102
+ # update page shapes after straightening
103
+ origin_page_shapes = [page.shape[:2] for page in pages]
104
+
99
105
  # Forward again to get predictions on straight pages
100
106
  loc_preds = self.det_predictor(pages, **kwargs)
101
107
 
102
108
  dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
109
+
110
+ # Detach objectness scores from loc_preds
111
+ objectness_scores = {}
112
+ for class_name, det_preds in dict_loc_preds.items():
113
+ _loc_preds, _scores = detach_scores(det_preds)
114
+ dict_loc_preds[class_name] = _loc_preds
115
+ objectness_scores[class_name] = _scores
116
+
103
117
  # Check whether crop mode should be switched to channels first
104
118
  channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
105
119
 
106
- # Rectify crops if aspect ratio
107
- dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}
108
-
109
120
  # Apply hooks to loc_preds if any
110
121
  for hook in self.hooks:
111
122
  dict_loc_preds = hook(dict_loc_preds)
@@ -114,32 +125,44 @@ class KIEPredictor(nn.Module, _KIEPredictor):
114
125
  crops = {}
115
126
  for class_name in dict_loc_preds.keys():
116
127
  crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
117
- pages,
128
+ pages, # type: ignore[arg-type]
118
129
  dict_loc_preds[class_name],
119
130
  channels_last=channels_last,
120
131
  assume_straight_pages=self.assume_straight_pages,
132
+ assume_horizontal=self._page_orientation_disabled,
121
133
  )
122
134
  # Rectify crop orientation
135
+ crop_orientations: Any = {}
123
136
  if not self.assume_straight_pages:
124
137
  for class_name in dict_loc_preds.keys():
125
- crops[class_name], dict_loc_preds[class_name] = self._rectify_crops(
138
+ crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
126
139
  crops[class_name], dict_loc_preds[class_name]
127
140
  )
141
+ crop_orientations[class_name] = [
142
+ {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
143
+ ]
144
+
128
145
  # Identify character sequences
129
146
  word_preds = {
130
147
  k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
131
148
  for k, crop_value in crops.items()
132
149
  }
150
+ if not crop_orientations:
151
+ crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
133
152
 
134
153
  boxes: Dict = {}
135
154
  text_preds: Dict = {}
155
+ word_crop_orientations: Dict = {}
136
156
  for class_name in dict_loc_preds.keys():
137
- boxes[class_name], text_preds[class_name] = self._process_predictions(
138
- dict_loc_preds[class_name], word_preds[class_name]
157
+ boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
158
+ dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
139
159
  )
140
160
 
141
161
  boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
162
+ objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
142
163
  text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
164
+ crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
165
+
143
166
  if self.detect_language:
144
167
  languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
145
168
  languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
@@ -147,10 +170,12 @@ class KIEPredictor(nn.Module, _KIEPredictor):
147
170
  languages_dict = None
148
171
 
149
172
  out = self.doc_builder(
150
- pages,
173
+ pages, # type: ignore[arg-type]
151
174
  boxes_per_page,
175
+ objectness_scores_per_page,
152
176
  text_preds_per_page,
153
- origin_page_shapes,
177
+ origin_page_shapes, # type: ignore[arg-type]
178
+ crop_orientations_per_page,
154
179
  orientations,
155
180
  languages_dict,
156
181
  )
@@ -9,10 +9,10 @@ import numpy as np
9
9
  import tensorflow as tf
10
10
 
11
11
  from doctr.io.elements import Document
12
- from doctr.models._utils import estimate_orientation, get_language, invert_data_structure
12
+ from doctr.models._utils import get_language, invert_data_structure
13
13
  from doctr.models.detection.predictor import DetectionPredictor
14
14
  from doctr.models.recognition.predictor import RecognitionPredictor
15
- from doctr.utils.geometry import rotate_image
15
+ from doctr.utils.geometry import detach_scores
16
16
  from doctr.utils.repr import NestedObject
17
17
 
18
18
  from .base import _KIEPredictor
@@ -56,7 +56,13 @@ class KIEPredictor(NestedObject, _KIEPredictor):
56
56
  self.det_predictor = det_predictor
57
57
  self.reco_predictor = reco_predictor
58
58
  _KIEPredictor.__init__(
59
- self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
59
+ self,
60
+ assume_straight_pages,
61
+ straighten_pages,
62
+ preserve_aspect_ratio,
63
+ symmetric_pad,
64
+ detect_orientation,
65
+ **kwargs,
60
66
  )
61
67
  self.detect_orientation = detect_orientation
62
68
  self.detect_language = detect_language
@@ -83,25 +89,30 @@ class KIEPredictor(NestedObject, _KIEPredictor):
83
89
  for out_map in out_maps
84
90
  ]
85
91
  if self.detect_orientation:
86
- origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
92
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
87
93
  orientations = [
88
- {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
94
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
89
95
  ]
90
96
  else:
91
97
  orientations = None
98
+ general_pages_orientations = None
99
+ origin_pages_orientations = None
92
100
  if self.straighten_pages:
93
- origin_page_orientations = (
94
- origin_page_orientations
95
- if self.detect_orientation
96
- else [estimate_orientation(seq_map) for seq_map in seg_maps]
97
- )
98
- pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
101
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
102
+ # update page shapes after straightening
103
+ origin_page_shapes = [page.shape[:2] for page in pages]
104
+
99
105
  # Forward again to get predictions on straight pages
100
106
  loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
101
107
 
102
108
  dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
103
- # Rectify crops if aspect ratio
104
- dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}
109
+
110
+ # Detach objectness scores from loc_preds
111
+ objectness_scores = {}
112
+ for class_name, det_preds in dict_loc_preds.items():
113
+ _loc_preds, _scores = detach_scores(det_preds)
114
+ dict_loc_preds[class_name] = _loc_preds
115
+ objectness_scores[class_name] = _scores
105
116
 
106
117
  # Apply hooks to loc_preds if any
107
118
  for hook in self.hooks:
@@ -111,30 +122,44 @@ class KIEPredictor(NestedObject, _KIEPredictor):
111
122
  crops = {}
112
123
  for class_name in dict_loc_preds.keys():
113
124
  crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
114
- pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages
125
+ pages,
126
+ dict_loc_preds[class_name],
127
+ channels_last=True,
128
+ assume_straight_pages=self.assume_straight_pages,
129
+ assume_horizontal=self._page_orientation_disabled,
115
130
  )
131
+
116
132
  # Rectify crop orientation
133
+ crop_orientations: Any = {}
117
134
  if not self.assume_straight_pages:
118
135
  for class_name in dict_loc_preds.keys():
119
- crops[class_name], dict_loc_preds[class_name] = self._rectify_crops(
136
+ crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
120
137
  crops[class_name], dict_loc_preds[class_name]
121
138
  )
139
+ crop_orientations[class_name] = [
140
+ {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
141
+ ]
122
142
 
123
143
  # Identify character sequences
124
144
  word_preds = {
125
145
  k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
126
146
  for k, crop_value in crops.items()
127
147
  }
148
+ if not crop_orientations:
149
+ crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
128
150
 
129
151
  boxes: Dict = {}
130
152
  text_preds: Dict = {}
153
+ word_crop_orientations: Dict = {}
131
154
  for class_name in dict_loc_preds.keys():
132
- boxes[class_name], text_preds[class_name] = self._process_predictions(
133
- dict_loc_preds[class_name], word_preds[class_name]
155
+ boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
156
+ dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
134
157
  )
135
158
 
136
159
  boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
160
+ objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
137
161
  text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
162
+ crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
138
163
 
139
164
  if self.detect_language:
140
165
  languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
@@ -145,8 +170,10 @@ class KIEPredictor(NestedObject, _KIEPredictor):
145
170
  out = self.doc_builder(
146
171
  pages,
147
172
  boxes_per_page,
173
+ objectness_scores_per_page,
148
174
  text_preds_per_page,
149
175
  origin_page_shapes, # type: ignore[arg-type]
176
+ crop_orientations_per_page,
150
177
  orientations,
151
178
  languages_dict,
152
179
  )
@@ -87,7 +87,7 @@ class FASTConvLayer(nn.Module):
87
87
  horizontal_outputs = (
88
88
  self.hor_bn(self.hor_conv(x)) if self.hor_bn is not None and self.hor_conv is not None else 0
89
89
  )
90
- id_out = self.rbr_identity(x) if self.rbr_identity is not None and self.ver_bn is not None else 0
90
+ id_out = self.rbr_identity(x) if self.rbr_identity is not None else 0
91
91
 
92
92
  return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
93
93
 
@@ -106,7 +106,7 @@ class FASTConvLayer(nn.Module):
106
106
  id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
107
107
  self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
108
108
  kernel = self.id_tensor
109
- std = (identity.running_var + identity.eps).sqrt() # type: ignore[attr-defined]
109
+ std = (identity.running_var + identity.eps).sqrt()
110
110
  t = (identity.weight / std).reshape(-1, 1, 1, 1)
111
111
  return kernel * t, identity.bias - identity.running_mean * identity.weight / std
112
112
 
@@ -155,7 +155,6 @@ class FASTConvLayer(nn.Module):
155
155
  )
156
156
  self.fused_conv.weight.data = kernel
157
157
  self.fused_conv.bias.data = bias # type: ignore[union-attr]
158
- self.deploy = True
159
158
  for para in self.parameters():
160
159
  para.detach_()
161
160
  for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
@@ -97,7 +97,7 @@ class FASTConvLayer(layers.Layer, NestedObject):
97
97
  if self.hor_bn is not None and self.hor_conv is not None
98
98
  else 0
99
99
  )
100
- id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None and self.ver_bn is not None else 0
100
+ id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0
101
101
 
102
102
  return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
103
103
 
@@ -110,14 +110,14 @@ class FASTConvLayer(layers.Layer, NestedObject):
110
110
  return 0, 0
111
111
  if not hasattr(self, "id_tensor"):
112
112
  input_dim = self.in_channels // self.groups
113
- kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32)
113
+ kernel_value = np.zeros((1, 1, input_dim, self.in_channels), dtype=np.float32)
114
114
  for i in range(self.in_channels):
115
- kernel_value[i, i % input_dim, 0, 0] = 1
115
+ kernel_value[0, 0, i % input_dim, i] = 1
116
116
  id_tensor = tf.constant(kernel_value, dtype=tf.float32)
117
117
  self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
118
118
  kernel = self.id_tensor
119
119
  std = tf.sqrt(identity.moving_variance + identity.epsilon)
120
- t = tf.reshape(identity.gamma / std, (-1, 1, 1, 1))
120
+ t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
121
121
  return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
122
122
 
123
123
  def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]:
@@ -138,18 +138,16 @@ class FASTConvLayer(layers.Layer, NestedObject):
138
138
  else:
139
139
  kernel_1xn, bias_1xn = 0, 0
140
140
  kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
141
- if not isinstance(kernel_id, int):
142
- kernel_id = tf.transpose(kernel_id, (2, 3, 0, 1))
143
141
  kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
144
142
  bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
145
143
  return kernel_mxn, bias_mxn
146
144
 
147
145
  def _pad_to_mxn_tensor(self, kernel: tf.Tensor) -> tf.Tensor:
148
146
  kernel_height, kernel_width = self.converted_ks
149
- height, width = kernel.shape[2:]
147
+ height, width = kernel.shape[:2]
150
148
  pad_left_right = tf.maximum(0, (kernel_width - width) // 2)
151
149
  pad_top_down = tf.maximum(0, (kernel_height - height) // 2)
152
- return tf.pad(kernel, [[0, 0], [0, 0], [pad_top_down, pad_top_down], [pad_left_right, pad_left_right]])
150
+ return tf.pad(kernel, [[pad_top_down, pad_top_down], [pad_left_right, pad_left_right], [0, 0], [0, 0]])
153
151
 
154
152
  def reparameterize_layer(self):
155
153
  kernel, bias = self._get_equivalent_kernel_bias()
@@ -51,8 +51,8 @@ def scaled_dot_product_attention(
51
51
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
52
52
  if mask is not None:
53
53
  # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
54
- scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined]
55
- p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload]
54
+ scores = scores.masked_fill(mask == 0, float("-inf"))
55
+ p_attn = torch.softmax(scores, dim=-1)
56
56
  return torch.matmul(p_attn, value), p_attn
57
57
 
58
58
 
@@ -13,8 +13,6 @@ from doctr.utils.repr import NestedObject
13
13
 
14
14
  __all__ = ["Decoder", "PositionalEncoding", "EncoderBlock", "PositionwiseFeedForward", "MultiHeadAttention"]
15
15
 
16
- tf.config.run_functions_eagerly(True)
17
-
18
16
 
19
17
  class PositionalEncoding(layers.Layer, NestedObject):
20
18
  """Compute positional encoding"""
@@ -20,7 +20,7 @@ class PatchEmbedding(nn.Module):
20
20
  channels, height, width = input_shape
21
21
  self.patch_size = patch_size
22
22
  self.interpolate = True if patch_size[0] == patch_size[1] else False
23
- self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)])
23
+ self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size))
24
24
  self.num_patches = self.grid_size[0] * self.grid_size[1]
25
25
 
26
26
  self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
@@ -22,7 +22,7 @@ class PatchEmbedding(layers.Layer, NestedObject):
22
22
  height, width, _ = input_shape
23
23
  self.patch_size = patch_size
24
24
  self.interpolate = True if patch_size[0] == patch_size[1] else False
25
- self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)])
25
+ self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size))
26
26
  self.num_patches = self.grid_size[0] * self.grid_size[1]
27
27
 
28
28
  self.cls_token = self.add_weight(shape=(1, 1, embed_dim), initializer="zeros", trainable=True, name="cls_token")
@@ -3,16 +3,16 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, List, Optional, Tuple
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple
7
7
 
8
8
  import numpy as np
9
9
 
10
10
  from doctr.models.builder import DocumentBuilder
11
- from doctr.utils.geometry import extract_crops, extract_rcrops
11
+ from doctr.utils.geometry import extract_crops, extract_rcrops, remove_image_padding, rotate_image
12
12
 
13
- from .._utils import rectify_crops, rectify_loc_preds
14
- from ..classification import crop_orientation_predictor
15
- from ..classification.predictor import CropOrientationPredictor
13
+ from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
14
+ from ..classification import crop_orientation_predictor, page_orientation_predictor
15
+ from ..classification.predictor import OrientationPredictor
16
16
 
17
17
  __all__ = ["_OCRPredictor"]
18
18
 
@@ -29,10 +29,13 @@ class _OCRPredictor:
29
29
  accordingly. Doing so will improve performances for documents with page-uniform rotations.
30
30
  preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
31
31
  symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
32
+ detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
33
+ page. Doing so will slightly deteriorate the overall latency.
32
34
  **kwargs: keyword args of `DocumentBuilder`
33
35
  """
34
36
 
35
- crop_orientation_predictor: Optional[CropOrientationPredictor]
37
+ crop_orientation_predictor: Optional[OrientationPredictor]
38
+ page_orientation_predictor: Optional[OrientationPredictor]
36
39
 
37
40
  def __init__(
38
41
  self,
@@ -40,29 +43,93 @@ class _OCRPredictor:
40
43
  straighten_pages: bool = False,
41
44
  preserve_aspect_ratio: bool = True,
42
45
  symmetric_pad: bool = True,
46
+ detect_orientation: bool = False,
43
47
  **kwargs: Any,
44
48
  ) -> None:
45
49
  self.assume_straight_pages = assume_straight_pages
46
50
  self.straighten_pages = straighten_pages
47
- self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True)
51
+ self._page_orientation_disabled = kwargs.pop("disable_page_orientation", False)
52
+ self._crop_orientation_disabled = kwargs.pop("disable_crop_orientation", False)
53
+ self.crop_orientation_predictor = (
54
+ None
55
+ if assume_straight_pages
56
+ else crop_orientation_predictor(pretrained=True, disabled=self._crop_orientation_disabled)
57
+ )
58
+ self.page_orientation_predictor = (
59
+ page_orientation_predictor(pretrained=True, disabled=self._page_orientation_disabled)
60
+ if detect_orientation or straighten_pages or not assume_straight_pages
61
+ else None
62
+ )
48
63
  self.doc_builder = DocumentBuilder(**kwargs)
49
64
  self.preserve_aspect_ratio = preserve_aspect_ratio
50
65
  self.symmetric_pad = symmetric_pad
51
66
  self.hooks: List[Callable] = []
52
67
 
68
+ def _general_page_orientations(
69
+ self,
70
+ pages: List[np.ndarray],
71
+ ) -> List[Tuple[int, float]]:
72
+ _, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
73
+ # Flatten to list of tuples with (value, confidence)
74
+ page_orientations = [
75
+ (orientation, prob)
76
+ for page_classes, page_probs in zip(classes, probs)
77
+ for orientation, prob in zip(page_classes, page_probs)
78
+ ]
79
+ return page_orientations
80
+
81
+ def _get_orientations(
82
+ self, pages: List[np.ndarray], seg_maps: List[np.ndarray]
83
+ ) -> Tuple[List[Tuple[int, float]], List[int]]:
84
+ general_pages_orientations = self._general_page_orientations(pages)
85
+ origin_page_orientations = [
86
+ estimate_orientation(seq_map, general_orientation)
87
+ for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
88
+ ]
89
+ return general_pages_orientations, origin_page_orientations
90
+
91
+ def _straighten_pages(
92
+ self,
93
+ pages: List[np.ndarray],
94
+ seg_maps: List[np.ndarray],
95
+ general_pages_orientations: Optional[List[Tuple[int, float]]] = None,
96
+ origin_pages_orientations: Optional[List[int]] = None,
97
+ ) -> List[np.ndarray]:
98
+ general_pages_orientations = (
99
+ general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
100
+ )
101
+ origin_pages_orientations = (
102
+ origin_pages_orientations
103
+ if origin_pages_orientations
104
+ else [
105
+ estimate_orientation(seq_map, general_orientation)
106
+ for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
107
+ ]
108
+ )
109
+ return [
110
+ # expand if height and width are not equal, then remove the padding
111
+ remove_image_padding(rotate_image(page, angle, expand=page.shape[0] != page.shape[1]))
112
+ for page, angle in zip(pages, origin_pages_orientations)
113
+ ]
114
+
53
115
  @staticmethod
54
116
  def _generate_crops(
55
117
  pages: List[np.ndarray],
56
118
  loc_preds: List[np.ndarray],
57
119
  channels_last: bool,
58
120
  assume_straight_pages: bool = False,
121
+ assume_horizontal: bool = False,
59
122
  ) -> List[List[np.ndarray]]:
60
- extraction_fn = extract_crops if assume_straight_pages else extract_rcrops
61
-
62
- crops = [
63
- extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator]
64
- for page, _boxes in zip(pages, loc_preds)
65
- ]
123
+ if assume_straight_pages:
124
+ crops = [
125
+ extract_crops(page, _boxes[:, :4], channels_last=channels_last)
126
+ for page, _boxes in zip(pages, loc_preds)
127
+ ]
128
+ else:
129
+ crops = [
130
+ extract_rcrops(page, _boxes[:, :4], channels_last=channels_last, assume_horizontal=assume_horizontal)
131
+ for page, _boxes in zip(pages, loc_preds)
132
+ ]
66
133
  return crops
67
134
 
68
135
  @staticmethod
@@ -71,8 +138,9 @@ class _OCRPredictor:
71
138
  loc_preds: List[np.ndarray],
72
139
  channels_last: bool,
73
140
  assume_straight_pages: bool = False,
141
+ assume_horizontal: bool = False,
74
142
  ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
75
- crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages)
143
+ crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)
76
144
 
77
145
  # Avoid sending zero-sized crops
78
146
  is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
@@ -88,68 +156,39 @@ class _OCRPredictor:
88
156
  self,
89
157
  crops: List[List[np.ndarray]],
90
158
  loc_preds: List[np.ndarray],
91
- ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
159
+ ) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]:
92
160
  # Work at a page level
93
- orientations = [self.crop_orientation_predictor(page_crops) for page_crops in crops] # type: ignore[misc]
161
+ orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc]
94
162
  rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
95
163
  rect_loc_preds = [
96
164
  rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds
97
165
  for page_loc_preds, orientation in zip(loc_preds, orientations)
98
166
  ]
99
- return rect_crops, rect_loc_preds # type: ignore[return-value]
100
-
101
- def _remove_padding(
102
- self,
103
- pages: List[np.ndarray],
104
- loc_preds: List[np.ndarray],
105
- ) -> List[np.ndarray]:
106
- if self.preserve_aspect_ratio:
107
- # Rectify loc_preds to remove padding
108
- rectified_preds = []
109
- for page, loc_pred in zip(pages, loc_preds):
110
- h, w = page.shape[0], page.shape[1]
111
- if h > w:
112
- # y unchanged, dilate x coord
113
- if self.symmetric_pad:
114
- if self.assume_straight_pages:
115
- loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1)
116
- else:
117
- loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1)
118
- else:
119
- if self.assume_straight_pages:
120
- loc_pred[:, [0, 2]] *= h / w
121
- else:
122
- loc_pred[:, :, 0] *= h / w
123
- elif w > h:
124
- # x unchanged, dilate y coord
125
- if self.symmetric_pad:
126
- if self.assume_straight_pages:
127
- loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1)
128
- else:
129
- loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1)
130
- else:
131
- if self.assume_straight_pages:
132
- loc_pred[:, [1, 3]] *= w / h
133
- else:
134
- loc_pred[:, :, 1] *= w / h
135
- rectified_preds.append(loc_pred)
136
- return rectified_preds
137
- return loc_preds
167
+ # Flatten to list of tuples with (value, confidence)
168
+ crop_orientations = [
169
+ (orientation, prob)
170
+ for page_classes, page_probs in zip(classes, probs)
171
+ for orientation, prob in zip(page_classes, page_probs)
172
+ ]
173
+ return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value]
138
174
 
139
175
  @staticmethod
140
176
  def _process_predictions(
141
177
  loc_preds: List[np.ndarray],
142
178
  word_preds: List[Tuple[str, float]],
143
- ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]]]:
179
+ crop_orientations: List[Dict[str, Any]],
180
+ ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]:
144
181
  text_preds = []
182
+ crop_orientation_preds = []
145
183
  if len(loc_preds) > 0:
146
- # Text
184
+ # Text & crop orientation predictions at page level
147
185
  _idx = 0
148
186
  for page_boxes in loc_preds:
149
187
  text_preds.append(word_preds[_idx : _idx + page_boxes.shape[0]])
188
+ crop_orientation_preds.append(crop_orientations[_idx : _idx + page_boxes.shape[0]])
150
189
  _idx += page_boxes.shape[0]
151
190
 
152
- return loc_preds, text_preds
191
+ return loc_preds, text_preds, crop_orientation_preds
153
192
 
154
193
  def add_hook(self, hook: Callable) -> None:
155
194
  """Add a hook to the predictor