python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
doctr/models/utils/tensorflow.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import tensorflow as tf
|
|
10
11
|
import tf2onnx
|
|
12
|
+
import validators
|
|
11
13
|
from tensorflow.keras import Model, layers
|
|
12
14
|
|
|
13
15
|
from doctr.utils.data import download_from_url
|
|
@@ -39,7 +41,6 @@ def _build_model(model: Model):
|
|
|
39
41
|
"""Build a model by calling it once with dummy input
|
|
40
42
|
|
|
41
43
|
Args:
|
|
42
|
-
----
|
|
43
44
|
model: the model to be built
|
|
44
45
|
"""
|
|
45
46
|
model(tf.zeros((1, *model.cfg["input_shape"])), training=False)
|
|
@@ -47,8 +48,8 @@ def _build_model(model: Model):
|
|
|
47
48
|
|
|
48
49
|
def load_pretrained_params(
|
|
49
50
|
model: Model,
|
|
50
|
-
|
|
51
|
-
hash_prefix:
|
|
51
|
+
path_or_url: str | None = None,
|
|
52
|
+
hash_prefix: str | None = None,
|
|
52
53
|
skip_mismatch: bool = False,
|
|
53
54
|
**kwargs: Any,
|
|
54
55
|
) -> None:
|
|
@@ -58,29 +59,34 @@ def load_pretrained_params(
|
|
|
58
59
|
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5")
|
|
59
60
|
|
|
60
61
|
Args:
|
|
61
|
-
----
|
|
62
62
|
model: the keras model to be loaded
|
|
63
|
-
|
|
63
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
64
64
|
hash_prefix: first characters of SHA256 expected hash
|
|
65
65
|
skip_mismatch: skip loading layers with mismatched shapes
|
|
66
66
|
**kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
|
|
67
67
|
"""
|
|
68
|
-
if
|
|
69
|
-
logging.warning("
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
68
|
+
if path_or_url is None:
|
|
69
|
+
logging.warning("No model URL or Path provided, using default initialization.")
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
archive_path = (
|
|
73
|
+
download_from_url(path_or_url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
|
|
74
|
+
if validators.url(path_or_url)
|
|
75
|
+
else path_or_url
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Load weights
|
|
79
|
+
model.load_weights(archive_path, skip_mismatch=skip_mismatch)
|
|
74
80
|
|
|
75
81
|
|
|
76
82
|
def conv_sequence(
|
|
77
83
|
out_channels: int,
|
|
78
|
-
activation:
|
|
84
|
+
activation: str | Callable | None = None,
|
|
79
85
|
bn: bool = False,
|
|
80
86
|
padding: str = "same",
|
|
81
87
|
kernel_initializer: str = "he_normal",
|
|
82
88
|
**kwargs: Any,
|
|
83
|
-
) ->
|
|
89
|
+
) -> list[layers.Layer]:
|
|
84
90
|
"""Builds a convolutional-based layer sequence
|
|
85
91
|
|
|
86
92
|
>>> from tensorflow.keras import Sequential
|
|
@@ -88,7 +94,6 @@ def conv_sequence(
|
|
|
88
94
|
>>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3]))
|
|
89
95
|
|
|
90
96
|
Args:
|
|
91
|
-
----
|
|
92
97
|
out_channels: number of output channels
|
|
93
98
|
activation: activation to be used (default: no activation)
|
|
94
99
|
bn: should a batch normalization layer be added
|
|
@@ -97,7 +102,6 @@ def conv_sequence(
|
|
|
97
102
|
**kwargs: additional arguments to be passed to the convolutional layer
|
|
98
103
|
|
|
99
104
|
Returns:
|
|
100
|
-
-------
|
|
101
105
|
list of layers
|
|
102
106
|
"""
|
|
103
107
|
# No bias before Batch norm
|
|
@@ -125,12 +129,11 @@ class IntermediateLayerGetter(Model):
|
|
|
125
129
|
>>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers)
|
|
126
130
|
|
|
127
131
|
Args:
|
|
128
|
-
----
|
|
129
132
|
model: the model to extract feature maps from
|
|
130
133
|
layer_names: the list of layers to retrieve the feature map from
|
|
131
134
|
"""
|
|
132
135
|
|
|
133
|
-
def __init__(self, model: Model, layer_names:
|
|
136
|
+
def __init__(self, model: Model, layer_names: list[str]) -> None:
|
|
134
137
|
intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names]
|
|
135
138
|
super().__init__(model.input, outputs=intermediate_fmaps)
|
|
136
139
|
|
|
@@ -139,8 +142,8 @@ class IntermediateLayerGetter(Model):
|
|
|
139
142
|
|
|
140
143
|
|
|
141
144
|
def export_model_to_onnx(
|
|
142
|
-
model: Model, model_name: str, dummy_input:
|
|
143
|
-
) ->
|
|
145
|
+
model: Model, model_name: str, dummy_input: list[tf.TensorSpec], **kwargs: Any
|
|
146
|
+
) -> tuple[str, list[str]]:
|
|
144
147
|
"""Export model to ONNX format.
|
|
145
148
|
|
|
146
149
|
>>> import tensorflow as tf
|
|
@@ -151,16 +154,18 @@ def export_model_to_onnx(
|
|
|
151
154
|
>>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")])
|
|
152
155
|
|
|
153
156
|
Args:
|
|
154
|
-
----
|
|
155
157
|
model: the keras model to be exported
|
|
156
158
|
model_name: the name for the exported model
|
|
157
159
|
dummy_input: the dummy input to the model
|
|
158
160
|
kwargs: additional arguments to be passed to tf2onnx
|
|
159
161
|
|
|
160
162
|
Returns:
|
|
161
|
-
-------
|
|
162
163
|
the path to the exported model and a list with the output layer names
|
|
163
164
|
"""
|
|
165
|
+
# get the users eager mode
|
|
166
|
+
eager_mode = tf.executing_eagerly()
|
|
167
|
+
# set eager mode to true to avoid issues with tf2onnx
|
|
168
|
+
tf.config.run_functions_eagerly(True)
|
|
164
169
|
large_model = kwargs.get("large_model", False)
|
|
165
170
|
model_proto, _ = tf2onnx.convert.from_keras(
|
|
166
171
|
model,
|
|
@@ -171,6 +176,9 @@ def export_model_to_onnx(
|
|
|
171
176
|
# Get the output layer names
|
|
172
177
|
output = [n.name for n in model_proto.graph.output]
|
|
173
178
|
|
|
179
|
+
# reset the eager mode to the users mode
|
|
180
|
+
tf.config.run_functions_eagerly(eager_mode)
|
|
181
|
+
|
|
174
182
|
# models which are too large (weights > 2GB while converting to ONNX) needs to be handled
|
|
175
183
|
# about an external tensor storage where the graph and weights are seperatly stored in a archive
|
|
176
184
|
if large_model:
|
doctr/models/zoo.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -83,7 +83,6 @@ def ocr_predictor(
|
|
|
83
83
|
>>> out = model([input_page])
|
|
84
84
|
|
|
85
85
|
Args:
|
|
86
|
-
----
|
|
87
86
|
det_arch: name of the detection architecture or the model itself to use
|
|
88
87
|
(e.g. 'db_resnet50', 'db_mobilenet_v3_large')
|
|
89
88
|
reco_arch: name of the recognition architecture or the model itself to use
|
|
@@ -108,7 +107,6 @@ def ocr_predictor(
|
|
|
108
107
|
kwargs: keyword args of `OCRPredictor`
|
|
109
108
|
|
|
110
109
|
Returns:
|
|
111
|
-
-------
|
|
112
110
|
OCR predictor
|
|
113
111
|
"""
|
|
114
112
|
return _predictor(
|
|
@@ -197,7 +195,6 @@ def kie_predictor(
|
|
|
197
195
|
>>> out = model([input_page])
|
|
198
196
|
|
|
199
197
|
Args:
|
|
200
|
-
----
|
|
201
198
|
det_arch: name of the detection architecture or the model itself to use
|
|
202
199
|
(e.g. 'db_resnet50', 'db_mobilenet_v3_large')
|
|
203
200
|
reco_arch: name of the recognition architecture or the model itself to use
|
|
@@ -222,7 +219,6 @@ def kie_predictor(
|
|
|
222
219
|
kwargs: keyword args of `OCRPredictor`
|
|
223
220
|
|
|
224
221
|
Returns:
|
|
225
|
-
-------
|
|
226
222
|
KIE predictor
|
|
227
223
|
"""
|
|
228
224
|
return _kie_predictor(
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Tuple, Union
|
|
7
6
|
|
|
8
7
|
import cv2
|
|
9
8
|
import numpy as np
|
|
@@ -15,17 +14,15 @@ __all__ = ["crop_boxes", "create_shadow_mask"]
|
|
|
15
14
|
|
|
16
15
|
def crop_boxes(
|
|
17
16
|
boxes: np.ndarray,
|
|
18
|
-
crop_box:
|
|
17
|
+
crop_box: tuple[int, int, int, int] | tuple[float, float, float, float],
|
|
19
18
|
) -> np.ndarray:
|
|
20
19
|
"""Crop localization boxes
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
----
|
|
24
22
|
boxes: ndarray of shape (N, 4) in relative or abs coordinates
|
|
25
23
|
crop_box: box (xmin, ymin, xmax, ymax) to crop the image, in the same coord format that the boxes
|
|
26
24
|
|
|
27
25
|
Returns:
|
|
28
|
-
-------
|
|
29
26
|
the cropped boxes
|
|
30
27
|
"""
|
|
31
28
|
is_box_rel = boxes.max() <= 1
|
|
@@ -49,17 +46,15 @@ def crop_boxes(
|
|
|
49
46
|
return boxes[is_valid]
|
|
50
47
|
|
|
51
48
|
|
|
52
|
-
def expand_line(line: np.ndarray, target_shape:
|
|
49
|
+
def expand_line(line: np.ndarray, target_shape: tuple[int, int]) -> tuple[float, float]:
|
|
53
50
|
"""Expands a 2-point line, so that the first is on the edge. In other terms, we extend the line in
|
|
54
51
|
the same direction until we meet one of the edges.
|
|
55
52
|
|
|
56
53
|
Args:
|
|
57
|
-
----
|
|
58
54
|
line: array of shape (2, 2) of the point supposed to be on one edge, and the shadow tip.
|
|
59
55
|
target_shape: the desired mask shape
|
|
60
56
|
|
|
61
57
|
Returns:
|
|
62
|
-
-------
|
|
63
58
|
2D coordinates of the first point once we extended the line (on one of the edges)
|
|
64
59
|
"""
|
|
65
60
|
if any(coord == 0 or coord == size for coord, size in zip(line[0], target_shape[::-1])):
|
|
@@ -112,7 +107,7 @@ def expand_line(line: np.ndarray, target_shape: Tuple[int, int]) -> Tuple[float,
|
|
|
112
107
|
|
|
113
108
|
|
|
114
109
|
def create_shadow_mask(
|
|
115
|
-
target_shape:
|
|
110
|
+
target_shape: tuple[int, int],
|
|
116
111
|
min_base_width=0.3,
|
|
117
112
|
max_tip_width=0.5,
|
|
118
113
|
max_tip_height=0.3,
|
|
@@ -120,14 +115,12 @@ def create_shadow_mask(
|
|
|
120
115
|
"""Creates a random shadow mask
|
|
121
116
|
|
|
122
117
|
Args:
|
|
123
|
-
----
|
|
124
118
|
target_shape: the target shape (H, W)
|
|
125
119
|
min_base_width: the relative minimum shadow base width
|
|
126
120
|
max_tip_width: the relative maximum shadow tip width
|
|
127
121
|
max_tip_height: the relative maximum shadow tip height
|
|
128
122
|
|
|
129
123
|
Returns:
|
|
130
|
-
-------
|
|
131
124
|
a numpy ndarray of shape (H, W, 1) with values in the range [0, 1]
|
|
132
125
|
"""
|
|
133
126
|
# Default base is top
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Tuple
|
|
8
7
|
|
|
9
8
|
import numpy as np
|
|
10
9
|
import torch
|
|
10
|
+
from scipy.ndimage import gaussian_filter
|
|
11
11
|
from torchvision.transforms import functional as F
|
|
12
12
|
|
|
13
13
|
from doctr.utils.geometry import rotate_abs_geoms
|
|
@@ -21,12 +21,10 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
|
|
|
21
21
|
"""Invert the colors of an image
|
|
22
22
|
|
|
23
23
|
Args:
|
|
24
|
-
----
|
|
25
24
|
img : torch.Tensor, the image to invert
|
|
26
25
|
min_val : minimum value of the random shift
|
|
27
26
|
|
|
28
27
|
Returns:
|
|
29
|
-
-------
|
|
30
28
|
the inverted image
|
|
31
29
|
"""
|
|
32
30
|
out = F.rgb_to_grayscale(img, num_output_channels=3)
|
|
@@ -35,9 +33,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
|
|
|
35
33
|
rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
|
|
36
34
|
# Inverse the color
|
|
37
35
|
if out.dtype == torch.uint8:
|
|
38
|
-
out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
|
|
36
|
+
out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) # type: ignore[attr-defined]
|
|
39
37
|
else:
|
|
40
|
-
out = out * rgb_shift.to(dtype=out.dtype)
|
|
38
|
+
out = out * rgb_shift.to(dtype=out.dtype) # type: ignore[attr-defined]
|
|
41
39
|
# Inverse the color
|
|
42
40
|
out = 255 - out if out.dtype == torch.uint8 else 1 - out
|
|
43
41
|
return out
|
|
@@ -48,18 +46,16 @@ def rotate_sample(
|
|
|
48
46
|
geoms: np.ndarray,
|
|
49
47
|
angle: float,
|
|
50
48
|
expand: bool = False,
|
|
51
|
-
) ->
|
|
49
|
+
) -> tuple[torch.Tensor, np.ndarray]:
|
|
52
50
|
"""Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
|
|
53
51
|
|
|
54
52
|
Args:
|
|
55
|
-
----
|
|
56
53
|
img: image to rotate
|
|
57
54
|
geoms: array of geometries of shape (N, 4) or (N, 4, 2)
|
|
58
55
|
angle: angle in degrees. +: counter-clockwise, -: clockwise
|
|
59
56
|
expand: whether the image should be padded before the rotation
|
|
60
57
|
|
|
61
58
|
Returns:
|
|
62
|
-
-------
|
|
63
59
|
A tuple of rotated img (tensor), rotated geometries of shape (N, 4, 2)
|
|
64
60
|
"""
|
|
65
61
|
rotated_img = F.rotate(img, angle=angle, fill=0, expand=expand) # Interpolation NEAREST by default
|
|
@@ -81,7 +77,7 @@ def rotate_sample(
|
|
|
81
77
|
rotated_geoms: np.ndarray = rotate_abs_geoms(
|
|
82
78
|
_geoms,
|
|
83
79
|
angle,
|
|
84
|
-
img.shape[1:],
|
|
80
|
+
img.shape[1:],
|
|
85
81
|
expand,
|
|
86
82
|
).astype(np.float32)
|
|
87
83
|
|
|
@@ -93,18 +89,16 @@ def rotate_sample(
|
|
|
93
89
|
|
|
94
90
|
|
|
95
91
|
def crop_detection(
|
|
96
|
-
img: torch.Tensor, boxes: np.ndarray, crop_box:
|
|
97
|
-
) ->
|
|
92
|
+
img: torch.Tensor, boxes: np.ndarray, crop_box: tuple[float, float, float, float]
|
|
93
|
+
) -> tuple[torch.Tensor, np.ndarray]:
|
|
98
94
|
"""Crop and image and associated bboxes
|
|
99
95
|
|
|
100
96
|
Args:
|
|
101
|
-
----
|
|
102
97
|
img: image to crop
|
|
103
98
|
boxes: array of boxes to clip, absolute (int) or relative (float)
|
|
104
99
|
crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords.
|
|
105
100
|
|
|
106
101
|
Returns:
|
|
107
|
-
-------
|
|
108
102
|
A tuple of cropped image, cropped boxes, where the image is not resized.
|
|
109
103
|
"""
|
|
110
104
|
if any(val < 0 or val > 1 for val in crop_box):
|
|
@@ -119,27 +113,25 @@ def crop_detection(
|
|
|
119
113
|
return cropped_img, boxes
|
|
120
114
|
|
|
121
115
|
|
|
122
|
-
def random_shadow(img: torch.Tensor, opacity_range:
|
|
123
|
-
"""
|
|
116
|
+
def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwargs) -> torch.Tensor:
|
|
117
|
+
"""Apply a random shadow effect to an image using NumPy for blurring.
|
|
124
118
|
|
|
125
119
|
Args:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
**kwargs: additional arguments to pass to `create_shadow_mask`
|
|
120
|
+
img: Image to modify (C, H, W) as a PyTorch tensor.
|
|
121
|
+
opacity_range: The minimum and maximum desired opacity of the shadow.
|
|
122
|
+
**kwargs: Additional arguments to pass to `create_shadow_mask`.
|
|
130
123
|
|
|
131
124
|
Returns:
|
|
132
|
-
|
|
133
|
-
shaded image
|
|
125
|
+
Shadowed image as a PyTorch tensor (same shape as input).
|
|
134
126
|
"""
|
|
135
|
-
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
|
|
136
|
-
|
|
127
|
+
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
|
|
137
128
|
opacity = np.random.uniform(*opacity_range)
|
|
138
|
-
shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...])
|
|
139
129
|
|
|
140
|
-
#
|
|
141
|
-
k = 7 + 2 * int(4 * np.random.rand(1))
|
|
130
|
+
# Apply Gaussian blur to the shadow mask
|
|
142
131
|
sigma = np.random.uniform(0.5, 5.0)
|
|
143
|
-
|
|
132
|
+
blurred_mask = gaussian_filter(shadow_mask, sigma=sigma)
|
|
133
|
+
|
|
134
|
+
shadow_tensor = 1 - torch.from_numpy(blurred_mask).float()
|
|
135
|
+
shadow_tensor = shadow_tensor.to(img.device).unsqueeze(0) # Add channel dimension
|
|
144
136
|
|
|
145
137
|
return opacity * shadow_tensor * img + (1 - opacity) * img
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
7
|
import random
|
|
8
|
+
from collections.abc import Iterable
|
|
8
9
|
from copy import deepcopy
|
|
9
|
-
from typing import Iterable, Optional, Tuple, Union
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
@@ -22,12 +22,10 @@ def invert_colors(img: tf.Tensor, min_val: float = 0.6) -> tf.Tensor:
|
|
|
22
22
|
"""Invert the colors of an image
|
|
23
23
|
|
|
24
24
|
Args:
|
|
25
|
-
----
|
|
26
25
|
img : tf.Tensor, the image to invert
|
|
27
26
|
min_val : minimum value of the random shift
|
|
28
27
|
|
|
29
28
|
Returns:
|
|
30
|
-
-------
|
|
31
29
|
the inverted image
|
|
32
30
|
"""
|
|
33
31
|
out = tf.image.rgb_to_grayscale(img) # Convert to gray
|
|
@@ -48,13 +46,11 @@ def rotated_img_tensor(img: tf.Tensor, angle: float, expand: bool = False) -> tf
|
|
|
48
46
|
"""Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
|
|
49
47
|
|
|
50
48
|
Args:
|
|
51
|
-
----
|
|
52
49
|
img: image to rotate
|
|
53
50
|
angle: angle in degrees. +: counter-clockwise, -: clockwise
|
|
54
51
|
expand: whether the image should be padded before the rotation
|
|
55
52
|
|
|
56
53
|
Returns:
|
|
57
|
-
-------
|
|
58
54
|
the rotated image (tensor)
|
|
59
55
|
"""
|
|
60
56
|
# Compute the expanded padding
|
|
@@ -103,18 +99,16 @@ def rotate_sample(
|
|
|
103
99
|
geoms: np.ndarray,
|
|
104
100
|
angle: float,
|
|
105
101
|
expand: bool = False,
|
|
106
|
-
) ->
|
|
102
|
+
) -> tuple[tf.Tensor, np.ndarray]:
|
|
107
103
|
"""Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
|
|
108
104
|
|
|
109
105
|
Args:
|
|
110
|
-
----
|
|
111
106
|
img: image to rotate
|
|
112
107
|
geoms: array of geometries of shape (N, 4) or (N, 4, 2)
|
|
113
108
|
angle: angle in degrees. +: counter-clockwise, -: clockwise
|
|
114
109
|
expand: whether the image should be padded before the rotation
|
|
115
110
|
|
|
116
111
|
Returns:
|
|
117
|
-
-------
|
|
118
112
|
A tuple of rotated img (tensor), rotated boxes (np array)
|
|
119
113
|
"""
|
|
120
114
|
# Rotated the image
|
|
@@ -140,22 +134,20 @@ def rotate_sample(
|
|
|
140
134
|
rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[1]
|
|
141
135
|
rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[0]
|
|
142
136
|
|
|
143
|
-
return rotated_img, np.clip(rotated_geoms, 0, 1)
|
|
137
|
+
return rotated_img, np.clip(np.around(rotated_geoms, decimals=15), 0, 1)
|
|
144
138
|
|
|
145
139
|
|
|
146
140
|
def crop_detection(
|
|
147
|
-
img: tf.Tensor, boxes: np.ndarray, crop_box:
|
|
148
|
-
) ->
|
|
141
|
+
img: tf.Tensor, boxes: np.ndarray, crop_box: tuple[float, float, float, float]
|
|
142
|
+
) -> tuple[tf.Tensor, np.ndarray]:
|
|
149
143
|
"""Crop and image and associated bboxes
|
|
150
144
|
|
|
151
145
|
Args:
|
|
152
|
-
----
|
|
153
146
|
img: image to crop
|
|
154
147
|
boxes: array of boxes to clip, absolute (int) or relative (float)
|
|
155
148
|
crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords.
|
|
156
149
|
|
|
157
150
|
Returns:
|
|
158
|
-
-------
|
|
159
151
|
A tuple of cropped image, cropped boxes, where the image is not resized.
|
|
160
152
|
"""
|
|
161
153
|
if any(val < 0 or val > 1 for val in crop_box):
|
|
@@ -172,16 +164,15 @@ def crop_detection(
|
|
|
172
164
|
|
|
173
165
|
def _gaussian_filter(
|
|
174
166
|
img: tf.Tensor,
|
|
175
|
-
kernel_size:
|
|
167
|
+
kernel_size: int | Iterable[int],
|
|
176
168
|
sigma: float,
|
|
177
|
-
mode:
|
|
178
|
-
pad_value:
|
|
169
|
+
mode: str | None = None,
|
|
170
|
+
pad_value: int = 0,
|
|
179
171
|
):
|
|
180
172
|
"""Apply Gaussian filter to image.
|
|
181
173
|
Adapted from: https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/filters.py
|
|
182
174
|
|
|
183
175
|
Args:
|
|
184
|
-
----
|
|
185
176
|
img: image to filter of shape (N, H, W, C)
|
|
186
177
|
kernel_size: kernel size of the filter
|
|
187
178
|
sigma: standard deviation of the Gaussian filter
|
|
@@ -189,7 +180,6 @@ def _gaussian_filter(
|
|
|
189
180
|
pad_value: value to pad the image with
|
|
190
181
|
|
|
191
182
|
Returns:
|
|
192
|
-
-------
|
|
193
183
|
A tensor of shape (N, H, W, C)
|
|
194
184
|
"""
|
|
195
185
|
ksize = tf.convert_to_tensor(tf.broadcast_to(kernel_size, [2]), dtype=tf.int32)
|
|
@@ -235,17 +225,15 @@ def _gaussian_filter(
|
|
|
235
225
|
return tf.nn.depthwise_conv2d(img, g, [1, 1, 1, 1], padding="VALID", data_format="NHWC")
|
|
236
226
|
|
|
237
227
|
|
|
238
|
-
def random_shadow(img: tf.Tensor, opacity_range:
|
|
228
|
+
def random_shadow(img: tf.Tensor, opacity_range: tuple[float, float], **kwargs) -> tf.Tensor:
|
|
239
229
|
"""Apply a random shadow to a given image
|
|
240
230
|
|
|
241
231
|
Args:
|
|
242
|
-
----
|
|
243
232
|
img: image to modify
|
|
244
233
|
opacity_range: the minimum and maximum desired opacity of the shadow
|
|
245
234
|
**kwargs: additional arguments to pass to `create_shadow_mask`
|
|
246
235
|
|
|
247
236
|
Returns:
|
|
248
|
-
-------
|
|
249
237
|
shadowed image
|
|
250
238
|
"""
|
|
251
239
|
shadow_mask = create_shadow_mask(img.shape[:2], **kwargs)
|
|
@@ -2,7 +2,7 @@ from doctr.file_utils import is_tf_available, is_torch_available
|
|
|
2
2
|
|
|
3
3
|
from .base import *
|
|
4
4
|
|
|
5
|
-
if
|
|
6
|
-
from .
|
|
7
|
-
elif
|
|
8
|
-
from .
|
|
5
|
+
if is_torch_available():
|
|
6
|
+
from .pytorch import *
|
|
7
|
+
elif is_tf_available():
|
|
8
|
+
from .tensorflow import * # type: ignore[assignment]
|