python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
doctr/models/predictor/base.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
|
|
10
11
|
from doctr.models.builder import DocumentBuilder
|
|
11
|
-
from doctr.utils.geometry import extract_crops, extract_rcrops, rotate_image
|
|
12
|
+
from doctr.utils.geometry import extract_crops, extract_rcrops, remove_image_padding, rotate_image
|
|
12
13
|
|
|
13
14
|
from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
|
|
14
15
|
from ..classification import crop_orientation_predictor, page_orientation_predictor
|
|
@@ -21,7 +22,6 @@ class _OCRPredictor:
|
|
|
21
22
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
22
23
|
|
|
23
24
|
Args:
|
|
24
|
-
----
|
|
25
25
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
26
26
|
without rotated textual elements.
|
|
27
27
|
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
|
|
@@ -34,8 +34,8 @@ class _OCRPredictor:
|
|
|
34
34
|
**kwargs: keyword args of `DocumentBuilder`
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
|
-
crop_orientation_predictor:
|
|
38
|
-
page_orientation_predictor:
|
|
37
|
+
crop_orientation_predictor: OrientationPredictor | None
|
|
38
|
+
page_orientation_predictor: OrientationPredictor | None
|
|
39
39
|
|
|
40
40
|
def __init__(
|
|
41
41
|
self,
|
|
@@ -48,21 +48,27 @@ class _OCRPredictor:
|
|
|
48
48
|
) -> None:
|
|
49
49
|
self.assume_straight_pages = assume_straight_pages
|
|
50
50
|
self.straighten_pages = straighten_pages
|
|
51
|
-
self.
|
|
51
|
+
self._page_orientation_disabled = kwargs.pop("disable_page_orientation", False)
|
|
52
|
+
self._crop_orientation_disabled = kwargs.pop("disable_crop_orientation", False)
|
|
53
|
+
self.crop_orientation_predictor = (
|
|
54
|
+
None
|
|
55
|
+
if assume_straight_pages
|
|
56
|
+
else crop_orientation_predictor(pretrained=True, disabled=self._crop_orientation_disabled)
|
|
57
|
+
)
|
|
52
58
|
self.page_orientation_predictor = (
|
|
53
|
-
page_orientation_predictor(pretrained=True)
|
|
59
|
+
page_orientation_predictor(pretrained=True, disabled=self._page_orientation_disabled)
|
|
54
60
|
if detect_orientation or straighten_pages or not assume_straight_pages
|
|
55
61
|
else None
|
|
56
62
|
)
|
|
57
63
|
self.doc_builder = DocumentBuilder(**kwargs)
|
|
58
64
|
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
59
65
|
self.symmetric_pad = symmetric_pad
|
|
60
|
-
self.hooks:
|
|
66
|
+
self.hooks: list[Callable] = []
|
|
61
67
|
|
|
62
68
|
def _general_page_orientations(
|
|
63
69
|
self,
|
|
64
|
-
pages:
|
|
65
|
-
) ->
|
|
70
|
+
pages: list[np.ndarray],
|
|
71
|
+
) -> list[tuple[int, float]]:
|
|
66
72
|
_, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
|
|
67
73
|
# Flatten to list of tuples with (value, confidence)
|
|
68
74
|
page_orientations = [
|
|
@@ -73,8 +79,8 @@ class _OCRPredictor:
|
|
|
73
79
|
return page_orientations
|
|
74
80
|
|
|
75
81
|
def _get_orientations(
|
|
76
|
-
self, pages:
|
|
77
|
-
) ->
|
|
82
|
+
self, pages: list[np.ndarray], seg_maps: list[np.ndarray]
|
|
83
|
+
) -> tuple[list[tuple[int, float]], list[int]]:
|
|
78
84
|
general_pages_orientations = self._general_page_orientations(pages)
|
|
79
85
|
origin_page_orientations = [
|
|
80
86
|
estimate_orientation(seq_map, general_orientation)
|
|
@@ -84,11 +90,11 @@ class _OCRPredictor:
|
|
|
84
90
|
|
|
85
91
|
def _straighten_pages(
|
|
86
92
|
self,
|
|
87
|
-
pages:
|
|
88
|
-
seg_maps:
|
|
89
|
-
general_pages_orientations:
|
|
90
|
-
origin_pages_orientations:
|
|
91
|
-
) ->
|
|
93
|
+
pages: list[np.ndarray],
|
|
94
|
+
seg_maps: list[np.ndarray],
|
|
95
|
+
general_pages_orientations: list[tuple[int, float]] | None = None,
|
|
96
|
+
origin_pages_orientations: list[int] | None = None,
|
|
97
|
+
) -> list[np.ndarray]:
|
|
92
98
|
general_pages_orientations = (
|
|
93
99
|
general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
|
|
94
100
|
)
|
|
@@ -101,34 +107,40 @@ class _OCRPredictor:
|
|
|
101
107
|
]
|
|
102
108
|
)
|
|
103
109
|
return [
|
|
104
|
-
#
|
|
105
|
-
rotate_image(page, angle, expand=page.shape[
|
|
110
|
+
# expand if height and width are not equal, then remove the padding
|
|
111
|
+
remove_image_padding(rotate_image(page, angle, expand=page.shape[0] != page.shape[1]))
|
|
106
112
|
for page, angle in zip(pages, origin_pages_orientations)
|
|
107
113
|
]
|
|
108
114
|
|
|
109
115
|
@staticmethod
|
|
110
116
|
def _generate_crops(
|
|
111
|
-
pages:
|
|
112
|
-
loc_preds:
|
|
117
|
+
pages: list[np.ndarray],
|
|
118
|
+
loc_preds: list[np.ndarray],
|
|
113
119
|
channels_last: bool,
|
|
114
120
|
assume_straight_pages: bool = False,
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
121
|
+
assume_horizontal: bool = False,
|
|
122
|
+
) -> list[list[np.ndarray]]:
|
|
123
|
+
if assume_straight_pages:
|
|
124
|
+
crops = [
|
|
125
|
+
extract_crops(page, _boxes[:, :4], channels_last=channels_last)
|
|
126
|
+
for page, _boxes in zip(pages, loc_preds)
|
|
127
|
+
]
|
|
128
|
+
else:
|
|
129
|
+
crops = [
|
|
130
|
+
extract_rcrops(page, _boxes[:, :4], channels_last=channels_last, assume_horizontal=assume_horizontal)
|
|
131
|
+
for page, _boxes in zip(pages, loc_preds)
|
|
132
|
+
]
|
|
122
133
|
return crops
|
|
123
134
|
|
|
124
135
|
@staticmethod
|
|
125
136
|
def _prepare_crops(
|
|
126
|
-
pages:
|
|
127
|
-
loc_preds:
|
|
137
|
+
pages: list[np.ndarray],
|
|
138
|
+
loc_preds: list[np.ndarray],
|
|
128
139
|
channels_last: bool,
|
|
129
140
|
assume_straight_pages: bool = False,
|
|
130
|
-
|
|
131
|
-
|
|
141
|
+
assume_horizontal: bool = False,
|
|
142
|
+
) -> tuple[list[list[np.ndarray]], list[np.ndarray]]:
|
|
143
|
+
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)
|
|
132
144
|
|
|
133
145
|
# Avoid sending zero-sized crops
|
|
134
146
|
is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
|
|
@@ -142,9 +154,9 @@ class _OCRPredictor:
|
|
|
142
154
|
|
|
143
155
|
def _rectify_crops(
|
|
144
156
|
self,
|
|
145
|
-
crops:
|
|
146
|
-
loc_preds:
|
|
147
|
-
) ->
|
|
157
|
+
crops: list[list[np.ndarray]],
|
|
158
|
+
loc_preds: list[np.ndarray],
|
|
159
|
+
) -> tuple[list[list[np.ndarray]], list[np.ndarray], list[tuple[int, float]]]:
|
|
148
160
|
# Work at a page level
|
|
149
161
|
orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc]
|
|
150
162
|
rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
|
|
@@ -162,10 +174,10 @@ class _OCRPredictor:
|
|
|
162
174
|
|
|
163
175
|
@staticmethod
|
|
164
176
|
def _process_predictions(
|
|
165
|
-
loc_preds:
|
|
166
|
-
word_preds:
|
|
167
|
-
crop_orientations:
|
|
168
|
-
) ->
|
|
177
|
+
loc_preds: list[np.ndarray],
|
|
178
|
+
word_preds: list[tuple[str, float]],
|
|
179
|
+
crop_orientations: list[dict[str, Any]],
|
|
180
|
+
) -> tuple[list[np.ndarray], list[list[tuple[str, float]]], list[list[dict[str, Any]]]]:
|
|
169
181
|
text_preds = []
|
|
170
182
|
crop_orientation_preds = []
|
|
171
183
|
if len(loc_preds) > 0:
|
|
@@ -182,7 +194,6 @@ class _OCRPredictor:
|
|
|
182
194
|
"""Add a hook to the predictor
|
|
183
195
|
|
|
184
196
|
Args:
|
|
185
|
-
----
|
|
186
197
|
hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
|
|
187
198
|
"""
|
|
188
199
|
self.hooks.append(hook)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -24,7 +24,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
-
----
|
|
28
27
|
det_predictor: detection module
|
|
29
28
|
reco_predictor: recognition module
|
|
30
29
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -52,8 +51,8 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
52
51
|
**kwargs: Any,
|
|
53
52
|
) -> None:
|
|
54
53
|
nn.Module.__init__(self)
|
|
55
|
-
self.det_predictor = det_predictor.eval()
|
|
56
|
-
self.reco_predictor = reco_predictor.eval()
|
|
54
|
+
self.det_predictor = det_predictor.eval()
|
|
55
|
+
self.reco_predictor = reco_predictor.eval()
|
|
57
56
|
_OCRPredictor.__init__(
|
|
58
57
|
self,
|
|
59
58
|
assume_straight_pages,
|
|
@@ -69,7 +68,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
69
68
|
@torch.inference_mode()
|
|
70
69
|
def forward(
|
|
71
70
|
self,
|
|
72
|
-
pages:
|
|
71
|
+
pages: list[np.ndarray | torch.Tensor],
|
|
73
72
|
**kwargs: Any,
|
|
74
73
|
) -> Document:
|
|
75
74
|
# Dimension check
|
|
@@ -87,7 +86,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
87
86
|
for out_map in out_maps
|
|
88
87
|
]
|
|
89
88
|
if self.detect_orientation:
|
|
90
|
-
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
89
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
91
90
|
orientations = [
|
|
92
91
|
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
93
92
|
]
|
|
@@ -96,13 +95,16 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
96
95
|
general_pages_orientations = None
|
|
97
96
|
origin_pages_orientations = None
|
|
98
97
|
if self.straighten_pages:
|
|
99
|
-
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
98
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
99
|
+
# update page shapes after straightening
|
|
100
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
101
|
+
|
|
100
102
|
# Forward again to get predictions on straight pages
|
|
101
103
|
loc_preds = self.det_predictor(pages, **kwargs)
|
|
102
104
|
|
|
103
|
-
assert all(
|
|
104
|
-
|
|
105
|
-
)
|
|
105
|
+
assert all(len(loc_pred) == 1 for loc_pred in loc_preds), (
|
|
106
|
+
"Detection Model in ocr_predictor should output only one class"
|
|
107
|
+
)
|
|
106
108
|
|
|
107
109
|
loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
|
|
108
110
|
# Detach objectness scores from loc_preds
|
|
@@ -116,10 +118,11 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
116
118
|
|
|
117
119
|
# Crop images
|
|
118
120
|
crops, loc_preds = self._prepare_crops(
|
|
119
|
-
pages,
|
|
121
|
+
pages,
|
|
120
122
|
loc_preds,
|
|
121
123
|
channels_last=channels_last,
|
|
122
124
|
assume_straight_pages=self.assume_straight_pages,
|
|
125
|
+
assume_horizontal=self._page_orientation_disabled,
|
|
123
126
|
)
|
|
124
127
|
# Rectify crop orientation and get crop orientation predictions
|
|
125
128
|
crop_orientations: Any = []
|
|
@@ -143,7 +146,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
143
146
|
languages_dict = None
|
|
144
147
|
|
|
145
148
|
out = self.doc_builder(
|
|
146
|
-
pages,
|
|
149
|
+
pages,
|
|
147
150
|
boxes,
|
|
148
151
|
objectness_scores,
|
|
149
152
|
text_preds,
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
@@ -24,7 +24,6 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
-
----
|
|
28
27
|
det_predictor: detection module
|
|
29
28
|
reco_predictor: recognition module
|
|
30
29
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -69,7 +68,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
69
68
|
|
|
70
69
|
def __call__(
|
|
71
70
|
self,
|
|
72
|
-
pages:
|
|
71
|
+
pages: list[np.ndarray | tf.Tensor],
|
|
73
72
|
**kwargs: Any,
|
|
74
73
|
) -> Document:
|
|
75
74
|
# Dimension check
|
|
@@ -97,13 +96,16 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
97
96
|
origin_pages_orientations = None
|
|
98
97
|
if self.straighten_pages:
|
|
99
98
|
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
99
|
+
# update page shapes after straightening
|
|
100
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
101
|
+
|
|
100
102
|
# forward again to get predictions on straight pages
|
|
101
|
-
loc_preds_dict = self.det_predictor(pages, **kwargs)
|
|
103
|
+
loc_preds_dict = self.det_predictor(pages, **kwargs)
|
|
102
104
|
|
|
103
|
-
assert all(
|
|
104
|
-
|
|
105
|
-
)
|
|
106
|
-
loc_preds:
|
|
105
|
+
assert all(len(loc_pred) == 1 for loc_pred in loc_preds_dict), (
|
|
106
|
+
"Detection Model in ocr_predictor should output only one class"
|
|
107
|
+
)
|
|
108
|
+
loc_preds: list[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict]
|
|
107
109
|
# Detach objectness scores from loc_preds
|
|
108
110
|
loc_preds, objectness_scores = detach_scores(loc_preds)
|
|
109
111
|
|
|
@@ -113,7 +115,11 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
113
115
|
|
|
114
116
|
# Crop images
|
|
115
117
|
crops, loc_preds = self._prepare_crops(
|
|
116
|
-
pages,
|
|
118
|
+
pages,
|
|
119
|
+
loc_preds,
|
|
120
|
+
channels_last=True,
|
|
121
|
+
assume_straight_pages=self.assume_straight_pages,
|
|
122
|
+
assume_horizontal=self._page_orientation_disabled,
|
|
117
123
|
)
|
|
118
124
|
# Rectify crop orientation and get crop orientation predictions
|
|
119
125
|
crop_orientations: Any = []
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import torch
|
|
@@ -22,19 +22,19 @@ class PreProcessor(nn.Module):
|
|
|
22
22
|
"""Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
|
|
23
23
|
|
|
24
24
|
Args:
|
|
25
|
-
----
|
|
26
25
|
output_size: expected size of each page in format (H, W)
|
|
27
26
|
batch_size: the size of page batches
|
|
28
27
|
mean: mean value of the training distribution by channel
|
|
29
28
|
std: standard deviation of the training distribution by channel
|
|
29
|
+
**kwargs: additional arguments for the resizing operation
|
|
30
30
|
"""
|
|
31
31
|
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
|
-
output_size:
|
|
34
|
+
output_size: tuple[int, int],
|
|
35
35
|
batch_size: int,
|
|
36
|
-
mean:
|
|
37
|
-
std:
|
|
36
|
+
mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
37
|
+
std: tuple[float, float, float] = (1.0, 1.0, 1.0),
|
|
38
38
|
**kwargs: Any,
|
|
39
39
|
) -> None:
|
|
40
40
|
super().__init__()
|
|
@@ -43,15 +43,13 @@ class PreProcessor(nn.Module):
|
|
|
43
43
|
# Perform the division by 255 at the same time
|
|
44
44
|
self.normalize = T.Normalize(mean, std)
|
|
45
45
|
|
|
46
|
-
def batch_inputs(self, samples:
|
|
46
|
+
def batch_inputs(self, samples: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
47
47
|
"""Gather samples into batches for inference purposes
|
|
48
48
|
|
|
49
49
|
Args:
|
|
50
|
-
----
|
|
51
50
|
samples: list of samples of shape (C, H, W)
|
|
52
51
|
|
|
53
52
|
Returns:
|
|
54
|
-
-------
|
|
55
53
|
list of batched samples (*, C, H, W)
|
|
56
54
|
"""
|
|
57
55
|
num_batches = int(math.ceil(len(samples) / self.batch_size))
|
|
@@ -62,7 +60,7 @@ class PreProcessor(nn.Module):
|
|
|
62
60
|
|
|
63
61
|
return batches
|
|
64
62
|
|
|
65
|
-
def sample_transforms(self, x:
|
|
63
|
+
def sample_transforms(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
|
|
66
64
|
if x.ndim != 3:
|
|
67
65
|
raise AssertionError("expected list of 3D Tensors")
|
|
68
66
|
if isinstance(x, np.ndarray):
|
|
@@ -79,17 +77,15 @@ class PreProcessor(nn.Module):
|
|
|
79
77
|
else:
|
|
80
78
|
x = x.to(dtype=torch.float32) # type: ignore[union-attr]
|
|
81
79
|
|
|
82
|
-
return x
|
|
80
|
+
return x # type: ignore[return-value]
|
|
83
81
|
|
|
84
|
-
def __call__(self, x:
|
|
82
|
+
def __call__(self, x: torch.Tensor | np.ndarray | list[torch.Tensor | np.ndarray]) -> list[torch.Tensor]:
|
|
85
83
|
"""Prepare document data for model forwarding
|
|
86
84
|
|
|
87
85
|
Args:
|
|
88
|
-
----
|
|
89
86
|
x: list of images (np.array) or tensors (already resized and batched)
|
|
90
87
|
|
|
91
88
|
Returns:
|
|
92
|
-
-------
|
|
93
89
|
list of page batches
|
|
94
90
|
"""
|
|
95
91
|
# Input type check
|
|
@@ -103,7 +99,7 @@ class PreProcessor(nn.Module):
|
|
|
103
99
|
elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
104
100
|
raise TypeError("unsupported data type for torch.Tensor")
|
|
105
101
|
# Resizing
|
|
106
|
-
if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
|
|
102
|
+
if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: # type: ignore[union-attr]
|
|
107
103
|
x = F.resize(
|
|
108
104
|
x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
|
|
109
105
|
)
|
|
@@ -118,11 +114,11 @@ class PreProcessor(nn.Module):
|
|
|
118
114
|
# Sample transform (to tensor, resize)
|
|
119
115
|
samples = list(multithread_exec(self.sample_transforms, x))
|
|
120
116
|
# Batching
|
|
121
|
-
batches = self.batch_inputs(samples)
|
|
117
|
+
batches = self.batch_inputs(samples) # type: ignore[assignment]
|
|
122
118
|
else:
|
|
123
119
|
raise TypeError(f"invalid input type: {type(x)}")
|
|
124
120
|
|
|
125
121
|
# Batch transforms (normalize)
|
|
126
122
|
batches = list(multithread_exec(self.normalize, batches))
|
|
127
123
|
|
|
128
|
-
return batches
|
|
124
|
+
return batches # type: ignore[return-value]
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import tensorflow as tf
|
|
@@ -20,38 +20,36 @@ class PreProcessor(NestedObject):
|
|
|
20
20
|
"""Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
----
|
|
24
23
|
output_size: expected size of each page in format (H, W)
|
|
25
24
|
batch_size: the size of page batches
|
|
26
25
|
mean: mean value of the training distribution by channel
|
|
27
26
|
std: standard deviation of the training distribution by channel
|
|
27
|
+
**kwargs: additional arguments for the resizing operation
|
|
28
28
|
"""
|
|
29
29
|
|
|
30
|
-
_children_names:
|
|
30
|
+
_children_names: list[str] = ["resize", "normalize"]
|
|
31
31
|
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
|
-
output_size:
|
|
34
|
+
output_size: tuple[int, int],
|
|
35
35
|
batch_size: int,
|
|
36
|
-
mean:
|
|
37
|
-
std:
|
|
36
|
+
mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
37
|
+
std: tuple[float, float, float] = (1.0, 1.0, 1.0),
|
|
38
38
|
**kwargs: Any,
|
|
39
39
|
) -> None:
|
|
40
40
|
self.batch_size = batch_size
|
|
41
41
|
self.resize = Resize(output_size, **kwargs)
|
|
42
42
|
# Perform the division by 255 at the same time
|
|
43
43
|
self.normalize = Normalize(mean, std)
|
|
44
|
-
self._runs_on_cuda = tf.
|
|
44
|
+
self._runs_on_cuda = tf.config.list_physical_devices("GPU") != []
|
|
45
45
|
|
|
46
|
-
def batch_inputs(self, samples:
|
|
46
|
+
def batch_inputs(self, samples: list[tf.Tensor]) -> list[tf.Tensor]:
|
|
47
47
|
"""Gather samples into batches for inference purposes
|
|
48
48
|
|
|
49
49
|
Args:
|
|
50
|
-
----
|
|
51
50
|
samples: list of samples (tf.Tensor)
|
|
52
51
|
|
|
53
52
|
Returns:
|
|
54
|
-
-------
|
|
55
53
|
list of batched samples
|
|
56
54
|
"""
|
|
57
55
|
num_batches = int(math.ceil(len(samples) / self.batch_size))
|
|
@@ -62,7 +60,7 @@ class PreProcessor(NestedObject):
|
|
|
62
60
|
|
|
63
61
|
return batches
|
|
64
62
|
|
|
65
|
-
def sample_transforms(self, x:
|
|
63
|
+
def sample_transforms(self, x: np.ndarray | tf.Tensor) -> tf.Tensor:
|
|
66
64
|
if x.ndim != 3:
|
|
67
65
|
raise AssertionError("expected list of 3D Tensors")
|
|
68
66
|
if isinstance(x, np.ndarray):
|
|
@@ -79,15 +77,13 @@ class PreProcessor(NestedObject):
|
|
|
79
77
|
|
|
80
78
|
return x
|
|
81
79
|
|
|
82
|
-
def __call__(self, x:
|
|
80
|
+
def __call__(self, x: tf.Tensor | np.ndarray | list[tf.Tensor | np.ndarray]) -> list[tf.Tensor]:
|
|
83
81
|
"""Prepare document data for model forwarding
|
|
84
82
|
|
|
85
83
|
Args:
|
|
86
|
-
----
|
|
87
84
|
x: list of images (np.array) or tensors (already resized and batched)
|
|
88
85
|
|
|
89
86
|
Returns:
|
|
90
|
-
-------
|
|
91
87
|
list of page batches
|
|
92
88
|
"""
|
|
93
89
|
# Input type check
|
doctr/models/recognition/core.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Tuple
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -21,17 +20,15 @@ class RecognitionModel(NestedObject):
|
|
|
21
20
|
|
|
22
21
|
def build_target(
|
|
23
22
|
self,
|
|
24
|
-
gts:
|
|
25
|
-
) ->
|
|
23
|
+
gts: list[str],
|
|
24
|
+
) -> tuple[np.ndarray, list[int]]:
|
|
26
25
|
"""Encode a list of gts sequences into a np array and gives the corresponding*
|
|
27
26
|
sequence lengths.
|
|
28
27
|
|
|
29
28
|
Args:
|
|
30
|
-
----
|
|
31
29
|
gts: list of ground-truth labels
|
|
32
30
|
|
|
33
31
|
Returns:
|
|
34
|
-
-------
|
|
35
32
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
36
33
|
"""
|
|
37
34
|
encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab))
|
|
@@ -43,7 +40,6 @@ class RecognitionPostProcessor(NestedObject):
|
|
|
43
40
|
"""Abstract class to postprocess the raw output of the model
|
|
44
41
|
|
|
45
42
|
Args:
|
|
46
|
-
----
|
|
47
43
|
vocab: string containing the ordered sequence of supported characters
|
|
48
44
|
"""
|
|
49
45
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|