python-doctr 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +1 -1
- doctr/contrib/__init__.py +0 -0
- doctr/contrib/artefacts.py +131 -0
- doctr/contrib/base.py +105 -0
- doctr/datasets/cord.py +10 -1
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +11 -1
- doctr/datasets/generator/base.py +6 -5
- doctr/datasets/ic03.py +11 -1
- doctr/datasets/ic13.py +10 -1
- doctr/datasets/iiit5k.py +26 -16
- doctr/datasets/imgur5k.py +11 -2
- doctr/datasets/loader.py +1 -6
- doctr/datasets/sroie.py +11 -1
- doctr/datasets/svhn.py +11 -1
- doctr/datasets/svt.py +11 -1
- doctr/datasets/synthtext.py +11 -1
- doctr/datasets/utils.py +9 -3
- doctr/datasets/vocabs.py +15 -4
- doctr/datasets/wildreceipt.py +12 -1
- doctr/file_utils.py +45 -12
- doctr/io/elements.py +52 -10
- doctr/io/html.py +2 -2
- doctr/io/image/pytorch.py +6 -8
- doctr/io/image/tensorflow.py +1 -1
- doctr/io/pdf.py +5 -2
- doctr/io/reader.py +6 -0
- doctr/models/__init__.py +0 -1
- doctr/models/_utils.py +57 -20
- doctr/models/builder.py +73 -15
- doctr/models/classification/magc_resnet/tensorflow.py +13 -6
- doctr/models/classification/mobilenet/pytorch.py +47 -9
- doctr/models/classification/mobilenet/tensorflow.py +51 -14
- doctr/models/classification/predictor/pytorch.py +28 -17
- doctr/models/classification/predictor/tensorflow.py +26 -16
- doctr/models/classification/resnet/tensorflow.py +21 -8
- doctr/models/classification/textnet/pytorch.py +3 -3
- doctr/models/classification/textnet/tensorflow.py +11 -5
- doctr/models/classification/vgg/tensorflow.py +9 -3
- doctr/models/classification/vit/tensorflow.py +10 -4
- doctr/models/classification/zoo.py +55 -19
- doctr/models/detection/_utils/__init__.py +1 -0
- doctr/models/detection/_utils/base.py +66 -0
- doctr/models/detection/differentiable_binarization/base.py +4 -3
- doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
- doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
- doctr/models/detection/fast/base.py +6 -5
- doctr/models/detection/fast/pytorch.py +4 -4
- doctr/models/detection/fast/tensorflow.py +15 -12
- doctr/models/detection/linknet/base.py +4 -3
- doctr/models/detection/linknet/tensorflow.py +23 -11
- doctr/models/detection/predictor/pytorch.py +15 -1
- doctr/models/detection/predictor/tensorflow.py +17 -3
- doctr/models/detection/zoo.py +7 -2
- doctr/models/factory/hub.py +8 -18
- doctr/models/kie_predictor/base.py +13 -3
- doctr/models/kie_predictor/pytorch.py +45 -20
- doctr/models/kie_predictor/tensorflow.py +44 -17
- doctr/models/modules/layers/pytorch.py +2 -3
- doctr/models/modules/layers/tensorflow.py +6 -8
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/transformer/tensorflow.py +0 -2
- doctr/models/modules/vision_transformer/pytorch.py +1 -1
- doctr/models/modules/vision_transformer/tensorflow.py +1 -1
- doctr/models/predictor/base.py +97 -58
- doctr/models/predictor/pytorch.py +35 -20
- doctr/models/predictor/tensorflow.py +35 -18
- doctr/models/preprocessor/pytorch.py +4 -4
- doctr/models/preprocessor/tensorflow.py +3 -2
- doctr/models/recognition/crnn/tensorflow.py +8 -6
- doctr/models/recognition/master/pytorch.py +2 -2
- doctr/models/recognition/master/tensorflow.py +9 -4
- doctr/models/recognition/parseq/pytorch.py +4 -3
- doctr/models/recognition/parseq/tensorflow.py +14 -11
- doctr/models/recognition/sar/pytorch.py +7 -6
- doctr/models/recognition/sar/tensorflow.py +10 -12
- doctr/models/recognition/vitstr/pytorch.py +1 -1
- doctr/models/recognition/vitstr/tensorflow.py +9 -4
- doctr/models/recognition/zoo.py +1 -1
- doctr/models/utils/pytorch.py +1 -1
- doctr/models/utils/tensorflow.py +15 -15
- doctr/models/zoo.py +2 -2
- doctr/py.typed +0 -0
- doctr/transforms/functional/base.py +1 -1
- doctr/transforms/functional/pytorch.py +5 -5
- doctr/transforms/modules/base.py +37 -15
- doctr/transforms/modules/pytorch.py +73 -14
- doctr/transforms/modules/tensorflow.py +78 -19
- doctr/utils/fonts.py +7 -5
- doctr/utils/geometry.py +141 -31
- doctr/utils/metrics.py +34 -175
- doctr/utils/reconstitution.py +212 -0
- doctr/utils/visualization.py +5 -118
- doctr/version.py +1 -1
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/METADATA +85 -81
- python_doctr-0.10.0.dist-info/RECORD +173 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
- doctr/models/artefacts/__init__.py +0 -2
- doctr/models/artefacts/barcode.py +0 -74
- doctr/models/artefacts/face.py +0 -63
- doctr/models/obj_detection/__init__.py +0 -1
- doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
- python_doctr-0.8.1.dist-info/RECORD +0 -173
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
doctr/models/_utils.py
CHANGED
|
@@ -11,6 +11,8 @@ import cv2
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from langdetect import LangDetectException, detect_langs
|
|
13
13
|
|
|
14
|
+
from doctr.utils.geometry import rotate_image
|
|
15
|
+
|
|
14
16
|
__all__ = ["estimate_orientation", "get_language", "invert_data_structure"]
|
|
15
17
|
|
|
16
18
|
|
|
@@ -29,56 +31,91 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
|
|
|
29
31
|
return max(w / h, h / w)
|
|
30
32
|
|
|
31
33
|
|
|
32
|
-
def estimate_orientation(
|
|
34
|
+
def estimate_orientation(
|
|
35
|
+
img: np.ndarray,
|
|
36
|
+
general_page_orientation: Optional[Tuple[int, float]] = None,
|
|
37
|
+
n_ct: int = 70,
|
|
38
|
+
ratio_threshold_for_lines: float = 3,
|
|
39
|
+
min_confidence: float = 0.2,
|
|
40
|
+
lower_area: int = 100,
|
|
41
|
+
) -> int:
|
|
33
42
|
"""Estimate the angle of the general document orientation based on the
|
|
34
43
|
lines of the document and the assumption that they should be horizontal.
|
|
35
44
|
|
|
36
45
|
Args:
|
|
37
46
|
----
|
|
38
47
|
img: the img or bitmap to analyze (H, W, C)
|
|
48
|
+
general_page_orientation: the general orientation of the page (angle [0, 90, 180, 270 (-90)], confidence)
|
|
49
|
+
estimated by a model
|
|
39
50
|
n_ct: the number of contours used for the orientation estimation
|
|
40
51
|
ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
|
|
52
|
+
min_confidence: the minimum confidence to consider the general_page_orientation
|
|
53
|
+
lower_area: the minimum area of a contour to be considered
|
|
41
54
|
|
|
42
55
|
Returns:
|
|
43
56
|
-------
|
|
44
|
-
the angle of the
|
|
57
|
+
the estimated angle of the page (clockwise, negative for left side rotation, positive for right side rotation)
|
|
45
58
|
"""
|
|
46
59
|
assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
if
|
|
50
|
-
thresh = img.astype(np.uint8)
|
|
51
|
-
if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3:
|
|
60
|
+
thresh = None
|
|
61
|
+
# Convert image to grayscale if necessary
|
|
62
|
+
if img.shape[-1] == 3:
|
|
52
63
|
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
53
64
|
gray_img = cv2.medianBlur(gray_img, 5)
|
|
54
|
-
thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
65
|
+
thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
|
66
|
+
else:
|
|
67
|
+
thresh = img.astype(np.uint8) # type: ignore[assignment]
|
|
68
|
+
|
|
69
|
+
page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
|
|
70
|
+
if page_orientation and orientation_confidence >= min_confidence:
|
|
71
|
+
# We rotate the image to the general orientation which improves the detection
|
|
72
|
+
# No expand needed bitmap is already padded
|
|
73
|
+
thresh = rotate_image(thresh, -page_orientation) # type: ignore
|
|
74
|
+
else: # That's only required if we do not work on the detection models bin map
|
|
75
|
+
# try to merge words in lines
|
|
76
|
+
(h, w) = img.shape[:2]
|
|
77
|
+
k_x = max(1, (floor(w / 100)))
|
|
78
|
+
k_y = max(1, (floor(h / 100)))
|
|
79
|
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
|
|
80
|
+
thresh = cv2.dilate(thresh, kernel, iterations=1)
|
|
62
81
|
|
|
63
82
|
# extract contours
|
|
64
83
|
contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
|
65
84
|
|
|
66
|
-
# Sort contours
|
|
67
|
-
contours = sorted(
|
|
85
|
+
# Filter & Sort contours
|
|
86
|
+
contours = sorted(
|
|
87
|
+
[contour for contour in contours if cv2.contourArea(contour) > lower_area],
|
|
88
|
+
key=get_max_width_length_ratio,
|
|
89
|
+
reverse=True,
|
|
90
|
+
)
|
|
68
91
|
|
|
69
92
|
angles = []
|
|
70
93
|
for contour in contours[:n_ct]:
|
|
71
|
-
_, (w, h), angle = cv2.minAreaRect(contour)
|
|
94
|
+
_, (w, h), angle = cv2.minAreaRect(contour) # type: ignore[assignment]
|
|
72
95
|
if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
|
|
73
96
|
angles.append(angle)
|
|
74
97
|
elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
|
|
75
98
|
angles.append(angle - 90)
|
|
76
99
|
|
|
77
100
|
if len(angles) == 0:
|
|
78
|
-
|
|
101
|
+
estimated_angle = 0 # in case no angles is found
|
|
79
102
|
else:
|
|
80
103
|
median = -median_low(angles)
|
|
81
|
-
|
|
104
|
+
estimated_angle = -round(median) if abs(median) != 0 else 0
|
|
105
|
+
|
|
106
|
+
# combine with the general orientation and the estimated angle
|
|
107
|
+
if page_orientation and orientation_confidence >= min_confidence:
|
|
108
|
+
# special case where the estimated angle is mostly wrong:
|
|
109
|
+
# case 1: - and + swapped
|
|
110
|
+
# case 2: estimated angle is completely wrong
|
|
111
|
+
# so in this case we prefer the general page orientation
|
|
112
|
+
if abs(estimated_angle) == abs(page_orientation):
|
|
113
|
+
return page_orientation
|
|
114
|
+
estimated_angle = estimated_angle if page_orientation == 0 else page_orientation + estimated_angle
|
|
115
|
+
if estimated_angle > 180:
|
|
116
|
+
estimated_angle -= 360
|
|
117
|
+
|
|
118
|
+
return estimated_angle # return the clockwise angle (negative - left side rotation, positive - right side rotation)
|
|
82
119
|
|
|
83
120
|
|
|
84
121
|
def rectify_crops(
|
doctr/models/builder.py
CHANGED
|
@@ -31,7 +31,7 @@ class DocumentBuilder(NestedObject):
|
|
|
31
31
|
def __init__(
|
|
32
32
|
self,
|
|
33
33
|
resolve_lines: bool = True,
|
|
34
|
-
resolve_blocks: bool =
|
|
34
|
+
resolve_blocks: bool = False,
|
|
35
35
|
paragraph_break: float = 0.035,
|
|
36
36
|
export_as_straight_boxes: bool = False,
|
|
37
37
|
) -> None:
|
|
@@ -220,13 +220,22 @@ class DocumentBuilder(NestedObject):
|
|
|
220
220
|
|
|
221
221
|
return blocks
|
|
222
222
|
|
|
223
|
-
def _build_blocks(
|
|
223
|
+
def _build_blocks(
|
|
224
|
+
self,
|
|
225
|
+
boxes: np.ndarray,
|
|
226
|
+
objectness_scores: np.ndarray,
|
|
227
|
+
word_preds: List[Tuple[str, float]],
|
|
228
|
+
crop_orientations: List[Dict[str, Any]],
|
|
229
|
+
) -> List[Block]:
|
|
224
230
|
"""Gather independent words in structured blocks
|
|
225
231
|
|
|
226
232
|
Args:
|
|
227
233
|
----
|
|
228
|
-
boxes: bounding boxes of all detected words of the page, of shape (N,
|
|
234
|
+
boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2)
|
|
235
|
+
objectness_scores: objectness scores of all detected words of the page, of shape N
|
|
229
236
|
word_preds: list of all detected words of the page, of shape N
|
|
237
|
+
crop_orientations: list of dictoinaries containing
|
|
238
|
+
the general orientation (orientations + confidences) of the crops
|
|
230
239
|
|
|
231
240
|
Returns:
|
|
232
241
|
-------
|
|
@@ -257,10 +266,17 @@ class DocumentBuilder(NestedObject):
|
|
|
257
266
|
Line([
|
|
258
267
|
Word(
|
|
259
268
|
*word_preds[idx],
|
|
260
|
-
tuple(
|
|
269
|
+
tuple(tuple(pt) for pt in boxes[idx].tolist()), # type: ignore[arg-type]
|
|
270
|
+
float(objectness_scores[idx]),
|
|
271
|
+
crop_orientations[idx],
|
|
261
272
|
)
|
|
262
273
|
if boxes.ndim == 3
|
|
263
|
-
else Word(
|
|
274
|
+
else Word(
|
|
275
|
+
*word_preds[idx],
|
|
276
|
+
((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])),
|
|
277
|
+
float(objectness_scores[idx]),
|
|
278
|
+
crop_orientations[idx],
|
|
279
|
+
)
|
|
264
280
|
for idx in line
|
|
265
281
|
])
|
|
266
282
|
for line in lines
|
|
@@ -281,8 +297,10 @@ class DocumentBuilder(NestedObject):
|
|
|
281
297
|
self,
|
|
282
298
|
pages: List[np.ndarray],
|
|
283
299
|
boxes: List[np.ndarray],
|
|
300
|
+
objectness_scores: List[np.ndarray],
|
|
284
301
|
text_preds: List[List[Tuple[str, float]]],
|
|
285
302
|
page_shapes: List[Tuple[int, int]],
|
|
303
|
+
crop_orientations: List[Dict[str, Any]],
|
|
286
304
|
orientations: Optional[List[Dict[str, Any]]] = None,
|
|
287
305
|
languages: Optional[List[Dict[str, Any]]] = None,
|
|
288
306
|
) -> Document:
|
|
@@ -291,10 +309,13 @@ class DocumentBuilder(NestedObject):
|
|
|
291
309
|
Args:
|
|
292
310
|
----
|
|
293
311
|
pages: list of N elements, where each element represents the page image
|
|
294
|
-
boxes: list of N elements, where each element represents the localization predictions, of shape (*,
|
|
295
|
-
or (*,
|
|
312
|
+
boxes: list of N elements, where each element represents the localization predictions, of shape (*, 4)
|
|
313
|
+
or (*, 4, 2) for all words for a given page
|
|
314
|
+
objectness_scores: list of N elements, where each element represents the objectness scores
|
|
296
315
|
text_preds: list of N elements, where each element is the list of all word prediction (text + confidence)
|
|
297
316
|
page_shapes: shape of each page, of size N
|
|
317
|
+
crop_orientations: list of N elements, where each element is
|
|
318
|
+
a dictionary containing the general orientation (orientations + confidences) of the crops
|
|
298
319
|
orientations: optional, list of N elements,
|
|
299
320
|
where each element is a dictionary containing the orientation (orientation + confidence)
|
|
300
321
|
languages: optional, list of N elements,
|
|
@@ -304,7 +325,9 @@ class DocumentBuilder(NestedObject):
|
|
|
304
325
|
-------
|
|
305
326
|
document object
|
|
306
327
|
"""
|
|
307
|
-
if len(boxes) != len(text_preds) or len(boxes) != len(
|
|
328
|
+
if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len(
|
|
329
|
+
page_shapes
|
|
330
|
+
) != len(crop_orientations) != len(objectness_scores):
|
|
308
331
|
raise ValueError("All arguments are expected to be lists of the same size")
|
|
309
332
|
|
|
310
333
|
_orientations = (
|
|
@@ -322,15 +345,25 @@ class DocumentBuilder(NestedObject):
|
|
|
322
345
|
page,
|
|
323
346
|
self._build_blocks(
|
|
324
347
|
page_boxes,
|
|
348
|
+
loc_scores,
|
|
325
349
|
word_preds,
|
|
350
|
+
word_crop_orientations,
|
|
326
351
|
),
|
|
327
352
|
_idx,
|
|
328
353
|
shape,
|
|
329
354
|
orientation,
|
|
330
355
|
language,
|
|
331
356
|
)
|
|
332
|
-
for page, _idx, shape, page_boxes, word_preds, orientation, language in zip(
|
|
333
|
-
pages,
|
|
357
|
+
for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language in zip( # noqa: E501
|
|
358
|
+
pages,
|
|
359
|
+
range(len(boxes)),
|
|
360
|
+
page_shapes,
|
|
361
|
+
boxes,
|
|
362
|
+
objectness_scores,
|
|
363
|
+
text_preds,
|
|
364
|
+
crop_orientations,
|
|
365
|
+
_orientations,
|
|
366
|
+
_languages,
|
|
334
367
|
)
|
|
335
368
|
]
|
|
336
369
|
|
|
@@ -353,8 +386,10 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
353
386
|
self,
|
|
354
387
|
pages: List[np.ndarray],
|
|
355
388
|
boxes: List[Dict[str, np.ndarray]],
|
|
389
|
+
objectness_scores: List[Dict[str, np.ndarray]],
|
|
356
390
|
text_preds: List[Dict[str, List[Tuple[str, float]]]],
|
|
357
391
|
page_shapes: List[Tuple[int, int]],
|
|
392
|
+
crop_orientations: List[Dict[str, List[Dict[str, Any]]]],
|
|
358
393
|
orientations: Optional[List[Dict[str, Any]]] = None,
|
|
359
394
|
languages: Optional[List[Dict[str, Any]]] = None,
|
|
360
395
|
) -> KIEDocument:
|
|
@@ -365,8 +400,11 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
365
400
|
pages: list of N elements, where each element represents the page image
|
|
366
401
|
boxes: list of N dictionaries, where each element represents the localization predictions for a class,
|
|
367
402
|
of shape (*, 5) or (*, 6) for all predictions
|
|
403
|
+
objectness_scores: list of N dictionaries, where each element represents the objectness scores for a class
|
|
368
404
|
text_preds: list of N dictionaries, where each element is the list of all word prediction
|
|
369
405
|
page_shapes: shape of each page, of size N
|
|
406
|
+
crop_orientations: list of N dictonaries, where each element is
|
|
407
|
+
a list containing the general crop orientations (orientations + confidences) of the crops
|
|
370
408
|
orientations: optional, list of N elements,
|
|
371
409
|
where each element is a dictionary containing the orientation (orientation + confidence)
|
|
372
410
|
languages: optional, list of N elements,
|
|
@@ -376,7 +414,9 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
376
414
|
-------
|
|
377
415
|
document object
|
|
378
416
|
"""
|
|
379
|
-
if len(boxes) != len(text_preds) or len(boxes) != len(
|
|
417
|
+
if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len(
|
|
418
|
+
page_shapes
|
|
419
|
+
) != len(crop_orientations) != len(objectness_scores):
|
|
380
420
|
raise ValueError("All arguments are expected to be lists of the same size")
|
|
381
421
|
_orientations = (
|
|
382
422
|
orientations if isinstance(orientations, list) else [None] * len(boxes) # type: ignore[list-item]
|
|
@@ -401,7 +441,9 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
401
441
|
{
|
|
402
442
|
k: self._build_blocks(
|
|
403
443
|
page_boxes[k],
|
|
444
|
+
loc_scores[k],
|
|
404
445
|
word_preds[k],
|
|
446
|
+
word_crop_orientations[k],
|
|
405
447
|
)
|
|
406
448
|
for k in page_boxes.keys()
|
|
407
449
|
},
|
|
@@ -410,8 +452,16 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
410
452
|
orientation,
|
|
411
453
|
language,
|
|
412
454
|
)
|
|
413
|
-
for page, _idx, shape, page_boxes, word_preds, orientation, language in zip(
|
|
414
|
-
pages,
|
|
455
|
+
for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language in zip( # noqa: E501
|
|
456
|
+
pages,
|
|
457
|
+
range(len(boxes)),
|
|
458
|
+
page_shapes,
|
|
459
|
+
boxes,
|
|
460
|
+
objectness_scores,
|
|
461
|
+
text_preds,
|
|
462
|
+
crop_orientations,
|
|
463
|
+
_orientations,
|
|
464
|
+
_languages,
|
|
415
465
|
)
|
|
416
466
|
]
|
|
417
467
|
|
|
@@ -420,14 +470,18 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
420
470
|
def _build_blocks( # type: ignore[override]
|
|
421
471
|
self,
|
|
422
472
|
boxes: np.ndarray,
|
|
473
|
+
objectness_scores: np.ndarray,
|
|
423
474
|
word_preds: List[Tuple[str, float]],
|
|
475
|
+
crop_orientations: List[Dict[str, Any]],
|
|
424
476
|
) -> List[Prediction]:
|
|
425
477
|
"""Gather independent words in structured blocks
|
|
426
478
|
|
|
427
479
|
Args:
|
|
428
480
|
----
|
|
429
|
-
boxes: bounding boxes of all detected words of the page, of shape (N,
|
|
481
|
+
boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2)
|
|
482
|
+
objectness_scores: objectness scores of all detected words of the page
|
|
430
483
|
word_preds: list of all detected words of the page, of shape N
|
|
484
|
+
crop_orientations: list of orientations for each word crop
|
|
431
485
|
|
|
432
486
|
Returns:
|
|
433
487
|
-------
|
|
@@ -446,13 +500,17 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
446
500
|
Prediction(
|
|
447
501
|
value=word_preds[idx][0],
|
|
448
502
|
confidence=word_preds[idx][1],
|
|
449
|
-
geometry=tuple(
|
|
503
|
+
geometry=tuple(tuple(pt) for pt in boxes[idx].tolist()), # type: ignore[arg-type]
|
|
504
|
+
objectness_score=float(objectness_scores[idx]),
|
|
505
|
+
crop_orientation=crop_orientations[idx],
|
|
450
506
|
)
|
|
451
507
|
if boxes.ndim == 3
|
|
452
508
|
else Prediction(
|
|
453
509
|
value=word_preds[idx][0],
|
|
454
510
|
confidence=word_preds[idx][1],
|
|
455
511
|
geometry=((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])),
|
|
512
|
+
objectness_score=float(objectness_scores[idx]),
|
|
513
|
+
crop_orientation=crop_orientations[idx],
|
|
456
514
|
)
|
|
457
515
|
for idx in idxs
|
|
458
516
|
]
|
|
@@ -9,12 +9,12 @@ from functools import partial
|
|
|
9
9
|
from typing import Any, Dict, List, Optional, Tuple
|
|
10
10
|
|
|
11
11
|
import tensorflow as tf
|
|
12
|
-
from tensorflow.keras import layers
|
|
12
|
+
from tensorflow.keras import activations, layers
|
|
13
13
|
from tensorflow.keras.models import Sequential
|
|
14
14
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
|
-
from ...utils import load_pretrained_params
|
|
17
|
+
from ...utils import _build_model, load_pretrained_params
|
|
18
18
|
from ..resnet.tensorflow import ResNet
|
|
19
19
|
|
|
20
20
|
__all__ = ["magc_resnet31"]
|
|
@@ -26,7 +26,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
26
26
|
"std": (0.299, 0.296, 0.301),
|
|
27
27
|
"input_shape": (32, 32, 3),
|
|
28
28
|
"classes": list(VOCABS["french"]),
|
|
29
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
29
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0",
|
|
30
30
|
},
|
|
31
31
|
}
|
|
32
32
|
|
|
@@ -57,6 +57,7 @@ class MAGC(layers.Layer):
|
|
|
57
57
|
self.headers = headers # h
|
|
58
58
|
self.inplanes = inplanes # C
|
|
59
59
|
self.attn_scale = attn_scale
|
|
60
|
+
self.ratio = ratio
|
|
60
61
|
self.planes = int(inplanes * ratio)
|
|
61
62
|
|
|
62
63
|
self.single_header_inplanes = int(inplanes / headers) # C / h
|
|
@@ -97,7 +98,7 @@ class MAGC(layers.Layer):
|
|
|
97
98
|
if self.attn_scale and self.headers > 1:
|
|
98
99
|
context_mask = context_mask / math.sqrt(self.single_header_inplanes)
|
|
99
100
|
# B*h, 1, H*W, 1
|
|
100
|
-
context_mask =
|
|
101
|
+
context_mask = activations.softmax(context_mask, axis=2)
|
|
101
102
|
|
|
102
103
|
# Compute context
|
|
103
104
|
# B*h, 1, C/h, 1
|
|
@@ -114,7 +115,7 @@ class MAGC(layers.Layer):
|
|
|
114
115
|
# Context modeling: B, H, W, C -> B, 1, 1, C
|
|
115
116
|
context = self.context_modeling(inputs)
|
|
116
117
|
# Transform: B, 1, 1, C -> B, 1, 1, C
|
|
117
|
-
transformed = self.transform(context)
|
|
118
|
+
transformed = self.transform(context, **kwargs)
|
|
118
119
|
return inputs + transformed
|
|
119
120
|
|
|
120
121
|
|
|
@@ -151,9 +152,15 @@ def _magc_resnet(
|
|
|
151
152
|
cfg=_cfg,
|
|
152
153
|
**kwargs,
|
|
153
154
|
)
|
|
155
|
+
_build_model(model)
|
|
156
|
+
|
|
154
157
|
# Load pretrained parameters
|
|
155
158
|
if pretrained:
|
|
156
|
-
|
|
159
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
160
|
+
# skip the mismatching layers for fine tuning
|
|
161
|
+
load_pretrained_params(
|
|
162
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
163
|
+
)
|
|
157
164
|
|
|
158
165
|
return model
|
|
159
166
|
|
|
@@ -9,17 +9,20 @@ from copy import deepcopy
|
|
|
9
9
|
from typing import Any, Dict, List, Optional
|
|
10
10
|
|
|
11
11
|
from torchvision.models import mobilenetv3
|
|
12
|
+
from torchvision.models.mobilenetv3 import MobileNetV3
|
|
12
13
|
|
|
13
14
|
from doctr.datasets import VOCABS
|
|
14
15
|
|
|
15
16
|
from ...utils import load_pretrained_params
|
|
16
17
|
|
|
17
18
|
__all__ = [
|
|
19
|
+
"MobileNetV3",
|
|
18
20
|
"mobilenet_v3_small",
|
|
19
21
|
"mobilenet_v3_small_r",
|
|
20
22
|
"mobilenet_v3_large",
|
|
21
23
|
"mobilenet_v3_large_r",
|
|
22
|
-
"
|
|
24
|
+
"mobilenet_v3_small_crop_orientation",
|
|
25
|
+
"mobilenet_v3_small_page_orientation",
|
|
23
26
|
]
|
|
24
27
|
|
|
25
28
|
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
@@ -51,12 +54,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
51
54
|
"classes": list(VOCABS["french"]),
|
|
52
55
|
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-1a8a3530.pt&src=0",
|
|
53
56
|
},
|
|
54
|
-
"
|
|
57
|
+
"mobilenet_v3_small_crop_orientation": {
|
|
55
58
|
"mean": (0.694, 0.695, 0.693),
|
|
56
59
|
"std": (0.299, 0.296, 0.301),
|
|
57
|
-
"input_shape": (3,
|
|
58
|
-
"classes": [0, 90, 180,
|
|
59
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
60
|
+
"input_shape": (3, 256, 256),
|
|
61
|
+
"classes": [0, -90, 180, 90],
|
|
62
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_crop_orientation-f0847a18.pt&src=0",
|
|
63
|
+
},
|
|
64
|
+
"mobilenet_v3_small_page_orientation": {
|
|
65
|
+
"mean": (0.694, 0.695, 0.693),
|
|
66
|
+
"std": (0.299, 0.296, 0.301),
|
|
67
|
+
"input_shape": (3, 512, 512),
|
|
68
|
+
"classes": [0, -90, 180, 90],
|
|
69
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-8e60325c.pt&src=0",
|
|
60
70
|
},
|
|
61
71
|
}
|
|
62
72
|
|
|
@@ -212,14 +222,42 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
|
|
|
212
222
|
)
|
|
213
223
|
|
|
214
224
|
|
|
215
|
-
def
|
|
225
|
+
def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
|
|
226
|
+
"""MobileNetV3-Small architecture as described in
|
|
227
|
+
`"Searching for MobileNetV3",
|
|
228
|
+
<https://arxiv.org/pdf/1905.02244.pdf>`_.
|
|
229
|
+
|
|
230
|
+
>>> import torch
|
|
231
|
+
>>> from doctr.models import mobilenet_v3_small_crop_orientation
|
|
232
|
+
>>> model = mobilenet_v3_small_crop_orientation(pretrained=False)
|
|
233
|
+
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
|
234
|
+
>>> out = model(input_tensor)
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
----
|
|
238
|
+
pretrained: boolean, True if model is pretrained
|
|
239
|
+
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
-------
|
|
243
|
+
a torch.nn.Module
|
|
244
|
+
"""
|
|
245
|
+
return _mobilenet_v3(
|
|
246
|
+
"mobilenet_v3_small_crop_orientation",
|
|
247
|
+
pretrained,
|
|
248
|
+
ignore_keys=["classifier.3.weight", "classifier.3.bias"],
|
|
249
|
+
**kwargs,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
|
|
216
254
|
"""MobileNetV3-Small architecture as described in
|
|
217
255
|
`"Searching for MobileNetV3",
|
|
218
256
|
<https://arxiv.org/pdf/1905.02244.pdf>`_.
|
|
219
257
|
|
|
220
258
|
>>> import torch
|
|
221
|
-
>>> from doctr.models import
|
|
222
|
-
>>> model =
|
|
259
|
+
>>> from doctr.models import mobilenet_v3_small_page_orientation
|
|
260
|
+
>>> model = mobilenet_v3_small_page_orientation(pretrained=False)
|
|
223
261
|
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
|
224
262
|
>>> out = model(input_tensor)
|
|
225
263
|
|
|
@@ -233,7 +271,7 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> m
|
|
|
233
271
|
a torch.nn.Module
|
|
234
272
|
"""
|
|
235
273
|
return _mobilenet_v3(
|
|
236
|
-
"
|
|
274
|
+
"mobilenet_v3_small_page_orientation",
|
|
237
275
|
pretrained,
|
|
238
276
|
ignore_keys=["classifier.3.weight", "classifier.3.bias"],
|
|
239
277
|
**kwargs,
|
|
@@ -13,7 +13,7 @@ from tensorflow.keras import layers
|
|
|
13
13
|
from tensorflow.keras.models import Sequential
|
|
14
14
|
|
|
15
15
|
from ....datasets import VOCABS
|
|
16
|
-
from ...utils import conv_sequence, load_pretrained_params
|
|
16
|
+
from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
17
17
|
|
|
18
18
|
__all__ = [
|
|
19
19
|
"MobileNetV3",
|
|
@@ -21,7 +21,8 @@ __all__ = [
|
|
|
21
21
|
"mobilenet_v3_small_r",
|
|
22
22
|
"mobilenet_v3_large",
|
|
23
23
|
"mobilenet_v3_large_r",
|
|
24
|
-
"
|
|
24
|
+
"mobilenet_v3_small_crop_orientation",
|
|
25
|
+
"mobilenet_v3_small_page_orientation",
|
|
25
26
|
]
|
|
26
27
|
|
|
27
28
|
|
|
@@ -31,35 +32,42 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
31
32
|
"std": (0.299, 0.296, 0.301),
|
|
32
33
|
"input_shape": (32, 32, 3),
|
|
33
34
|
"classes": list(VOCABS["french"]),
|
|
34
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
35
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0",
|
|
35
36
|
},
|
|
36
37
|
"mobilenet_v3_large_r": {
|
|
37
38
|
"mean": (0.694, 0.695, 0.693),
|
|
38
39
|
"std": (0.299, 0.296, 0.301),
|
|
39
40
|
"input_shape": (32, 32, 3),
|
|
40
41
|
"classes": list(VOCABS["french"]),
|
|
41
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
42
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0",
|
|
42
43
|
},
|
|
43
44
|
"mobilenet_v3_small": {
|
|
44
45
|
"mean": (0.694, 0.695, 0.693),
|
|
45
46
|
"std": (0.299, 0.296, 0.301),
|
|
46
47
|
"input_shape": (32, 32, 3),
|
|
47
48
|
"classes": list(VOCABS["french"]),
|
|
48
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
49
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0",
|
|
49
50
|
},
|
|
50
51
|
"mobilenet_v3_small_r": {
|
|
51
52
|
"mean": (0.694, 0.695, 0.693),
|
|
52
53
|
"std": (0.299, 0.296, 0.301),
|
|
53
54
|
"input_shape": (32, 32, 3),
|
|
54
55
|
"classes": list(VOCABS["french"]),
|
|
55
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
56
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0",
|
|
56
57
|
},
|
|
57
|
-
"
|
|
58
|
+
"mobilenet_v3_small_crop_orientation": {
|
|
58
59
|
"mean": (0.694, 0.695, 0.693),
|
|
59
60
|
"std": (0.299, 0.296, 0.301),
|
|
60
61
|
"input_shape": (128, 128, 3),
|
|
61
|
-
"classes": [0, 90, 180,
|
|
62
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
62
|
+
"classes": [0, -90, 180, 90],
|
|
63
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0",
|
|
64
|
+
},
|
|
65
|
+
"mobilenet_v3_small_page_orientation": {
|
|
66
|
+
"mean": (0.694, 0.695, 0.693),
|
|
67
|
+
"std": (0.299, 0.296, 0.301),
|
|
68
|
+
"input_shape": (512, 512, 3),
|
|
69
|
+
"classes": [0, -90, 180, 90],
|
|
70
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0",
|
|
63
71
|
},
|
|
64
72
|
}
|
|
65
73
|
|
|
@@ -287,9 +295,15 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa
|
|
|
287
295
|
cfg=_cfg,
|
|
288
296
|
**kwargs,
|
|
289
297
|
)
|
|
298
|
+
_build_model(model)
|
|
299
|
+
|
|
290
300
|
# Load pretrained parameters
|
|
291
301
|
if pretrained:
|
|
292
|
-
|
|
302
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
303
|
+
# skip the mismatching layers for fine tuning
|
|
304
|
+
load_pretrained_params(
|
|
305
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
306
|
+
)
|
|
293
307
|
|
|
294
308
|
return model
|
|
295
309
|
|
|
@@ -386,14 +400,37 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
|
|
|
386
400
|
return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs)
|
|
387
401
|
|
|
388
402
|
|
|
389
|
-
def
|
|
403
|
+
def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
|
|
404
|
+
"""MobileNetV3-Small architecture as described in
|
|
405
|
+
`"Searching for MobileNetV3",
|
|
406
|
+
<https://arxiv.org/pdf/1905.02244.pdf>`_.
|
|
407
|
+
|
|
408
|
+
>>> import tensorflow as tf
|
|
409
|
+
>>> from doctr.models import mobilenet_v3_small_crop_orientation
|
|
410
|
+
>>> model = mobilenet_v3_small_crop_orientation(pretrained=False)
|
|
411
|
+
>>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
|
|
412
|
+
>>> out = model(input_tensor)
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
----
|
|
416
|
+
pretrained: boolean, True if model is pretrained
|
|
417
|
+
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
-------
|
|
421
|
+
a keras.Model
|
|
422
|
+
"""
|
|
423
|
+
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
|
|
390
427
|
"""MobileNetV3-Small architecture as described in
|
|
391
428
|
`"Searching for MobileNetV3",
|
|
392
429
|
<https://arxiv.org/pdf/1905.02244.pdf>`_.
|
|
393
430
|
|
|
394
431
|
>>> import tensorflow as tf
|
|
395
|
-
>>> from doctr.models import
|
|
396
|
-
>>> model =
|
|
432
|
+
>>> from doctr.models import mobilenet_v3_small_page_orientation
|
|
433
|
+
>>> model = mobilenet_v3_small_page_orientation(pretrained=False)
|
|
397
434
|
>>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
|
|
398
435
|
>>> out = model(input_tensor)
|
|
399
436
|
|
|
@@ -406,4 +443,4 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> M
|
|
|
406
443
|
-------
|
|
407
444
|
a keras.Model
|
|
408
445
|
"""
|
|
409
|
-
return _mobilenet_v3("
|
|
446
|
+
return _mobilenet_v3("mobilenet_v3_small_page_orientation", pretrained, include_top=True, **kwargs)
|