python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
doctr/io/reader.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
from collections.abc import Sequence
|
|
6
7
|
from pathlib import Path
|
|
7
|
-
from typing import List, Sequence, Union
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
|
|
@@ -22,37 +22,33 @@ class DocumentFile:
|
|
|
22
22
|
"""Read a document from multiple extensions"""
|
|
23
23
|
|
|
24
24
|
@classmethod
|
|
25
|
-
def from_pdf(cls, file: AbstractFile, **kwargs) ->
|
|
25
|
+
def from_pdf(cls, file: AbstractFile, **kwargs) -> list[np.ndarray]:
|
|
26
26
|
"""Read a PDF file
|
|
27
27
|
|
|
28
28
|
>>> from doctr.io import DocumentFile
|
|
29
29
|
>>> doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
|
|
30
30
|
|
|
31
31
|
Args:
|
|
32
|
-
----
|
|
33
32
|
file: the path to the PDF file or a binary stream
|
|
34
33
|
**kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render`
|
|
35
34
|
|
|
36
35
|
Returns:
|
|
37
|
-
-------
|
|
38
36
|
the list of pages decoded as numpy ndarray of shape H x W x 3
|
|
39
37
|
"""
|
|
40
38
|
return read_pdf(file, **kwargs)
|
|
41
39
|
|
|
42
40
|
@classmethod
|
|
43
|
-
def from_url(cls, url: str, **kwargs) ->
|
|
41
|
+
def from_url(cls, url: str, **kwargs) -> list[np.ndarray]:
|
|
44
42
|
"""Interpret a web page as a PDF document
|
|
45
43
|
|
|
46
44
|
>>> from doctr.io import DocumentFile
|
|
47
45
|
>>> doc = DocumentFile.from_url("https://www.yoursite.com")
|
|
48
46
|
|
|
49
47
|
Args:
|
|
50
|
-
----
|
|
51
48
|
url: the URL of the target web page
|
|
52
49
|
**kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render`
|
|
53
50
|
|
|
54
51
|
Returns:
|
|
55
|
-
-------
|
|
56
52
|
the list of pages decoded as numpy ndarray of shape H x W x 3
|
|
57
53
|
"""
|
|
58
54
|
requires_package(
|
|
@@ -64,19 +60,17 @@ class DocumentFile:
|
|
|
64
60
|
return cls.from_pdf(pdf_stream, **kwargs)
|
|
65
61
|
|
|
66
62
|
@classmethod
|
|
67
|
-
def from_images(cls, files:
|
|
63
|
+
def from_images(cls, files: Sequence[AbstractFile] | AbstractFile, **kwargs) -> list[np.ndarray]:
|
|
68
64
|
"""Read an image file (or a collection of image files) and convert it into an image in numpy format
|
|
69
65
|
|
|
70
66
|
>>> from doctr.io import DocumentFile
|
|
71
67
|
>>> pages = DocumentFile.from_images(["path/to/your/page1.png", "path/to/your/page2.png"])
|
|
72
68
|
|
|
73
69
|
Args:
|
|
74
|
-
----
|
|
75
70
|
files: the path to the image file or a binary stream, or a collection of those
|
|
76
71
|
**kwargs: additional parameters to :meth:`doctr.io.image.read_img_as_numpy`
|
|
77
72
|
|
|
78
73
|
Returns:
|
|
79
|
-
-------
|
|
80
74
|
the list of pages decoded as numpy ndarray of shape H x W x 3
|
|
81
75
|
"""
|
|
82
76
|
if isinstance(files, (str, Path, bytes)):
|
doctr/models/_utils.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from math import floor
|
|
7
7
|
from statistics import median_low
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import cv2
|
|
11
11
|
import numpy as np
|
|
@@ -20,11 +20,9 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
|
|
|
20
20
|
"""Get the maximum shape ratio of a contour.
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
----
|
|
24
23
|
contour: the contour from cv2.findContour
|
|
25
24
|
|
|
26
25
|
Returns:
|
|
27
|
-
-------
|
|
28
26
|
the maximum shape ratio
|
|
29
27
|
"""
|
|
30
28
|
_, (w, h), _ = cv2.minAreaRect(contour)
|
|
@@ -33,7 +31,7 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
|
|
|
33
31
|
|
|
34
32
|
def estimate_orientation(
|
|
35
33
|
img: np.ndarray,
|
|
36
|
-
general_page_orientation:
|
|
34
|
+
general_page_orientation: tuple[int, float] | None = None,
|
|
37
35
|
n_ct: int = 70,
|
|
38
36
|
ratio_threshold_for_lines: float = 3,
|
|
39
37
|
min_confidence: float = 0.2,
|
|
@@ -43,7 +41,6 @@ def estimate_orientation(
|
|
|
43
41
|
lines of the document and the assumption that they should be horizontal.
|
|
44
42
|
|
|
45
43
|
Args:
|
|
46
|
-
----
|
|
47
44
|
img: the img or bitmap to analyze (H, W, C)
|
|
48
45
|
general_page_orientation: the general orientation of the page (angle [0, 90, 180, 270 (-90)], confidence)
|
|
49
46
|
estimated by a model
|
|
@@ -53,7 +50,6 @@ def estimate_orientation(
|
|
|
53
50
|
lower_area: the minimum area of a contour to be considered
|
|
54
51
|
|
|
55
52
|
Returns:
|
|
56
|
-
-------
|
|
57
53
|
the estimated angle of the page (clockwise, negative for left side rotation, positive for right side rotation)
|
|
58
54
|
"""
|
|
59
55
|
assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
|
|
@@ -64,13 +60,13 @@ def estimate_orientation(
|
|
|
64
60
|
gray_img = cv2.medianBlur(gray_img, 5)
|
|
65
61
|
thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
|
66
62
|
else:
|
|
67
|
-
thresh = img.astype(np.uint8)
|
|
63
|
+
thresh = img.astype(np.uint8)
|
|
68
64
|
|
|
69
65
|
page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
|
|
70
66
|
if page_orientation and orientation_confidence >= min_confidence:
|
|
71
67
|
# We rotate the image to the general orientation which improves the detection
|
|
72
68
|
# No expand needed bitmap is already padded
|
|
73
|
-
thresh = rotate_image(thresh, -page_orientation)
|
|
69
|
+
thresh = rotate_image(thresh, -page_orientation)
|
|
74
70
|
else: # That's only required if we do not work on the detection models bin map
|
|
75
71
|
# try to merge words in lines
|
|
76
72
|
(h, w) = img.shape[:2]
|
|
@@ -119,9 +115,9 @@ def estimate_orientation(
|
|
|
119
115
|
|
|
120
116
|
|
|
121
117
|
def rectify_crops(
|
|
122
|
-
crops:
|
|
123
|
-
orientations:
|
|
124
|
-
) ->
|
|
118
|
+
crops: list[np.ndarray],
|
|
119
|
+
orientations: list[int],
|
|
120
|
+
) -> list[np.ndarray]:
|
|
125
121
|
"""Rotate each crop of the list according to the predicted orientation:
|
|
126
122
|
0: already straight, no rotation
|
|
127
123
|
1: 90 ccw, rotate 3 times ccw
|
|
@@ -139,8 +135,8 @@ def rectify_crops(
|
|
|
139
135
|
|
|
140
136
|
def rectify_loc_preds(
|
|
141
137
|
page_loc_preds: np.ndarray,
|
|
142
|
-
orientations:
|
|
143
|
-
) ->
|
|
138
|
+
orientations: list[int],
|
|
139
|
+
) -> np.ndarray | None:
|
|
144
140
|
"""Orient the quadrangle (Polygon4P) according to the predicted orientation,
|
|
145
141
|
so that the points are in this order: top L, top R, bot R, bot L if the crop is readable
|
|
146
142
|
"""
|
|
@@ -157,16 +153,14 @@ def rectify_loc_preds(
|
|
|
157
153
|
)
|
|
158
154
|
|
|
159
155
|
|
|
160
|
-
def get_language(text: str) ->
|
|
156
|
+
def get_language(text: str) -> tuple[str, float]:
|
|
161
157
|
"""Get languages of a text using langdetect model.
|
|
162
158
|
Get the language with the highest probability or no language if only a few words or a low probability
|
|
163
159
|
|
|
164
160
|
Args:
|
|
165
|
-
----
|
|
166
161
|
text (str): text
|
|
167
162
|
|
|
168
163
|
Returns:
|
|
169
|
-
-------
|
|
170
164
|
The detected language in ISO 639 code and confidence score
|
|
171
165
|
"""
|
|
172
166
|
try:
|
|
@@ -179,16 +173,14 @@ def get_language(text: str) -> Tuple[str, float]:
|
|
|
179
173
|
|
|
180
174
|
|
|
181
175
|
def invert_data_structure(
|
|
182
|
-
x:
|
|
183
|
-
) ->
|
|
184
|
-
"""Invert a
|
|
176
|
+
x: list[dict[str, Any]] | dict[str, list[Any]],
|
|
177
|
+
) -> list[dict[str, Any]] | dict[str, list[Any]]:
|
|
178
|
+
"""Invert a list of dict of elements to a dict of list of elements and the other way around
|
|
185
179
|
|
|
186
180
|
Args:
|
|
187
|
-
----
|
|
188
181
|
x: a list of dictionaries with the same keys or a dictionary of lists of the same length
|
|
189
182
|
|
|
190
183
|
Returns:
|
|
191
|
-
-------
|
|
192
184
|
dictionary of list when x is a list of dictionaries or a list of dictionaries when x is dictionary of lists
|
|
193
185
|
"""
|
|
194
186
|
if isinstance(x, dict):
|
doctr/models/builder.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from scipy.cluster.hierarchy import fclusterdata
|
|
@@ -20,7 +20,6 @@ class DocumentBuilder(NestedObject):
|
|
|
20
20
|
"""Implements a document builder
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
----
|
|
24
23
|
resolve_lines: whether words should be automatically grouped into lines
|
|
25
24
|
resolve_blocks: whether lines should be automatically grouped into blocks
|
|
26
25
|
paragraph_break: relative length of the minimum space separating paragraphs
|
|
@@ -41,15 +40,13 @@ class DocumentBuilder(NestedObject):
|
|
|
41
40
|
self.export_as_straight_boxes = export_as_straight_boxes
|
|
42
41
|
|
|
43
42
|
@staticmethod
|
|
44
|
-
def _sort_boxes(boxes: np.ndarray) ->
|
|
43
|
+
def _sort_boxes(boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
|
45
44
|
"""Sort bounding boxes from top to bottom, left to right
|
|
46
45
|
|
|
47
46
|
Args:
|
|
48
|
-
----
|
|
49
47
|
boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox)
|
|
50
48
|
|
|
51
49
|
Returns:
|
|
52
|
-
-------
|
|
53
50
|
tuple: indices of ordered boxes of shape (N,), boxes
|
|
54
51
|
If straight boxes are passed tpo the function, boxes are unchanged
|
|
55
52
|
else: boxes returned are straight boxes fitted to the straightened rotated boxes
|
|
@@ -65,16 +62,14 @@ class DocumentBuilder(NestedObject):
|
|
|
65
62
|
boxes = np.concatenate((boxes.min(1), boxes.max(1)), -1)
|
|
66
63
|
return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort(), boxes
|
|
67
64
|
|
|
68
|
-
def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs:
|
|
65
|
+
def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: list[int]) -> list[list[int]]:
|
|
69
66
|
"""Split a line in sub_lines
|
|
70
67
|
|
|
71
68
|
Args:
|
|
72
|
-
----
|
|
73
69
|
boxes: bounding boxes of shape (N, 4)
|
|
74
70
|
word_idcs: list of indexes for the words of the line
|
|
75
71
|
|
|
76
72
|
Returns:
|
|
77
|
-
-------
|
|
78
73
|
A list of (sub-)lines computed from the original line (words)
|
|
79
74
|
"""
|
|
80
75
|
lines = []
|
|
@@ -105,15 +100,13 @@ class DocumentBuilder(NestedObject):
|
|
|
105
100
|
|
|
106
101
|
return lines
|
|
107
102
|
|
|
108
|
-
def _resolve_lines(self, boxes: np.ndarray) ->
|
|
103
|
+
def _resolve_lines(self, boxes: np.ndarray) -> list[list[int]]:
|
|
109
104
|
"""Order boxes to group them in lines
|
|
110
105
|
|
|
111
106
|
Args:
|
|
112
|
-
----
|
|
113
107
|
boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox
|
|
114
108
|
|
|
115
109
|
Returns:
|
|
116
|
-
-------
|
|
117
110
|
nested list of box indices
|
|
118
111
|
"""
|
|
119
112
|
# Sort boxes, and straighten the boxes if they are rotated
|
|
@@ -153,16 +146,14 @@ class DocumentBuilder(NestedObject):
|
|
|
153
146
|
return lines
|
|
154
147
|
|
|
155
148
|
@staticmethod
|
|
156
|
-
def _resolve_blocks(boxes: np.ndarray, lines:
|
|
149
|
+
def _resolve_blocks(boxes: np.ndarray, lines: list[list[int]]) -> list[list[list[int]]]:
|
|
157
150
|
"""Order lines to group them in blocks
|
|
158
151
|
|
|
159
152
|
Args:
|
|
160
|
-
----
|
|
161
153
|
boxes: bounding boxes of shape (N, 4) or (N, 4, 2)
|
|
162
154
|
lines: list of lines, each line is a list of idx
|
|
163
155
|
|
|
164
156
|
Returns:
|
|
165
|
-
-------
|
|
166
157
|
nested list of box indices
|
|
167
158
|
"""
|
|
168
159
|
# Resolve enclosing boxes of lines
|
|
@@ -207,7 +198,7 @@ class DocumentBuilder(NestedObject):
|
|
|
207
198
|
# Compute clusters
|
|
208
199
|
clusters = fclusterdata(box_features, t=0.1, depth=4, criterion="distance", metric="euclidean")
|
|
209
200
|
|
|
210
|
-
_blocks:
|
|
201
|
+
_blocks: dict[int, list[int]] = {}
|
|
211
202
|
# Form clusters
|
|
212
203
|
for line_idx, cluster_idx in enumerate(clusters):
|
|
213
204
|
if cluster_idx in _blocks.keys():
|
|
@@ -224,13 +215,12 @@ class DocumentBuilder(NestedObject):
|
|
|
224
215
|
self,
|
|
225
216
|
boxes: np.ndarray,
|
|
226
217
|
objectness_scores: np.ndarray,
|
|
227
|
-
word_preds:
|
|
228
|
-
crop_orientations:
|
|
229
|
-
) ->
|
|
218
|
+
word_preds: list[tuple[str, float]],
|
|
219
|
+
crop_orientations: list[dict[str, Any]],
|
|
220
|
+
) -> list[Block]:
|
|
230
221
|
"""Gather independent words in structured blocks
|
|
231
222
|
|
|
232
223
|
Args:
|
|
233
|
-
----
|
|
234
224
|
boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2)
|
|
235
225
|
objectness_scores: objectness scores of all detected words of the page, of shape N
|
|
236
226
|
word_preds: list of all detected words of the page, of shape N
|
|
@@ -238,7 +228,6 @@ class DocumentBuilder(NestedObject):
|
|
|
238
228
|
the general orientation (orientations + confidences) of the crops
|
|
239
229
|
|
|
240
230
|
Returns:
|
|
241
|
-
-------
|
|
242
231
|
list of block elements
|
|
243
232
|
"""
|
|
244
233
|
if boxes.shape[0] != len(word_preds):
|
|
@@ -266,7 +255,7 @@ class DocumentBuilder(NestedObject):
|
|
|
266
255
|
Line([
|
|
267
256
|
Word(
|
|
268
257
|
*word_preds[idx],
|
|
269
|
-
tuple(
|
|
258
|
+
tuple(tuple(pt) for pt in boxes[idx].tolist()), # type: ignore[arg-type]
|
|
270
259
|
float(objectness_scores[idx]),
|
|
271
260
|
crop_orientations[idx],
|
|
272
261
|
)
|
|
@@ -295,19 +284,18 @@ class DocumentBuilder(NestedObject):
|
|
|
295
284
|
|
|
296
285
|
def __call__(
|
|
297
286
|
self,
|
|
298
|
-
pages:
|
|
299
|
-
boxes:
|
|
300
|
-
objectness_scores:
|
|
301
|
-
text_preds:
|
|
302
|
-
page_shapes:
|
|
303
|
-
crop_orientations:
|
|
304
|
-
orientations:
|
|
305
|
-
languages:
|
|
287
|
+
pages: list[np.ndarray],
|
|
288
|
+
boxes: list[np.ndarray],
|
|
289
|
+
objectness_scores: list[np.ndarray],
|
|
290
|
+
text_preds: list[list[tuple[str, float]]],
|
|
291
|
+
page_shapes: list[tuple[int, int]],
|
|
292
|
+
crop_orientations: list[dict[str, Any]],
|
|
293
|
+
orientations: list[dict[str, Any]] | None = None,
|
|
294
|
+
languages: list[dict[str, Any]] | None = None,
|
|
306
295
|
) -> Document:
|
|
307
296
|
"""Re-arrange detected words into structured blocks
|
|
308
297
|
|
|
309
298
|
Args:
|
|
310
|
-
----
|
|
311
299
|
pages: list of N elements, where each element represents the page image
|
|
312
300
|
boxes: list of N elements, where each element represents the localization predictions, of shape (*, 4)
|
|
313
301
|
or (*, 4, 2) for all words for a given page
|
|
@@ -322,7 +310,6 @@ class DocumentBuilder(NestedObject):
|
|
|
322
310
|
where each element is a dictionary containing the language (language + confidence)
|
|
323
311
|
|
|
324
312
|
Returns:
|
|
325
|
-
-------
|
|
326
313
|
document object
|
|
327
314
|
"""
|
|
328
315
|
if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len(
|
|
@@ -374,7 +361,6 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
374
361
|
"""Implements a KIE document builder
|
|
375
362
|
|
|
376
363
|
Args:
|
|
377
|
-
----
|
|
378
364
|
resolve_lines: whether words should be automatically grouped into lines
|
|
379
365
|
resolve_blocks: whether lines should be automatically grouped into blocks
|
|
380
366
|
paragraph_break: relative length of the minimum space separating paragraphs
|
|
@@ -384,19 +370,18 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
384
370
|
|
|
385
371
|
def __call__( # type: ignore[override]
|
|
386
372
|
self,
|
|
387
|
-
pages:
|
|
388
|
-
boxes:
|
|
389
|
-
objectness_scores:
|
|
390
|
-
text_preds:
|
|
391
|
-
page_shapes:
|
|
392
|
-
crop_orientations:
|
|
393
|
-
orientations:
|
|
394
|
-
languages:
|
|
373
|
+
pages: list[np.ndarray],
|
|
374
|
+
boxes: list[dict[str, np.ndarray]],
|
|
375
|
+
objectness_scores: list[dict[str, np.ndarray]],
|
|
376
|
+
text_preds: list[dict[str, list[tuple[str, float]]]],
|
|
377
|
+
page_shapes: list[tuple[int, int]],
|
|
378
|
+
crop_orientations: list[dict[str, list[dict[str, Any]]]],
|
|
379
|
+
orientations: list[dict[str, Any]] | None = None,
|
|
380
|
+
languages: list[dict[str, Any]] | None = None,
|
|
395
381
|
) -> KIEDocument:
|
|
396
382
|
"""Re-arrange detected words into structured predictions
|
|
397
383
|
|
|
398
384
|
Args:
|
|
399
|
-
----
|
|
400
385
|
pages: list of N elements, where each element represents the page image
|
|
401
386
|
boxes: list of N dictionaries, where each element represents the localization predictions for a class,
|
|
402
387
|
of shape (*, 5) or (*, 6) for all predictions
|
|
@@ -411,7 +396,6 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
411
396
|
where each element is a dictionary containing the language (language + confidence)
|
|
412
397
|
|
|
413
398
|
Returns:
|
|
414
|
-
-------
|
|
415
399
|
document object
|
|
416
400
|
"""
|
|
417
401
|
if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len(
|
|
@@ -425,7 +409,7 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
425
409
|
if self.export_as_straight_boxes and len(boxes) > 0:
|
|
426
410
|
# If boxes are already straight OK, else fit a bounding rect
|
|
427
411
|
if next(iter(boxes[0].values())).ndim == 3:
|
|
428
|
-
straight_boxes:
|
|
412
|
+
straight_boxes: list[dict[str, np.ndarray]] = []
|
|
429
413
|
# Iterate over pages
|
|
430
414
|
for p_boxes in boxes:
|
|
431
415
|
# Iterate over boxes of the pages
|
|
@@ -471,20 +455,18 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
471
455
|
self,
|
|
472
456
|
boxes: np.ndarray,
|
|
473
457
|
objectness_scores: np.ndarray,
|
|
474
|
-
word_preds:
|
|
475
|
-
crop_orientations:
|
|
476
|
-
) ->
|
|
458
|
+
word_preds: list[tuple[str, float]],
|
|
459
|
+
crop_orientations: list[dict[str, Any]],
|
|
460
|
+
) -> list[Prediction]:
|
|
477
461
|
"""Gather independent words in structured blocks
|
|
478
462
|
|
|
479
463
|
Args:
|
|
480
|
-
----
|
|
481
464
|
boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2)
|
|
482
465
|
objectness_scores: objectness scores of all detected words of the page
|
|
483
466
|
word_preds: list of all detected words of the page, of shape N
|
|
484
467
|
crop_orientations: list of orientations for each word crop
|
|
485
468
|
|
|
486
469
|
Returns:
|
|
487
|
-
-------
|
|
488
470
|
list of block elements
|
|
489
471
|
"""
|
|
490
472
|
if boxes.shape[0] != len(word_preds):
|
|
@@ -500,7 +482,7 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
500
482
|
Prediction(
|
|
501
483
|
value=word_preds[idx][0],
|
|
502
484
|
confidence=word_preds[idx][1],
|
|
503
|
-
geometry=tuple(
|
|
485
|
+
geometry=tuple(tuple(pt) for pt in boxes[idx].tolist()), # type: ignore[arg-type]
|
|
504
486
|
objectness_score=float(objectness_scores[idx]),
|
|
505
487
|
crop_orientation=crop_orientations[idx],
|
|
506
488
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
import math
|
|
8
8
|
from copy import deepcopy
|
|
9
9
|
from functools import partial
|
|
10
|
-
from typing import Any
|
|
10
|
+
from typing import Any
|
|
11
11
|
|
|
12
12
|
import torch
|
|
13
13
|
from torch import nn
|
|
@@ -20,7 +20,7 @@ from ..resnet.pytorch import ResNet
|
|
|
20
20
|
__all__ = ["magc_resnet31"]
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
default_cfgs:
|
|
23
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
24
24
|
"magc_resnet31": {
|
|
25
25
|
"mean": (0.694, 0.695, 0.693),
|
|
26
26
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -36,7 +36,6 @@ class MAGC(nn.Module):
|
|
|
36
36
|
<https://arxiv.org/pdf/1910.02562.pdf>`_.
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
|
-
----
|
|
40
39
|
inplanes: input channels
|
|
41
40
|
headers: number of headers to split channels
|
|
42
41
|
attn_scale: if True, re-scale attention to counteract the variance distibutions
|
|
@@ -50,7 +49,7 @@ class MAGC(nn.Module):
|
|
|
50
49
|
headers: int = 8,
|
|
51
50
|
attn_scale: bool = False,
|
|
52
51
|
ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper
|
|
53
|
-
cfg:
|
|
52
|
+
cfg: dict[str, Any] | None = None,
|
|
54
53
|
) -> None:
|
|
55
54
|
super().__init__()
|
|
56
55
|
|
|
@@ -105,12 +104,12 @@ class MAGC(nn.Module):
|
|
|
105
104
|
def _magc_resnet(
|
|
106
105
|
arch: str,
|
|
107
106
|
pretrained: bool,
|
|
108
|
-
num_blocks:
|
|
109
|
-
output_channels:
|
|
110
|
-
stage_stride:
|
|
111
|
-
stage_conv:
|
|
112
|
-
stage_pooling:
|
|
113
|
-
ignore_keys:
|
|
107
|
+
num_blocks: list[int],
|
|
108
|
+
output_channels: list[int],
|
|
109
|
+
stage_stride: list[int],
|
|
110
|
+
stage_conv: list[bool],
|
|
111
|
+
stage_pooling: list[tuple[int, int] | None],
|
|
112
|
+
ignore_keys: list[str] | None = None,
|
|
114
113
|
**kwargs: Any,
|
|
115
114
|
) -> ResNet:
|
|
116
115
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
@@ -154,12 +153,10 @@ def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
154
153
|
>>> out = model(input_tensor)
|
|
155
154
|
|
|
156
155
|
Args:
|
|
157
|
-
----
|
|
158
156
|
pretrained: boolean, True if model is pretrained
|
|
159
157
|
**kwargs: keyword arguments of the ResNet architecture
|
|
160
158
|
|
|
161
159
|
Returns:
|
|
162
|
-
-------
|
|
163
160
|
A feature extractor model
|
|
164
161
|
"""
|
|
165
162
|
return _magc_resnet(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,27 +6,27 @@
|
|
|
6
6
|
import math
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from functools import partial
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import tensorflow as tf
|
|
12
|
-
from tensorflow.keras import layers
|
|
12
|
+
from tensorflow.keras import activations, layers
|
|
13
13
|
from tensorflow.keras.models import Sequential
|
|
14
14
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
|
-
from ...utils import load_pretrained_params
|
|
17
|
+
from ...utils import _build_model, load_pretrained_params
|
|
18
18
|
from ..resnet.tensorflow import ResNet
|
|
19
19
|
|
|
20
20
|
__all__ = ["magc_resnet31"]
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
default_cfgs:
|
|
23
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
24
24
|
"magc_resnet31": {
|
|
25
25
|
"mean": (0.694, 0.695, 0.693),
|
|
26
26
|
"std": (0.299, 0.296, 0.301),
|
|
27
27
|
"input_shape": (32, 32, 3),
|
|
28
28
|
"classes": list(VOCABS["french"]),
|
|
29
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
29
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0",
|
|
30
30
|
},
|
|
31
31
|
}
|
|
32
32
|
|
|
@@ -36,7 +36,6 @@ class MAGC(layers.Layer):
|
|
|
36
36
|
<https://arxiv.org/pdf/1910.02562.pdf>`_.
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
|
-
----
|
|
40
39
|
inplanes: input channels
|
|
41
40
|
headers: number of headers to split channels
|
|
42
41
|
attn_scale: if True, re-scale attention to counteract the variance distibutions
|
|
@@ -57,6 +56,7 @@ class MAGC(layers.Layer):
|
|
|
57
56
|
self.headers = headers # h
|
|
58
57
|
self.inplanes = inplanes # C
|
|
59
58
|
self.attn_scale = attn_scale
|
|
59
|
+
self.ratio = ratio
|
|
60
60
|
self.planes = int(inplanes * ratio)
|
|
61
61
|
|
|
62
62
|
self.single_header_inplanes = int(inplanes / headers) # C / h
|
|
@@ -97,7 +97,7 @@ class MAGC(layers.Layer):
|
|
|
97
97
|
if self.attn_scale and self.headers > 1:
|
|
98
98
|
context_mask = context_mask / math.sqrt(self.single_header_inplanes)
|
|
99
99
|
# B*h, 1, H*W, 1
|
|
100
|
-
context_mask =
|
|
100
|
+
context_mask = activations.softmax(context_mask, axis=2)
|
|
101
101
|
|
|
102
102
|
# Compute context
|
|
103
103
|
# B*h, 1, C/h, 1
|
|
@@ -114,18 +114,18 @@ class MAGC(layers.Layer):
|
|
|
114
114
|
# Context modeling: B, H, W, C -> B, 1, 1, C
|
|
115
115
|
context = self.context_modeling(inputs)
|
|
116
116
|
# Transform: B, 1, 1, C -> B, 1, 1, C
|
|
117
|
-
transformed = self.transform(context)
|
|
117
|
+
transformed = self.transform(context, **kwargs)
|
|
118
118
|
return inputs + transformed
|
|
119
119
|
|
|
120
120
|
|
|
121
121
|
def _magc_resnet(
|
|
122
122
|
arch: str,
|
|
123
123
|
pretrained: bool,
|
|
124
|
-
num_blocks:
|
|
125
|
-
output_channels:
|
|
126
|
-
stage_downsample:
|
|
127
|
-
stage_conv:
|
|
128
|
-
stage_pooling:
|
|
124
|
+
num_blocks: list[int],
|
|
125
|
+
output_channels: list[int],
|
|
126
|
+
stage_downsample: list[bool],
|
|
127
|
+
stage_conv: list[bool],
|
|
128
|
+
stage_pooling: list[tuple[int, int] | None],
|
|
129
129
|
origin_stem: bool = True,
|
|
130
130
|
**kwargs: Any,
|
|
131
131
|
) -> ResNet:
|
|
@@ -151,9 +151,15 @@ def _magc_resnet(
|
|
|
151
151
|
cfg=_cfg,
|
|
152
152
|
**kwargs,
|
|
153
153
|
)
|
|
154
|
+
_build_model(model)
|
|
155
|
+
|
|
154
156
|
# Load pretrained parameters
|
|
155
157
|
if pretrained:
|
|
156
|
-
|
|
158
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
159
|
+
# skip the mismatching layers for fine tuning
|
|
160
|
+
load_pretrained_params(
|
|
161
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
162
|
+
)
|
|
157
163
|
|
|
158
164
|
return model
|
|
159
165
|
|
|
@@ -170,12 +176,10 @@ def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
170
176
|
>>> out = model(input_tensor)
|
|
171
177
|
|
|
172
178
|
Args:
|
|
173
|
-
----
|
|
174
179
|
pretrained: boolean, True if model is pretrained
|
|
175
180
|
**kwargs: keyword arguments of the ResNet architecture
|
|
176
181
|
|
|
177
182
|
Returns:
|
|
178
|
-
-------
|
|
179
183
|
A feature extractor model
|
|
180
184
|
"""
|
|
181
185
|
return _magc_resnet(
|