python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
doctr/models/predictor/base.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
|
|
@@ -21,7 +22,6 @@ class _OCRPredictor:
|
|
|
21
22
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
22
23
|
|
|
23
24
|
Args:
|
|
24
|
-
----
|
|
25
25
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
26
26
|
without rotated textual elements.
|
|
27
27
|
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
|
|
@@ -34,8 +34,8 @@ class _OCRPredictor:
|
|
|
34
34
|
**kwargs: keyword args of `DocumentBuilder`
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
|
-
crop_orientation_predictor:
|
|
38
|
-
page_orientation_predictor:
|
|
37
|
+
crop_orientation_predictor: OrientationPredictor | None
|
|
38
|
+
page_orientation_predictor: OrientationPredictor | None
|
|
39
39
|
|
|
40
40
|
def __init__(
|
|
41
41
|
self,
|
|
@@ -63,12 +63,12 @@ class _OCRPredictor:
|
|
|
63
63
|
self.doc_builder = DocumentBuilder(**kwargs)
|
|
64
64
|
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
65
65
|
self.symmetric_pad = symmetric_pad
|
|
66
|
-
self.hooks:
|
|
66
|
+
self.hooks: list[Callable] = []
|
|
67
67
|
|
|
68
68
|
def _general_page_orientations(
|
|
69
69
|
self,
|
|
70
|
-
pages:
|
|
71
|
-
) ->
|
|
70
|
+
pages: list[np.ndarray],
|
|
71
|
+
) -> list[tuple[int, float]]:
|
|
72
72
|
_, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
|
|
73
73
|
# Flatten to list of tuples with (value, confidence)
|
|
74
74
|
page_orientations = [
|
|
@@ -79,8 +79,8 @@ class _OCRPredictor:
|
|
|
79
79
|
return page_orientations
|
|
80
80
|
|
|
81
81
|
def _get_orientations(
|
|
82
|
-
self, pages:
|
|
83
|
-
) ->
|
|
82
|
+
self, pages: list[np.ndarray], seg_maps: list[np.ndarray]
|
|
83
|
+
) -> tuple[list[tuple[int, float]], list[int]]:
|
|
84
84
|
general_pages_orientations = self._general_page_orientations(pages)
|
|
85
85
|
origin_page_orientations = [
|
|
86
86
|
estimate_orientation(seq_map, general_orientation)
|
|
@@ -90,11 +90,11 @@ class _OCRPredictor:
|
|
|
90
90
|
|
|
91
91
|
def _straighten_pages(
|
|
92
92
|
self,
|
|
93
|
-
pages:
|
|
94
|
-
seg_maps:
|
|
95
|
-
general_pages_orientations:
|
|
96
|
-
origin_pages_orientations:
|
|
97
|
-
) ->
|
|
93
|
+
pages: list[np.ndarray],
|
|
94
|
+
seg_maps: list[np.ndarray],
|
|
95
|
+
general_pages_orientations: list[tuple[int, float]] | None = None,
|
|
96
|
+
origin_pages_orientations: list[int] | None = None,
|
|
97
|
+
) -> list[np.ndarray]:
|
|
98
98
|
general_pages_orientations = (
|
|
99
99
|
general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
|
|
100
100
|
)
|
|
@@ -114,12 +114,12 @@ class _OCRPredictor:
|
|
|
114
114
|
|
|
115
115
|
@staticmethod
|
|
116
116
|
def _generate_crops(
|
|
117
|
-
pages:
|
|
118
|
-
loc_preds:
|
|
117
|
+
pages: list[np.ndarray],
|
|
118
|
+
loc_preds: list[np.ndarray],
|
|
119
119
|
channels_last: bool,
|
|
120
120
|
assume_straight_pages: bool = False,
|
|
121
121
|
assume_horizontal: bool = False,
|
|
122
|
-
) ->
|
|
122
|
+
) -> list[list[np.ndarray]]:
|
|
123
123
|
if assume_straight_pages:
|
|
124
124
|
crops = [
|
|
125
125
|
extract_crops(page, _boxes[:, :4], channels_last=channels_last)
|
|
@@ -134,12 +134,12 @@ class _OCRPredictor:
|
|
|
134
134
|
|
|
135
135
|
@staticmethod
|
|
136
136
|
def _prepare_crops(
|
|
137
|
-
pages:
|
|
138
|
-
loc_preds:
|
|
137
|
+
pages: list[np.ndarray],
|
|
138
|
+
loc_preds: list[np.ndarray],
|
|
139
139
|
channels_last: bool,
|
|
140
140
|
assume_straight_pages: bool = False,
|
|
141
141
|
assume_horizontal: bool = False,
|
|
142
|
-
) ->
|
|
142
|
+
) -> tuple[list[list[np.ndarray]], list[np.ndarray]]:
|
|
143
143
|
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)
|
|
144
144
|
|
|
145
145
|
# Avoid sending zero-sized crops
|
|
@@ -154,9 +154,9 @@ class _OCRPredictor:
|
|
|
154
154
|
|
|
155
155
|
def _rectify_crops(
|
|
156
156
|
self,
|
|
157
|
-
crops:
|
|
158
|
-
loc_preds:
|
|
159
|
-
) ->
|
|
157
|
+
crops: list[list[np.ndarray]],
|
|
158
|
+
loc_preds: list[np.ndarray],
|
|
159
|
+
) -> tuple[list[list[np.ndarray]], list[np.ndarray], list[tuple[int, float]]]:
|
|
160
160
|
# Work at a page level
|
|
161
161
|
orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc]
|
|
162
162
|
rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
|
|
@@ -174,10 +174,10 @@ class _OCRPredictor:
|
|
|
174
174
|
|
|
175
175
|
@staticmethod
|
|
176
176
|
def _process_predictions(
|
|
177
|
-
loc_preds:
|
|
178
|
-
word_preds:
|
|
179
|
-
crop_orientations:
|
|
180
|
-
) ->
|
|
177
|
+
loc_preds: list[np.ndarray],
|
|
178
|
+
word_preds: list[tuple[str, float]],
|
|
179
|
+
crop_orientations: list[dict[str, Any]],
|
|
180
|
+
) -> tuple[list[np.ndarray], list[list[tuple[str, float]]], list[list[dict[str, Any]]]]:
|
|
181
181
|
text_preds = []
|
|
182
182
|
crop_orientation_preds = []
|
|
183
183
|
if len(loc_preds) > 0:
|
|
@@ -194,7 +194,6 @@ class _OCRPredictor:
|
|
|
194
194
|
"""Add a hook to the predictor
|
|
195
195
|
|
|
196
196
|
Args:
|
|
197
|
-
----
|
|
198
197
|
hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
|
|
199
198
|
"""
|
|
200
199
|
self.hooks.append(hook)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -24,7 +24,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
-
----
|
|
28
27
|
det_predictor: detection module
|
|
29
28
|
reco_predictor: recognition module
|
|
30
29
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -52,8 +51,8 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
52
51
|
**kwargs: Any,
|
|
53
52
|
) -> None:
|
|
54
53
|
nn.Module.__init__(self)
|
|
55
|
-
self.det_predictor = det_predictor.eval()
|
|
56
|
-
self.reco_predictor = reco_predictor.eval()
|
|
54
|
+
self.det_predictor = det_predictor.eval()
|
|
55
|
+
self.reco_predictor = reco_predictor.eval()
|
|
57
56
|
_OCRPredictor.__init__(
|
|
58
57
|
self,
|
|
59
58
|
assume_straight_pages,
|
|
@@ -69,7 +68,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
69
68
|
@torch.inference_mode()
|
|
70
69
|
def forward(
|
|
71
70
|
self,
|
|
72
|
-
pages:
|
|
71
|
+
pages: list[np.ndarray | torch.Tensor],
|
|
73
72
|
**kwargs: Any,
|
|
74
73
|
) -> Document:
|
|
75
74
|
# Dimension check
|
|
@@ -87,7 +86,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
87
86
|
for out_map in out_maps
|
|
88
87
|
]
|
|
89
88
|
if self.detect_orientation:
|
|
90
|
-
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
89
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
91
90
|
orientations = [
|
|
92
91
|
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
93
92
|
]
|
|
@@ -96,16 +95,16 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
96
95
|
general_pages_orientations = None
|
|
97
96
|
origin_pages_orientations = None
|
|
98
97
|
if self.straighten_pages:
|
|
99
|
-
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
98
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
100
99
|
# update page shapes after straightening
|
|
101
100
|
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
102
101
|
|
|
103
102
|
# Forward again to get predictions on straight pages
|
|
104
103
|
loc_preds = self.det_predictor(pages, **kwargs)
|
|
105
104
|
|
|
106
|
-
assert all(
|
|
107
|
-
|
|
108
|
-
)
|
|
105
|
+
assert all(len(loc_pred) == 1 for loc_pred in loc_preds), (
|
|
106
|
+
"Detection Model in ocr_predictor should output only one class"
|
|
107
|
+
)
|
|
109
108
|
|
|
110
109
|
loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
|
|
111
110
|
# Detach objectness scores from loc_preds
|
|
@@ -119,7 +118,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
119
118
|
|
|
120
119
|
# Crop images
|
|
121
120
|
crops, loc_preds = self._prepare_crops(
|
|
122
|
-
pages,
|
|
121
|
+
pages,
|
|
123
122
|
loc_preds,
|
|
124
123
|
channels_last=channels_last,
|
|
125
124
|
assume_straight_pages=self.assume_straight_pages,
|
|
@@ -147,11 +146,11 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
147
146
|
languages_dict = None
|
|
148
147
|
|
|
149
148
|
out = self.doc_builder(
|
|
150
|
-
pages,
|
|
149
|
+
pages,
|
|
151
150
|
boxes,
|
|
152
151
|
objectness_scores,
|
|
153
152
|
text_preds,
|
|
154
|
-
origin_page_shapes,
|
|
153
|
+
origin_page_shapes,
|
|
155
154
|
crop_orientations,
|
|
156
155
|
orientations,
|
|
157
156
|
languages_dict,
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
@@ -24,7 +24,6 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
-
----
|
|
28
27
|
det_predictor: detection module
|
|
29
28
|
reco_predictor: recognition module
|
|
30
29
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -69,7 +68,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
69
68
|
|
|
70
69
|
def __call__(
|
|
71
70
|
self,
|
|
72
|
-
pages:
|
|
71
|
+
pages: list[np.ndarray | tf.Tensor],
|
|
73
72
|
**kwargs: Any,
|
|
74
73
|
) -> Document:
|
|
75
74
|
# Dimension check
|
|
@@ -101,12 +100,12 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
101
100
|
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
102
101
|
|
|
103
102
|
# forward again to get predictions on straight pages
|
|
104
|
-
loc_preds_dict = self.det_predictor(pages, **kwargs)
|
|
103
|
+
loc_preds_dict = self.det_predictor(pages, **kwargs)
|
|
105
104
|
|
|
106
|
-
assert all(
|
|
107
|
-
|
|
108
|
-
)
|
|
109
|
-
loc_preds:
|
|
105
|
+
assert all(len(loc_pred) == 1 for loc_pred in loc_preds_dict), (
|
|
106
|
+
"Detection Model in ocr_predictor should output only one class"
|
|
107
|
+
)
|
|
108
|
+
loc_preds: list[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict]
|
|
110
109
|
# Detach objectness scores from loc_preds
|
|
111
110
|
loc_preds, objectness_scores = detach_scores(loc_preds)
|
|
112
111
|
|
|
@@ -148,7 +147,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
148
147
|
boxes,
|
|
149
148
|
objectness_scores,
|
|
150
149
|
text_preds,
|
|
151
|
-
origin_page_shapes,
|
|
150
|
+
origin_page_shapes,
|
|
152
151
|
crop_orientations,
|
|
153
152
|
orientations,
|
|
154
153
|
languages_dict,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import torch
|
|
@@ -22,19 +22,19 @@ class PreProcessor(nn.Module):
|
|
|
22
22
|
"""Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
|
|
23
23
|
|
|
24
24
|
Args:
|
|
25
|
-
----
|
|
26
25
|
output_size: expected size of each page in format (H, W)
|
|
27
26
|
batch_size: the size of page batches
|
|
28
27
|
mean: mean value of the training distribution by channel
|
|
29
28
|
std: standard deviation of the training distribution by channel
|
|
29
|
+
**kwargs: additional arguments for the resizing operation
|
|
30
30
|
"""
|
|
31
31
|
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
|
-
output_size:
|
|
34
|
+
output_size: tuple[int, int],
|
|
35
35
|
batch_size: int,
|
|
36
|
-
mean:
|
|
37
|
-
std:
|
|
36
|
+
mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
37
|
+
std: tuple[float, float, float] = (1.0, 1.0, 1.0),
|
|
38
38
|
**kwargs: Any,
|
|
39
39
|
) -> None:
|
|
40
40
|
super().__init__()
|
|
@@ -43,15 +43,13 @@ class PreProcessor(nn.Module):
|
|
|
43
43
|
# Perform the division by 255 at the same time
|
|
44
44
|
self.normalize = T.Normalize(mean, std)
|
|
45
45
|
|
|
46
|
-
def batch_inputs(self, samples:
|
|
46
|
+
def batch_inputs(self, samples: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
47
47
|
"""Gather samples into batches for inference purposes
|
|
48
48
|
|
|
49
49
|
Args:
|
|
50
|
-
----
|
|
51
50
|
samples: list of samples of shape (C, H, W)
|
|
52
51
|
|
|
53
52
|
Returns:
|
|
54
|
-
-------
|
|
55
53
|
list of batched samples (*, C, H, W)
|
|
56
54
|
"""
|
|
57
55
|
num_batches = int(math.ceil(len(samples) / self.batch_size))
|
|
@@ -62,7 +60,7 @@ class PreProcessor(nn.Module):
|
|
|
62
60
|
|
|
63
61
|
return batches
|
|
64
62
|
|
|
65
|
-
def sample_transforms(self, x:
|
|
63
|
+
def sample_transforms(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
|
|
66
64
|
if x.ndim != 3:
|
|
67
65
|
raise AssertionError("expected list of 3D Tensors")
|
|
68
66
|
if isinstance(x, np.ndarray):
|
|
@@ -79,17 +77,15 @@ class PreProcessor(nn.Module):
|
|
|
79
77
|
else:
|
|
80
78
|
x = x.to(dtype=torch.float32) # type: ignore[union-attr]
|
|
81
79
|
|
|
82
|
-
return x
|
|
80
|
+
return x # type: ignore[return-value]
|
|
83
81
|
|
|
84
|
-
def __call__(self, x:
|
|
82
|
+
def __call__(self, x: torch.Tensor | np.ndarray | list[torch.Tensor | np.ndarray]) -> list[torch.Tensor]:
|
|
85
83
|
"""Prepare document data for model forwarding
|
|
86
84
|
|
|
87
85
|
Args:
|
|
88
|
-
----
|
|
89
86
|
x: list of images (np.array) or tensors (already resized and batched)
|
|
90
87
|
|
|
91
88
|
Returns:
|
|
92
|
-
-------
|
|
93
89
|
list of page batches
|
|
94
90
|
"""
|
|
95
91
|
# Input type check
|
|
@@ -103,7 +99,7 @@ class PreProcessor(nn.Module):
|
|
|
103
99
|
elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
104
100
|
raise TypeError("unsupported data type for torch.Tensor")
|
|
105
101
|
# Resizing
|
|
106
|
-
if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
|
|
102
|
+
if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: # type: ignore[union-attr]
|
|
107
103
|
x = F.resize(
|
|
108
104
|
x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
|
|
109
105
|
)
|
|
@@ -118,11 +114,11 @@ class PreProcessor(nn.Module):
|
|
|
118
114
|
# Sample transform (to tensor, resize)
|
|
119
115
|
samples = list(multithread_exec(self.sample_transforms, x))
|
|
120
116
|
# Batching
|
|
121
|
-
batches = self.batch_inputs(samples)
|
|
117
|
+
batches = self.batch_inputs(samples) # type: ignore[assignment]
|
|
122
118
|
else:
|
|
123
119
|
raise TypeError(f"invalid input type: {type(x)}")
|
|
124
120
|
|
|
125
121
|
# Batch transforms (normalize)
|
|
126
122
|
batches = list(multithread_exec(self.normalize, batches))
|
|
127
123
|
|
|
128
|
-
return batches
|
|
124
|
+
return batches # type: ignore[return-value]
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import tensorflow as tf
|
|
@@ -20,21 +20,21 @@ class PreProcessor(NestedObject):
|
|
|
20
20
|
"""Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
----
|
|
24
23
|
output_size: expected size of each page in format (H, W)
|
|
25
24
|
batch_size: the size of page batches
|
|
26
25
|
mean: mean value of the training distribution by channel
|
|
27
26
|
std: standard deviation of the training distribution by channel
|
|
27
|
+
**kwargs: additional arguments for the resizing operation
|
|
28
28
|
"""
|
|
29
29
|
|
|
30
|
-
_children_names:
|
|
30
|
+
_children_names: list[str] = ["resize", "normalize"]
|
|
31
31
|
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
|
-
output_size:
|
|
34
|
+
output_size: tuple[int, int],
|
|
35
35
|
batch_size: int,
|
|
36
|
-
mean:
|
|
37
|
-
std:
|
|
36
|
+
mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
37
|
+
std: tuple[float, float, float] = (1.0, 1.0, 1.0),
|
|
38
38
|
**kwargs: Any,
|
|
39
39
|
) -> None:
|
|
40
40
|
self.batch_size = batch_size
|
|
@@ -43,15 +43,13 @@ class PreProcessor(NestedObject):
|
|
|
43
43
|
self.normalize = Normalize(mean, std)
|
|
44
44
|
self._runs_on_cuda = tf.config.list_physical_devices("GPU") != []
|
|
45
45
|
|
|
46
|
-
def batch_inputs(self, samples:
|
|
46
|
+
def batch_inputs(self, samples: list[tf.Tensor]) -> list[tf.Tensor]:
|
|
47
47
|
"""Gather samples into batches for inference purposes
|
|
48
48
|
|
|
49
49
|
Args:
|
|
50
|
-
----
|
|
51
50
|
samples: list of samples (tf.Tensor)
|
|
52
51
|
|
|
53
52
|
Returns:
|
|
54
|
-
-------
|
|
55
53
|
list of batched samples
|
|
56
54
|
"""
|
|
57
55
|
num_batches = int(math.ceil(len(samples) / self.batch_size))
|
|
@@ -62,7 +60,7 @@ class PreProcessor(NestedObject):
|
|
|
62
60
|
|
|
63
61
|
return batches
|
|
64
62
|
|
|
65
|
-
def sample_transforms(self, x:
|
|
63
|
+
def sample_transforms(self, x: np.ndarray | tf.Tensor) -> tf.Tensor:
|
|
66
64
|
if x.ndim != 3:
|
|
67
65
|
raise AssertionError("expected list of 3D Tensors")
|
|
68
66
|
if isinstance(x, np.ndarray):
|
|
@@ -79,15 +77,13 @@ class PreProcessor(NestedObject):
|
|
|
79
77
|
|
|
80
78
|
return x
|
|
81
79
|
|
|
82
|
-
def __call__(self, x:
|
|
80
|
+
def __call__(self, x: tf.Tensor | np.ndarray | list[tf.Tensor | np.ndarray]) -> list[tf.Tensor]:
|
|
83
81
|
"""Prepare document data for model forwarding
|
|
84
82
|
|
|
85
83
|
Args:
|
|
86
|
-
----
|
|
87
84
|
x: list of images (np.array) or tensors (already resized and batched)
|
|
88
85
|
|
|
89
86
|
Returns:
|
|
90
|
-
-------
|
|
91
87
|
list of page batches
|
|
92
88
|
"""
|
|
93
89
|
# Input type check
|
doctr/models/recognition/core.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Tuple
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -21,17 +20,15 @@ class RecognitionModel(NestedObject):
|
|
|
21
20
|
|
|
22
21
|
def build_target(
|
|
23
22
|
self,
|
|
24
|
-
gts:
|
|
25
|
-
) ->
|
|
23
|
+
gts: list[str],
|
|
24
|
+
) -> tuple[np.ndarray, list[int]]:
|
|
26
25
|
"""Encode a list of gts sequences into a np array and gives the corresponding*
|
|
27
26
|
sequence lengths.
|
|
28
27
|
|
|
29
28
|
Args:
|
|
30
|
-
----
|
|
31
29
|
gts: list of ground-truth labels
|
|
32
30
|
|
|
33
31
|
Returns:
|
|
34
|
-
-------
|
|
35
32
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
36
33
|
"""
|
|
37
34
|
encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab))
|
|
@@ -43,7 +40,6 @@ class RecognitionPostProcessor(NestedObject):
|
|
|
43
40
|
"""Abstract class to postprocess the raw output of the model
|
|
44
41
|
|
|
45
42
|
Args:
|
|
46
|
-
----
|
|
47
43
|
vocab: string containing the ordered sequence of supported characters
|
|
48
44
|
"""
|
|
49
45
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|