python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +2 -0
- doctr/datasets/cord.py +6 -4
- doctr/datasets/datasets/base.py +3 -2
- doctr/datasets/datasets/pytorch.py +4 -2
- doctr/datasets/datasets/tensorflow.py +4 -2
- doctr/datasets/detection.py +6 -3
- doctr/datasets/doc_artefacts.py +2 -1
- doctr/datasets/funsd.py +7 -8
- doctr/datasets/generator/base.py +3 -2
- doctr/datasets/generator/pytorch.py +3 -1
- doctr/datasets/generator/tensorflow.py +3 -1
- doctr/datasets/ic03.py +3 -2
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +6 -4
- doctr/datasets/iiithws.py +2 -1
- doctr/datasets/imgur5k.py +3 -2
- doctr/datasets/loader.py +4 -2
- doctr/datasets/mjsynth.py +2 -1
- doctr/datasets/ocr.py +2 -1
- doctr/datasets/orientation.py +40 -0
- doctr/datasets/recognition.py +3 -2
- doctr/datasets/sroie.py +2 -1
- doctr/datasets/svhn.py +2 -1
- doctr/datasets/svt.py +3 -2
- doctr/datasets/synthtext.py +2 -1
- doctr/datasets/utils.py +27 -11
- doctr/datasets/vocabs.py +26 -1
- doctr/datasets/wildreceipt.py +111 -0
- doctr/file_utils.py +3 -1
- doctr/io/elements.py +52 -35
- doctr/io/html.py +5 -3
- doctr/io/image/base.py +5 -4
- doctr/io/image/pytorch.py +12 -7
- doctr/io/image/tensorflow.py +11 -6
- doctr/io/pdf.py +5 -4
- doctr/io/reader.py +13 -5
- doctr/models/_utils.py +30 -53
- doctr/models/artefacts/barcode.py +4 -3
- doctr/models/artefacts/face.py +4 -2
- doctr/models/builder.py +58 -43
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +5 -2
- doctr/models/classification/magc_resnet/tensorflow.py +5 -2
- doctr/models/classification/mobilenet/pytorch.py +16 -4
- doctr/models/classification/mobilenet/tensorflow.py +29 -20
- doctr/models/classification/predictor/pytorch.py +3 -2
- doctr/models/classification/predictor/tensorflow.py +2 -1
- doctr/models/classification/resnet/pytorch.py +23 -13
- doctr/models/classification/resnet/tensorflow.py +33 -26
- doctr/models/classification/textnet/__init__.py +6 -0
- doctr/models/classification/textnet/pytorch.py +275 -0
- doctr/models/classification/textnet/tensorflow.py +267 -0
- doctr/models/classification/vgg/pytorch.py +4 -2
- doctr/models/classification/vgg/tensorflow.py +5 -2
- doctr/models/classification/vit/pytorch.py +9 -3
- doctr/models/classification/vit/tensorflow.py +9 -3
- doctr/models/classification/zoo.py +7 -2
- doctr/models/core.py +1 -1
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/pytorch.py +7 -1
- doctr/models/detection/_utils/tensorflow.py +7 -3
- doctr/models/detection/core.py +9 -3
- doctr/models/detection/differentiable_binarization/base.py +37 -25
- doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
- doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +256 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +12 -5
- doctr/models/detection/linknet/pytorch.py +28 -15
- doctr/models/detection/linknet/tensorflow.py +68 -88
- doctr/models/detection/predictor/pytorch.py +16 -6
- doctr/models/detection/predictor/tensorflow.py +13 -5
- doctr/models/detection/zoo.py +19 -16
- doctr/models/factory/hub.py +20 -10
- doctr/models/kie_predictor/base.py +2 -1
- doctr/models/kie_predictor/pytorch.py +28 -36
- doctr/models/kie_predictor/tensorflow.py +27 -27
- doctr/models/modules/__init__.py +1 -0
- doctr/models/modules/layers/__init__.py +6 -0
- doctr/models/modules/layers/pytorch.py +166 -0
- doctr/models/modules/layers/tensorflow.py +175 -0
- doctr/models/modules/transformer/pytorch.py +24 -22
- doctr/models/modules/transformer/tensorflow.py +6 -4
- doctr/models/modules/vision_transformer/pytorch.py +2 -4
- doctr/models/modules/vision_transformer/tensorflow.py +2 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
- doctr/models/predictor/base.py +14 -3
- doctr/models/predictor/pytorch.py +26 -29
- doctr/models/predictor/tensorflow.py +25 -22
- doctr/models/preprocessor/pytorch.py +14 -9
- doctr/models/preprocessor/tensorflow.py +10 -5
- doctr/models/recognition/core.py +4 -1
- doctr/models/recognition/crnn/pytorch.py +23 -16
- doctr/models/recognition/crnn/tensorflow.py +25 -17
- doctr/models/recognition/master/base.py +4 -1
- doctr/models/recognition/master/pytorch.py +20 -9
- doctr/models/recognition/master/tensorflow.py +20 -8
- doctr/models/recognition/parseq/base.py +4 -1
- doctr/models/recognition/parseq/pytorch.py +28 -22
- doctr/models/recognition/parseq/tensorflow.py +22 -11
- doctr/models/recognition/predictor/_utils.py +3 -2
- doctr/models/recognition/predictor/pytorch.py +3 -2
- doctr/models/recognition/predictor/tensorflow.py +2 -1
- doctr/models/recognition/sar/pytorch.py +14 -7
- doctr/models/recognition/sar/tensorflow.py +23 -14
- doctr/models/recognition/utils.py +5 -1
- doctr/models/recognition/vitstr/base.py +4 -1
- doctr/models/recognition/vitstr/pytorch.py +22 -13
- doctr/models/recognition/vitstr/tensorflow.py +21 -10
- doctr/models/recognition/zoo.py +4 -2
- doctr/models/utils/pytorch.py +24 -6
- doctr/models/utils/tensorflow.py +22 -3
- doctr/models/zoo.py +21 -3
- doctr/transforms/functional/base.py +8 -3
- doctr/transforms/functional/pytorch.py +23 -6
- doctr/transforms/functional/tensorflow.py +25 -5
- doctr/transforms/modules/base.py +12 -5
- doctr/transforms/modules/pytorch.py +10 -12
- doctr/transforms/modules/tensorflow.py +17 -9
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +4 -2
- doctr/utils/fonts.py +3 -2
- doctr/utils/geometry.py +95 -26
- doctr/utils/metrics.py +36 -22
- doctr/utils/multithreading.py +5 -3
- doctr/utils/repr.py +3 -1
- doctr/utils/visualization.py +31 -8
- doctr/version.py +1 -1
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
- python_doctr-0.8.1.dist-info/RECORD +173 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
- python_doctr-0.7.0.dist-info/RECORD +0 -161
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -42,12 +42,15 @@ class PositionalEncoding(layers.Layer, NestedObject):
|
|
|
42
42
|
x: tf.Tensor,
|
|
43
43
|
**kwargs: Any,
|
|
44
44
|
) -> tf.Tensor:
|
|
45
|
-
"""
|
|
45
|
+
"""Forward pass
|
|
46
|
+
|
|
46
47
|
Args:
|
|
48
|
+
----
|
|
47
49
|
x: embeddings (batch, max_len, d_model)
|
|
48
50
|
**kwargs: additional arguments
|
|
49
51
|
|
|
50
|
-
Returns
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
51
54
|
positional embeddings (batch, max_len, d_model)
|
|
52
55
|
"""
|
|
53
56
|
if x.dtype == tf.float16: # amp fix: cast to half
|
|
@@ -62,7 +65,6 @@ def scaled_dot_product_attention(
|
|
|
62
65
|
query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: Optional[tf.Tensor] = None
|
|
63
66
|
) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
64
67
|
"""Scaled Dot-Product Attention"""
|
|
65
|
-
|
|
66
68
|
scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) / math.sqrt(query.shape[-1])
|
|
67
69
|
if mask is not None:
|
|
68
70
|
# NOTE: to ensure the ONNX compatibility, tf.where works only with bool type condition
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -28,8 +28,7 @@ class PatchEmbedding(nn.Module):
|
|
|
28
28
|
self.projection = nn.Conv2d(channels, embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
|
|
29
29
|
|
|
30
30
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
31
|
-
"""
|
|
32
|
-
100 % borrowed from:
|
|
31
|
+
"""100 % borrowed from:
|
|
33
32
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py
|
|
34
33
|
|
|
35
34
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
|
@@ -38,7 +37,6 @@ class PatchEmbedding(nn.Module):
|
|
|
38
37
|
Source:
|
|
39
38
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py
|
|
40
39
|
"""
|
|
41
|
-
|
|
42
40
|
num_patches = embeddings.shape[1] - 1
|
|
43
41
|
num_positions = self.positions.shape[1] - 1
|
|
44
42
|
if num_patches == num_positions and height == width:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -45,8 +45,7 @@ class PatchEmbedding(layers.Layer, NestedObject):
|
|
|
45
45
|
)
|
|
46
46
|
|
|
47
47
|
def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
|
|
48
|
-
"""
|
|
49
|
-
100 % borrowed from:
|
|
48
|
+
"""100 % borrowed from:
|
|
50
49
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_tf_vit.py
|
|
51
50
|
|
|
52
51
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
|
@@ -55,7 +54,6 @@ class PatchEmbedding(layers.Layer, NestedObject):
|
|
|
55
54
|
Source:
|
|
56
55
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py
|
|
57
56
|
"""
|
|
58
|
-
|
|
59
57
|
seq_len, dim = embeddings.shape[1:]
|
|
60
58
|
num_patches = seq_len - 1
|
|
61
59
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -70,10 +70,12 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained: bool = False, **kwargs: Any) -
|
|
|
70
70
|
>>> out = model(input_tensor)
|
|
71
71
|
|
|
72
72
|
Args:
|
|
73
|
+
----
|
|
73
74
|
pretrained (bool): If True, returns a model pre-trained on our object detection dataset
|
|
75
|
+
**kwargs: keyword arguments of the FasterRCNN architecture
|
|
74
76
|
|
|
75
77
|
Returns:
|
|
78
|
+
-------
|
|
76
79
|
object detection architecture
|
|
77
80
|
"""
|
|
78
|
-
|
|
79
81
|
return _fasterrcnn("fasterrcnn_mobilenet_v3_large_fpn", pretrained, **kwargs)
|
doctr/models/predictor/base.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any, List, Optional, Tuple
|
|
6
|
+
from typing import Any, Callable, List, Optional, Tuple
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
@@ -21,6 +21,7 @@ class _OCRPredictor:
|
|
|
21
21
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
22
22
|
|
|
23
23
|
Args:
|
|
24
|
+
----
|
|
24
25
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
25
26
|
without rotated textual elements.
|
|
26
27
|
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
|
|
@@ -28,7 +29,7 @@ class _OCRPredictor:
|
|
|
28
29
|
accordingly. Doing so will improve performances for documents with page-uniform rotations.
|
|
29
30
|
preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
|
|
30
31
|
symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
|
|
31
|
-
kwargs: keyword args of `DocumentBuilder`
|
|
32
|
+
**kwargs: keyword args of `DocumentBuilder`
|
|
32
33
|
"""
|
|
33
34
|
|
|
34
35
|
crop_orientation_predictor: Optional[CropOrientationPredictor]
|
|
@@ -47,6 +48,7 @@ class _OCRPredictor:
|
|
|
47
48
|
self.doc_builder = DocumentBuilder(**kwargs)
|
|
48
49
|
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
49
50
|
self.symmetric_pad = symmetric_pad
|
|
51
|
+
self.hooks: List[Callable] = []
|
|
50
52
|
|
|
51
53
|
@staticmethod
|
|
52
54
|
def _generate_crops(
|
|
@@ -148,3 +150,12 @@ class _OCRPredictor:
|
|
|
148
150
|
_idx += page_boxes.shape[0]
|
|
149
151
|
|
|
150
152
|
return loc_preds, text_preds
|
|
153
|
+
|
|
154
|
+
def add_hook(self, hook: Callable) -> None:
|
|
155
|
+
"""Add a hook to the predictor
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
----
|
|
159
|
+
hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
|
|
160
|
+
"""
|
|
161
|
+
self.hooks.append(hook)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -13,7 +13,7 @@ from doctr.io.elements import Document
|
|
|
13
13
|
from doctr.models._utils import estimate_orientation, get_language
|
|
14
14
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
15
15
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
16
|
-
from doctr.utils.geometry import
|
|
16
|
+
from doctr.utils.geometry import rotate_image
|
|
17
17
|
|
|
18
18
|
from .base import _OCRPredictor
|
|
19
19
|
|
|
@@ -24,6 +24,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
+
----
|
|
27
28
|
det_predictor: detection module
|
|
28
29
|
reco_predictor: recognition module
|
|
29
30
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -35,7 +36,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
35
36
|
page. Doing so will slightly deteriorate the overall latency.
|
|
36
37
|
detect_language: if True, the language prediction will be added to the predictions for each
|
|
37
38
|
page. Doing so will slightly deteriorate the overall latency.
|
|
38
|
-
kwargs: keyword args of `DocumentBuilder`
|
|
39
|
+
**kwargs: keyword args of `DocumentBuilder`
|
|
39
40
|
"""
|
|
40
41
|
|
|
41
42
|
def __init__(
|
|
@@ -59,7 +60,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
59
60
|
self.detect_orientation = detect_orientation
|
|
60
61
|
self.detect_language = detect_language
|
|
61
62
|
|
|
62
|
-
@torch.
|
|
63
|
+
@torch.inference_mode()
|
|
63
64
|
def forward(
|
|
64
65
|
self,
|
|
65
66
|
pages: List[Union[np.ndarray, torch.Tensor]],
|
|
@@ -71,11 +72,18 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
71
72
|
|
|
72
73
|
origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages]
|
|
73
74
|
|
|
75
|
+
# Localize text elements
|
|
76
|
+
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
|
|
77
|
+
|
|
74
78
|
# Detect document rotation and rotate pages
|
|
79
|
+
seg_maps = [
|
|
80
|
+
np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8)
|
|
81
|
+
for out_map in out_maps
|
|
82
|
+
]
|
|
75
83
|
if self.detect_orientation:
|
|
76
|
-
origin_page_orientations = [estimate_orientation(
|
|
84
|
+
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
77
85
|
orientations = [
|
|
78
|
-
{"value": orientation_page, "confidence":
|
|
86
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
|
|
79
87
|
]
|
|
80
88
|
else:
|
|
81
89
|
orientations = None
|
|
@@ -83,15 +91,12 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
83
91
|
origin_page_orientations = (
|
|
84
92
|
origin_page_orientations
|
|
85
93
|
if self.detect_orientation
|
|
86
|
-
else [estimate_orientation(
|
|
94
|
+
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
87
95
|
)
|
|
88
|
-
pages = [
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
]
|
|
96
|
+
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
97
|
+
# Forward again to get predictions on straight pages
|
|
98
|
+
loc_preds = self.det_predictor(pages, **kwargs)
|
|
92
99
|
|
|
93
|
-
# Localize text elements
|
|
94
|
-
loc_preds = self.det_predictor(pages, **kwargs)
|
|
95
100
|
assert all(
|
|
96
101
|
len(loc_pred) == 1 for loc_pred in loc_preds
|
|
97
102
|
), "Detection Model in ocr_predictor should output only one class"
|
|
@@ -101,11 +106,15 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
101
106
|
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
102
107
|
|
|
103
108
|
# Rectify crops if aspect ratio
|
|
104
|
-
loc_preds = self._remove_padding(pages, loc_preds)
|
|
109
|
+
loc_preds = self._remove_padding(pages, loc_preds)
|
|
110
|
+
|
|
111
|
+
# Apply hooks to loc_preds if any
|
|
112
|
+
for hook in self.hooks:
|
|
113
|
+
loc_preds = hook(loc_preds)
|
|
105
114
|
|
|
106
115
|
# Crop images
|
|
107
116
|
crops, loc_preds = self._prepare_crops(
|
|
108
|
-
pages,
|
|
117
|
+
pages,
|
|
109
118
|
loc_preds,
|
|
110
119
|
channels_last=channels_last,
|
|
111
120
|
assume_straight_pages=self.assume_straight_pages,
|
|
@@ -123,24 +132,12 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
123
132
|
languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
|
|
124
133
|
else:
|
|
125
134
|
languages_dict = None
|
|
126
|
-
# Rotate back pages and boxes while keeping original image size
|
|
127
|
-
if self.straighten_pages:
|
|
128
|
-
boxes = [
|
|
129
|
-
rotate_boxes(
|
|
130
|
-
page_boxes,
|
|
131
|
-
angle,
|
|
132
|
-
orig_shape=page.shape[:2]
|
|
133
|
-
if isinstance(page, np.ndarray)
|
|
134
|
-
else page.shape[1:], # type: ignore[arg-type]
|
|
135
|
-
target_shape=mask, # type: ignore[arg-type]
|
|
136
|
-
)
|
|
137
|
-
for page_boxes, page, angle, mask in zip(boxes, pages, origin_page_orientations, origin_page_shapes)
|
|
138
|
-
]
|
|
139
135
|
|
|
140
136
|
out = self.doc_builder(
|
|
137
|
+
pages,
|
|
141
138
|
boxes,
|
|
142
139
|
text_preds,
|
|
143
|
-
|
|
140
|
+
origin_page_shapes,
|
|
144
141
|
orientations,
|
|
145
142
|
languages_dict,
|
|
146
143
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -12,7 +12,7 @@ from doctr.io.elements import Document
|
|
|
12
12
|
from doctr.models._utils import estimate_orientation, get_language
|
|
13
13
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
14
14
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
15
|
-
from doctr.utils.geometry import
|
|
15
|
+
from doctr.utils.geometry import rotate_image
|
|
16
16
|
from doctr.utils.repr import NestedObject
|
|
17
17
|
|
|
18
18
|
from .base import _OCRPredictor
|
|
@@ -24,6 +24,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
+
----
|
|
27
28
|
det_predictor: detection module
|
|
28
29
|
reco_predictor: recognition module
|
|
29
30
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -35,7 +36,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
35
36
|
page. Doing so will slightly deteriorate the overall latency.
|
|
36
37
|
detect_language: if True, the language prediction will be added to the predictions for each
|
|
37
38
|
page. Doing so will slightly deteriorate the overall latency.
|
|
38
|
-
kwargs: keyword args of `DocumentBuilder`
|
|
39
|
+
**kwargs: keyword args of `DocumentBuilder`
|
|
39
40
|
"""
|
|
40
41
|
|
|
41
42
|
_children_names = ["det_predictor", "reco_predictor", "doc_builder"]
|
|
@@ -71,31 +72,43 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
71
72
|
|
|
72
73
|
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
73
74
|
|
|
75
|
+
# Localize text elements
|
|
76
|
+
loc_preds_dict, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
|
|
77
|
+
|
|
74
78
|
# Detect document rotation and rotate pages
|
|
79
|
+
seg_maps = [
|
|
80
|
+
np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8)
|
|
81
|
+
for out_map in out_maps
|
|
82
|
+
]
|
|
75
83
|
if self.detect_orientation:
|
|
76
|
-
origin_page_orientations = [estimate_orientation(
|
|
84
|
+
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
77
85
|
orientations = [
|
|
78
|
-
{"value": orientation_page, "confidence":
|
|
86
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
|
|
79
87
|
]
|
|
80
88
|
else:
|
|
81
89
|
orientations = None
|
|
82
90
|
if self.straighten_pages:
|
|
83
91
|
origin_page_orientations = (
|
|
84
|
-
origin_page_orientations
|
|
92
|
+
origin_page_orientations
|
|
93
|
+
if self.detect_orientation
|
|
94
|
+
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
85
95
|
)
|
|
86
|
-
pages = [rotate_image(page, -angle, expand=
|
|
96
|
+
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
97
|
+
# forward again to get predictions on straight pages
|
|
98
|
+
loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
|
|
87
99
|
|
|
88
|
-
# Localize text elements
|
|
89
|
-
loc_preds_dict = self.det_predictor(pages, **kwargs)
|
|
90
100
|
assert all(
|
|
91
101
|
len(loc_pred) == 1 for loc_pred in loc_preds_dict
|
|
92
102
|
), "Detection Model in ocr_predictor should output only one class"
|
|
93
|
-
|
|
94
|
-
loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict]
|
|
103
|
+
loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr]
|
|
95
104
|
|
|
96
105
|
# Rectify crops if aspect ratio
|
|
97
106
|
loc_preds = self._remove_padding(pages, loc_preds)
|
|
98
107
|
|
|
108
|
+
# Apply hooks to loc_preds if any
|
|
109
|
+
for hook in self.hooks:
|
|
110
|
+
loc_preds = hook(loc_preds)
|
|
111
|
+
|
|
99
112
|
# Crop images
|
|
100
113
|
crops, loc_preds = self._prepare_crops(
|
|
101
114
|
pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
|
|
@@ -114,19 +127,9 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
114
127
|
languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
|
|
115
128
|
else:
|
|
116
129
|
languages_dict = None
|
|
117
|
-
# Rotate back pages and boxes while keeping original image size
|
|
118
|
-
if self.straighten_pages:
|
|
119
|
-
boxes = [
|
|
120
|
-
rotate_boxes(
|
|
121
|
-
page_boxes,
|
|
122
|
-
angle,
|
|
123
|
-
orig_shape=page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:],
|
|
124
|
-
target_shape=mask, # type: ignore[arg-type]
|
|
125
|
-
)
|
|
126
|
-
for page_boxes, page, angle, mask in zip(boxes, pages, origin_page_orientations, origin_page_shapes)
|
|
127
|
-
]
|
|
128
130
|
|
|
129
131
|
out = self.doc_builder(
|
|
132
|
+
pages,
|
|
130
133
|
boxes,
|
|
131
134
|
text_preds,
|
|
132
135
|
origin_page_shapes, # type: ignore[arg-type]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -22,6 +22,7 @@ class PreProcessor(nn.Module):
|
|
|
22
22
|
"""Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
|
|
23
23
|
|
|
24
24
|
Args:
|
|
25
|
+
----
|
|
25
26
|
output_size: expected size of each page in format (H, W)
|
|
26
27
|
batch_size: the size of page batches
|
|
27
28
|
mean: mean value of the training distribution by channel
|
|
@@ -34,7 +35,6 @@ class PreProcessor(nn.Module):
|
|
|
34
35
|
batch_size: int,
|
|
35
36
|
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
36
37
|
std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
|
37
|
-
fp16: bool = False,
|
|
38
38
|
**kwargs: Any,
|
|
39
39
|
) -> None:
|
|
40
40
|
super().__init__()
|
|
@@ -47,12 +47,13 @@ class PreProcessor(nn.Module):
|
|
|
47
47
|
"""Gather samples into batches for inference purposes
|
|
48
48
|
|
|
49
49
|
Args:
|
|
50
|
+
----
|
|
50
51
|
samples: list of samples of shape (C, H, W)
|
|
51
52
|
|
|
52
53
|
Returns:
|
|
54
|
+
-------
|
|
53
55
|
list of batched samples (*, C, H, W)
|
|
54
56
|
"""
|
|
55
|
-
|
|
56
57
|
num_batches = int(math.ceil(len(samples) / self.batch_size))
|
|
57
58
|
batches = [
|
|
58
59
|
torch.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], dim=0)
|
|
@@ -78,17 +79,19 @@ class PreProcessor(nn.Module):
|
|
|
78
79
|
else:
|
|
79
80
|
x = x.to(dtype=torch.float32) # type: ignore[union-attr]
|
|
80
81
|
|
|
81
|
-
return x
|
|
82
|
+
return x # type: ignore[return-value]
|
|
82
83
|
|
|
83
84
|
def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]) -> List[torch.Tensor]:
|
|
84
85
|
"""Prepare document data for model forwarding
|
|
85
86
|
|
|
86
87
|
Args:
|
|
88
|
+
----
|
|
87
89
|
x: list of images (np.array) or tensors (already resized and batched)
|
|
90
|
+
|
|
88
91
|
Returns:
|
|
92
|
+
-------
|
|
89
93
|
list of page batches
|
|
90
94
|
"""
|
|
91
|
-
|
|
92
95
|
# Input type check
|
|
93
96
|
if isinstance(x, (np.ndarray, torch.Tensor)):
|
|
94
97
|
if x.ndim != 4:
|
|
@@ -100,8 +103,10 @@ class PreProcessor(nn.Module):
|
|
|
100
103
|
elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
101
104
|
raise TypeError("unsupported data type for torch.Tensor")
|
|
102
105
|
# Resizing
|
|
103
|
-
if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
|
|
104
|
-
x = F.resize(
|
|
106
|
+
if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: # type: ignore[union-attr]
|
|
107
|
+
x = F.resize(
|
|
108
|
+
x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
|
|
109
|
+
)
|
|
105
110
|
# Data type
|
|
106
111
|
if x.dtype == torch.uint8: # type: ignore[union-attr]
|
|
107
112
|
x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr]
|
|
@@ -113,11 +118,11 @@ class PreProcessor(nn.Module):
|
|
|
113
118
|
# Sample transform (to tensor, resize)
|
|
114
119
|
samples = list(multithread_exec(self.sample_transforms, x))
|
|
115
120
|
# Batching
|
|
116
|
-
batches = self.batch_inputs(samples)
|
|
121
|
+
batches = self.batch_inputs(samples) # type: ignore[assignment]
|
|
117
122
|
else:
|
|
118
123
|
raise TypeError(f"invalid input type: {type(x)}")
|
|
119
124
|
|
|
120
125
|
# Batch transforms (normalize)
|
|
121
126
|
batches = list(multithread_exec(self.normalize, batches))
|
|
122
127
|
|
|
123
|
-
return batches
|
|
128
|
+
return batches # type: ignore[return-value]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -20,6 +20,7 @@ class PreProcessor(NestedObject):
|
|
|
20
20
|
"""Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
+
----
|
|
23
24
|
output_size: expected size of each page in format (H, W)
|
|
24
25
|
batch_size: the size of page batches
|
|
25
26
|
mean: mean value of the training distribution by channel
|
|
@@ -34,7 +35,6 @@ class PreProcessor(NestedObject):
|
|
|
34
35
|
batch_size: int,
|
|
35
36
|
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
36
37
|
std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
|
37
|
-
fp16: bool = False,
|
|
38
38
|
**kwargs: Any,
|
|
39
39
|
) -> None:
|
|
40
40
|
self.batch_size = batch_size
|
|
@@ -46,12 +46,13 @@ class PreProcessor(NestedObject):
|
|
|
46
46
|
"""Gather samples into batches for inference purposes
|
|
47
47
|
|
|
48
48
|
Args:
|
|
49
|
+
----
|
|
49
50
|
samples: list of samples (tf.Tensor)
|
|
50
51
|
|
|
51
52
|
Returns:
|
|
53
|
+
-------
|
|
52
54
|
list of batched samples
|
|
53
55
|
"""
|
|
54
|
-
|
|
55
56
|
num_batches = int(math.ceil(len(samples) / self.batch_size))
|
|
56
57
|
batches = [
|
|
57
58
|
tf.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], axis=0)
|
|
@@ -81,11 +82,13 @@ class PreProcessor(NestedObject):
|
|
|
81
82
|
"""Prepare document data for model forwarding
|
|
82
83
|
|
|
83
84
|
Args:
|
|
85
|
+
----
|
|
84
86
|
x: list of images (np.array) or tensors (already resized and batched)
|
|
87
|
+
|
|
85
88
|
Returns:
|
|
89
|
+
-------
|
|
86
90
|
list of page batches
|
|
87
91
|
"""
|
|
88
|
-
|
|
89
92
|
# Input type check
|
|
90
93
|
if isinstance(x, (np.ndarray, tf.Tensor)):
|
|
91
94
|
if x.ndim != 4:
|
|
@@ -102,7 +105,9 @@ class PreProcessor(NestedObject):
|
|
|
102
105
|
x = tf.image.convert_image_dtype(x, dtype=tf.float32)
|
|
103
106
|
# Resizing
|
|
104
107
|
if (x.shape[1], x.shape[2]) != self.resize.output_size:
|
|
105
|
-
x = tf.image.resize(
|
|
108
|
+
x = tf.image.resize(
|
|
109
|
+
x, self.resize.output_size, method=self.resize.method, antialias=self.resize.antialias
|
|
110
|
+
)
|
|
106
111
|
|
|
107
112
|
batches = [x]
|
|
108
113
|
|
doctr/models/recognition/core.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -27,9 +27,11 @@ class RecognitionModel(NestedObject):
|
|
|
27
27
|
sequence lengths.
|
|
28
28
|
|
|
29
29
|
Args:
|
|
30
|
+
----
|
|
30
31
|
gts: list of ground-truth labels
|
|
31
32
|
|
|
32
33
|
Returns:
|
|
34
|
+
-------
|
|
33
35
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
34
36
|
"""
|
|
35
37
|
encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab))
|
|
@@ -41,6 +43,7 @@ class RecognitionPostProcessor(NestedObject):
|
|
|
41
43
|
"""Abstract class to postprocess the raw output of the model
|
|
42
44
|
|
|
43
45
|
Args:
|
|
46
|
+
----
|
|
44
47
|
vocab: string containing the ordered sequence of supported characters
|
|
45
48
|
"""
|
|
46
49
|
|