python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
doctr/models/utils/tensorflow.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
-
import
|
|
8
|
-
from typing import Any
|
|
9
|
-
from zipfile import ZipFile
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import Any
|
|
10
9
|
|
|
11
10
|
import tensorflow as tf
|
|
12
11
|
import tf2onnx
|
|
@@ -19,6 +18,7 @@ logging.getLogger("tensorflow").setLevel(logging.DEBUG)
|
|
|
19
18
|
|
|
20
19
|
__all__ = [
|
|
21
20
|
"load_pretrained_params",
|
|
21
|
+
"_build_model",
|
|
22
22
|
"conv_sequence",
|
|
23
23
|
"IntermediateLayerGetter",
|
|
24
24
|
"export_model_to_onnx",
|
|
@@ -36,51 +36,50 @@ def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor:
|
|
|
36
36
|
return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
def _build_model(model: Model):
|
|
40
|
+
"""Build a model by calling it once with dummy input
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
model: the model to be built
|
|
44
|
+
"""
|
|
45
|
+
model(tf.zeros((1, *model.cfg["input_shape"])), training=False)
|
|
46
|
+
|
|
47
|
+
|
|
39
48
|
def load_pretrained_params(
|
|
40
49
|
model: Model,
|
|
41
|
-
url:
|
|
42
|
-
hash_prefix:
|
|
43
|
-
|
|
44
|
-
internal_name: str = "weights",
|
|
50
|
+
url: str | None = None,
|
|
51
|
+
hash_prefix: str | None = None,
|
|
52
|
+
skip_mismatch: bool = False,
|
|
45
53
|
**kwargs: Any,
|
|
46
54
|
) -> None:
|
|
47
55
|
"""Load a set of parameters onto a model
|
|
48
56
|
|
|
49
57
|
>>> from doctr.models import load_pretrained_params
|
|
50
|
-
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.
|
|
58
|
+
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5")
|
|
51
59
|
|
|
52
60
|
Args:
|
|
53
|
-
----
|
|
54
61
|
model: the keras model to be loaded
|
|
55
62
|
url: URL of the zipped set of parameters
|
|
56
63
|
hash_prefix: first characters of SHA256 expected hash
|
|
57
|
-
|
|
58
|
-
internal_name: name of the ckpt files
|
|
64
|
+
skip_mismatch: skip loading layers with mismatched shapes
|
|
59
65
|
**kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
|
|
60
66
|
"""
|
|
61
67
|
if url is None:
|
|
62
68
|
logging.warning("Invalid model URL, using default initialization.")
|
|
63
69
|
else:
|
|
64
70
|
archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
|
|
65
|
-
|
|
66
|
-
# Unzip the archive
|
|
67
|
-
params_path = archive_path.parent.joinpath(archive_path.stem)
|
|
68
|
-
if not params_path.is_dir() or overwrite:
|
|
69
|
-
with ZipFile(archive_path, "r") as f:
|
|
70
|
-
f.extractall(path=params_path)
|
|
71
|
-
|
|
72
71
|
# Load weights
|
|
73
|
-
model.load_weights(
|
|
72
|
+
model.load_weights(archive_path, skip_mismatch=skip_mismatch)
|
|
74
73
|
|
|
75
74
|
|
|
76
75
|
def conv_sequence(
|
|
77
76
|
out_channels: int,
|
|
78
|
-
activation:
|
|
77
|
+
activation: str | Callable | None = None,
|
|
79
78
|
bn: bool = False,
|
|
80
79
|
padding: str = "same",
|
|
81
80
|
kernel_initializer: str = "he_normal",
|
|
82
81
|
**kwargs: Any,
|
|
83
|
-
) ->
|
|
82
|
+
) -> list[layers.Layer]:
|
|
84
83
|
"""Builds a convolutional-based layer sequence
|
|
85
84
|
|
|
86
85
|
>>> from tensorflow.keras import Sequential
|
|
@@ -88,7 +87,6 @@ def conv_sequence(
|
|
|
88
87
|
>>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3]))
|
|
89
88
|
|
|
90
89
|
Args:
|
|
91
|
-
----
|
|
92
90
|
out_channels: number of output channels
|
|
93
91
|
activation: activation to be used (default: no activation)
|
|
94
92
|
bn: should a batch normalization layer be added
|
|
@@ -97,7 +95,6 @@ def conv_sequence(
|
|
|
97
95
|
**kwargs: additional arguments to be passed to the convolutional layer
|
|
98
96
|
|
|
99
97
|
Returns:
|
|
100
|
-
-------
|
|
101
98
|
list of layers
|
|
102
99
|
"""
|
|
103
100
|
# No bias before Batch norm
|
|
@@ -125,12 +122,11 @@ class IntermediateLayerGetter(Model):
|
|
|
125
122
|
>>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers)
|
|
126
123
|
|
|
127
124
|
Args:
|
|
128
|
-
----
|
|
129
125
|
model: the model to extract feature maps from
|
|
130
126
|
layer_names: the list of layers to retrieve the feature map from
|
|
131
127
|
"""
|
|
132
128
|
|
|
133
|
-
def __init__(self, model: Model, layer_names:
|
|
129
|
+
def __init__(self, model: Model, layer_names: list[str]) -> None:
|
|
134
130
|
intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names]
|
|
135
131
|
super().__init__(model.input, outputs=intermediate_fmaps)
|
|
136
132
|
|
|
@@ -139,8 +135,8 @@ class IntermediateLayerGetter(Model):
|
|
|
139
135
|
|
|
140
136
|
|
|
141
137
|
def export_model_to_onnx(
|
|
142
|
-
model: Model, model_name: str, dummy_input:
|
|
143
|
-
) ->
|
|
138
|
+
model: Model, model_name: str, dummy_input: list[tf.TensorSpec], **kwargs: Any
|
|
139
|
+
) -> tuple[str, list[str]]:
|
|
144
140
|
"""Export model to ONNX format.
|
|
145
141
|
|
|
146
142
|
>>> import tensorflow as tf
|
|
@@ -151,16 +147,18 @@ def export_model_to_onnx(
|
|
|
151
147
|
>>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")])
|
|
152
148
|
|
|
153
149
|
Args:
|
|
154
|
-
----
|
|
155
150
|
model: the keras model to be exported
|
|
156
151
|
model_name: the name for the exported model
|
|
157
152
|
dummy_input: the dummy input to the model
|
|
158
153
|
kwargs: additional arguments to be passed to tf2onnx
|
|
159
154
|
|
|
160
155
|
Returns:
|
|
161
|
-
-------
|
|
162
156
|
the path to the exported model and a list with the output layer names
|
|
163
157
|
"""
|
|
158
|
+
# get the users eager mode
|
|
159
|
+
eager_mode = tf.executing_eagerly()
|
|
160
|
+
# set eager mode to true to avoid issues with tf2onnx
|
|
161
|
+
tf.config.run_functions_eagerly(True)
|
|
164
162
|
large_model = kwargs.get("large_model", False)
|
|
165
163
|
model_proto, _ = tf2onnx.convert.from_keras(
|
|
166
164
|
model,
|
|
@@ -171,6 +169,9 @@ def export_model_to_onnx(
|
|
|
171
169
|
# Get the output layer names
|
|
172
170
|
output = [n.name for n in model_proto.graph.output]
|
|
173
171
|
|
|
172
|
+
# reset the eager mode to the users mode
|
|
173
|
+
tf.config.run_functions_eagerly(eager_mode)
|
|
174
|
+
|
|
174
175
|
# models which are too large (weights > 2GB while converting to ONNX) needs to be handled
|
|
175
176
|
# about an external tensor storage where the graph and weights are seperatly stored in a archive
|
|
176
177
|
if large_model:
|
doctr/models/zoo.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -83,7 +83,6 @@ def ocr_predictor(
|
|
|
83
83
|
>>> out = model([input_page])
|
|
84
84
|
|
|
85
85
|
Args:
|
|
86
|
-
----
|
|
87
86
|
det_arch: name of the detection architecture or the model itself to use
|
|
88
87
|
(e.g. 'db_resnet50', 'db_mobilenet_v3_large')
|
|
89
88
|
reco_arch: name of the recognition architecture or the model itself to use
|
|
@@ -108,7 +107,6 @@ def ocr_predictor(
|
|
|
108
107
|
kwargs: keyword args of `OCRPredictor`
|
|
109
108
|
|
|
110
109
|
Returns:
|
|
111
|
-
-------
|
|
112
110
|
OCR predictor
|
|
113
111
|
"""
|
|
114
112
|
return _predictor(
|
|
@@ -197,7 +195,6 @@ def kie_predictor(
|
|
|
197
195
|
>>> out = model([input_page])
|
|
198
196
|
|
|
199
197
|
Args:
|
|
200
|
-
----
|
|
201
198
|
det_arch: name of the detection architecture or the model itself to use
|
|
202
199
|
(e.g. 'db_resnet50', 'db_mobilenet_v3_large')
|
|
203
200
|
reco_arch: name of the recognition architecture or the model itself to use
|
|
@@ -222,7 +219,6 @@ def kie_predictor(
|
|
|
222
219
|
kwargs: keyword args of `OCRPredictor`
|
|
223
220
|
|
|
224
221
|
Returns:
|
|
225
|
-
-------
|
|
226
222
|
KIE predictor
|
|
227
223
|
"""
|
|
228
224
|
return _kie_predictor(
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Tuple, Union
|
|
7
6
|
|
|
8
7
|
import cv2
|
|
9
8
|
import numpy as np
|
|
@@ -15,17 +14,15 @@ __all__ = ["crop_boxes", "create_shadow_mask"]
|
|
|
15
14
|
|
|
16
15
|
def crop_boxes(
|
|
17
16
|
boxes: np.ndarray,
|
|
18
|
-
crop_box:
|
|
17
|
+
crop_box: tuple[int, int, int, int] | tuple[float, float, float, float],
|
|
19
18
|
) -> np.ndarray:
|
|
20
19
|
"""Crop localization boxes
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
----
|
|
24
22
|
boxes: ndarray of shape (N, 4) in relative or abs coordinates
|
|
25
23
|
crop_box: box (xmin, ymin, xmax, ymax) to crop the image, in the same coord format that the boxes
|
|
26
24
|
|
|
27
25
|
Returns:
|
|
28
|
-
-------
|
|
29
26
|
the cropped boxes
|
|
30
27
|
"""
|
|
31
28
|
is_box_rel = boxes.max() <= 1
|
|
@@ -49,17 +46,15 @@ def crop_boxes(
|
|
|
49
46
|
return boxes[is_valid]
|
|
50
47
|
|
|
51
48
|
|
|
52
|
-
def expand_line(line: np.ndarray, target_shape:
|
|
49
|
+
def expand_line(line: np.ndarray, target_shape: tuple[int, int]) -> tuple[float, float]:
|
|
53
50
|
"""Expands a 2-point line, so that the first is on the edge. In other terms, we extend the line in
|
|
54
51
|
the same direction until we meet one of the edges.
|
|
55
52
|
|
|
56
53
|
Args:
|
|
57
|
-
----
|
|
58
54
|
line: array of shape (2, 2) of the point supposed to be on one edge, and the shadow tip.
|
|
59
55
|
target_shape: the desired mask shape
|
|
60
56
|
|
|
61
57
|
Returns:
|
|
62
|
-
-------
|
|
63
58
|
2D coordinates of the first point once we extended the line (on one of the edges)
|
|
64
59
|
"""
|
|
65
60
|
if any(coord == 0 or coord == size for coord, size in zip(line[0], target_shape[::-1])):
|
|
@@ -112,7 +107,7 @@ def expand_line(line: np.ndarray, target_shape: Tuple[int, int]) -> Tuple[float,
|
|
|
112
107
|
|
|
113
108
|
|
|
114
109
|
def create_shadow_mask(
|
|
115
|
-
target_shape:
|
|
110
|
+
target_shape: tuple[int, int],
|
|
116
111
|
min_base_width=0.3,
|
|
117
112
|
max_tip_width=0.5,
|
|
118
113
|
max_tip_height=0.3,
|
|
@@ -120,14 +115,12 @@ def create_shadow_mask(
|
|
|
120
115
|
"""Creates a random shadow mask
|
|
121
116
|
|
|
122
117
|
Args:
|
|
123
|
-
----
|
|
124
118
|
target_shape: the target shape (H, W)
|
|
125
119
|
min_base_width: the relative minimum shadow base width
|
|
126
120
|
max_tip_width: the relative maximum shadow tip width
|
|
127
121
|
max_tip_height: the relative maximum shadow tip height
|
|
128
122
|
|
|
129
123
|
Returns:
|
|
130
|
-
-------
|
|
131
124
|
a numpy ndarray of shape (H, W, 1) with values in the range [0, 1]
|
|
132
125
|
"""
|
|
133
126
|
# Default base is top
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Tuple
|
|
8
7
|
|
|
9
8
|
import numpy as np
|
|
10
9
|
import torch
|
|
10
|
+
from scipy.ndimage import gaussian_filter
|
|
11
11
|
from torchvision.transforms import functional as F
|
|
12
12
|
|
|
13
13
|
from doctr.utils.geometry import rotate_abs_geoms
|
|
@@ -21,12 +21,10 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
|
|
|
21
21
|
"""Invert the colors of an image
|
|
22
22
|
|
|
23
23
|
Args:
|
|
24
|
-
----
|
|
25
24
|
img : torch.Tensor, the image to invert
|
|
26
25
|
min_val : minimum value of the random shift
|
|
27
26
|
|
|
28
27
|
Returns:
|
|
29
|
-
-------
|
|
30
28
|
the inverted image
|
|
31
29
|
"""
|
|
32
30
|
out = F.rgb_to_grayscale(img, num_output_channels=3)
|
|
@@ -35,9 +33,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
|
|
|
35
33
|
rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
|
|
36
34
|
# Inverse the color
|
|
37
35
|
if out.dtype == torch.uint8:
|
|
38
|
-
out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
|
|
36
|
+
out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) # type: ignore[attr-defined]
|
|
39
37
|
else:
|
|
40
|
-
out = out * rgb_shift.to(dtype=out.dtype)
|
|
38
|
+
out = out * rgb_shift.to(dtype=out.dtype) # type: ignore[attr-defined]
|
|
41
39
|
# Inverse the color
|
|
42
40
|
out = 255 - out if out.dtype == torch.uint8 else 1 - out
|
|
43
41
|
return out
|
|
@@ -48,18 +46,16 @@ def rotate_sample(
|
|
|
48
46
|
geoms: np.ndarray,
|
|
49
47
|
angle: float,
|
|
50
48
|
expand: bool = False,
|
|
51
|
-
) ->
|
|
49
|
+
) -> tuple[torch.Tensor, np.ndarray]:
|
|
52
50
|
"""Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
|
|
53
51
|
|
|
54
52
|
Args:
|
|
55
|
-
----
|
|
56
53
|
img: image to rotate
|
|
57
54
|
geoms: array of geometries of shape (N, 4) or (N, 4, 2)
|
|
58
55
|
angle: angle in degrees. +: counter-clockwise, -: clockwise
|
|
59
56
|
expand: whether the image should be padded before the rotation
|
|
60
57
|
|
|
61
58
|
Returns:
|
|
62
|
-
-------
|
|
63
59
|
A tuple of rotated img (tensor), rotated geometries of shape (N, 4, 2)
|
|
64
60
|
"""
|
|
65
61
|
rotated_img = F.rotate(img, angle=angle, fill=0, expand=expand) # Interpolation NEAREST by default
|
|
@@ -81,7 +77,7 @@ def rotate_sample(
|
|
|
81
77
|
rotated_geoms: np.ndarray = rotate_abs_geoms(
|
|
82
78
|
_geoms,
|
|
83
79
|
angle,
|
|
84
|
-
img.shape[1:],
|
|
80
|
+
img.shape[1:],
|
|
85
81
|
expand,
|
|
86
82
|
).astype(np.float32)
|
|
87
83
|
|
|
@@ -89,22 +85,20 @@ def rotate_sample(
|
|
|
89
85
|
rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[2]
|
|
90
86
|
rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[1]
|
|
91
87
|
|
|
92
|
-
return rotated_img, np.clip(rotated_geoms, 0, 1)
|
|
88
|
+
return rotated_img, np.clip(np.around(rotated_geoms, decimals=15), 0, 1)
|
|
93
89
|
|
|
94
90
|
|
|
95
91
|
def crop_detection(
|
|
96
|
-
img: torch.Tensor, boxes: np.ndarray, crop_box:
|
|
97
|
-
) ->
|
|
92
|
+
img: torch.Tensor, boxes: np.ndarray, crop_box: tuple[float, float, float, float]
|
|
93
|
+
) -> tuple[torch.Tensor, np.ndarray]:
|
|
98
94
|
"""Crop and image and associated bboxes
|
|
99
95
|
|
|
100
96
|
Args:
|
|
101
|
-
----
|
|
102
97
|
img: image to crop
|
|
103
98
|
boxes: array of boxes to clip, absolute (int) or relative (float)
|
|
104
99
|
crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords.
|
|
105
100
|
|
|
106
101
|
Returns:
|
|
107
|
-
-------
|
|
108
102
|
A tuple of cropped image, cropped boxes, where the image is not resized.
|
|
109
103
|
"""
|
|
110
104
|
if any(val < 0 or val > 1 for val in crop_box):
|
|
@@ -119,27 +113,25 @@ def crop_detection(
|
|
|
119
113
|
return cropped_img, boxes
|
|
120
114
|
|
|
121
115
|
|
|
122
|
-
def random_shadow(img: torch.Tensor, opacity_range:
|
|
123
|
-
"""
|
|
116
|
+
def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwargs) -> torch.Tensor:
|
|
117
|
+
"""Apply a random shadow effect to an image using NumPy for blurring.
|
|
124
118
|
|
|
125
119
|
Args:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
**kwargs: additional arguments to pass to `create_shadow_mask`
|
|
120
|
+
img: Image to modify (C, H, W) as a PyTorch tensor.
|
|
121
|
+
opacity_range: The minimum and maximum desired opacity of the shadow.
|
|
122
|
+
**kwargs: Additional arguments to pass to `create_shadow_mask`.
|
|
130
123
|
|
|
131
124
|
Returns:
|
|
132
|
-
|
|
133
|
-
shaded image
|
|
125
|
+
Shadowed image as a PyTorch tensor (same shape as input).
|
|
134
126
|
"""
|
|
135
|
-
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
|
|
136
|
-
|
|
127
|
+
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
|
|
137
128
|
opacity = np.random.uniform(*opacity_range)
|
|
138
|
-
shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...])
|
|
139
129
|
|
|
140
|
-
#
|
|
141
|
-
k = 7 + 2 * int(4 * np.random.rand(1))
|
|
130
|
+
# Apply Gaussian blur to the shadow mask
|
|
142
131
|
sigma = np.random.uniform(0.5, 5.0)
|
|
143
|
-
|
|
132
|
+
blurred_mask = gaussian_filter(shadow_mask, sigma=sigma)
|
|
133
|
+
|
|
134
|
+
shadow_tensor = 1 - torch.from_numpy(blurred_mask).float()
|
|
135
|
+
shadow_tensor = shadow_tensor.to(img.device).unsqueeze(0) # Add channel dimension
|
|
144
136
|
|
|
145
137
|
return opacity * shadow_tensor * img + (1 - opacity) * img
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
7
|
import random
|
|
8
|
+
from collections.abc import Iterable
|
|
8
9
|
from copy import deepcopy
|
|
9
|
-
from typing import Iterable, Optional, Tuple, Union
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
@@ -22,12 +22,10 @@ def invert_colors(img: tf.Tensor, min_val: float = 0.6) -> tf.Tensor:
|
|
|
22
22
|
"""Invert the colors of an image
|
|
23
23
|
|
|
24
24
|
Args:
|
|
25
|
-
----
|
|
26
25
|
img : tf.Tensor, the image to invert
|
|
27
26
|
min_val : minimum value of the random shift
|
|
28
27
|
|
|
29
28
|
Returns:
|
|
30
|
-
-------
|
|
31
29
|
the inverted image
|
|
32
30
|
"""
|
|
33
31
|
out = tf.image.rgb_to_grayscale(img) # Convert to gray
|
|
@@ -48,13 +46,11 @@ def rotated_img_tensor(img: tf.Tensor, angle: float, expand: bool = False) -> tf
|
|
|
48
46
|
"""Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
|
|
49
47
|
|
|
50
48
|
Args:
|
|
51
|
-
----
|
|
52
49
|
img: image to rotate
|
|
53
50
|
angle: angle in degrees. +: counter-clockwise, -: clockwise
|
|
54
51
|
expand: whether the image should be padded before the rotation
|
|
55
52
|
|
|
56
53
|
Returns:
|
|
57
|
-
-------
|
|
58
54
|
the rotated image (tensor)
|
|
59
55
|
"""
|
|
60
56
|
# Compute the expanded padding
|
|
@@ -103,18 +99,16 @@ def rotate_sample(
|
|
|
103
99
|
geoms: np.ndarray,
|
|
104
100
|
angle: float,
|
|
105
101
|
expand: bool = False,
|
|
106
|
-
) ->
|
|
102
|
+
) -> tuple[tf.Tensor, np.ndarray]:
|
|
107
103
|
"""Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
|
|
108
104
|
|
|
109
105
|
Args:
|
|
110
|
-
----
|
|
111
106
|
img: image to rotate
|
|
112
107
|
geoms: array of geometries of shape (N, 4) or (N, 4, 2)
|
|
113
108
|
angle: angle in degrees. +: counter-clockwise, -: clockwise
|
|
114
109
|
expand: whether the image should be padded before the rotation
|
|
115
110
|
|
|
116
111
|
Returns:
|
|
117
|
-
-------
|
|
118
112
|
A tuple of rotated img (tensor), rotated boxes (np array)
|
|
119
113
|
"""
|
|
120
114
|
# Rotated the image
|
|
@@ -140,22 +134,20 @@ def rotate_sample(
|
|
|
140
134
|
rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[1]
|
|
141
135
|
rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[0]
|
|
142
136
|
|
|
143
|
-
return rotated_img, np.clip(rotated_geoms, 0, 1)
|
|
137
|
+
return rotated_img, np.clip(np.around(rotated_geoms, decimals=15), 0, 1)
|
|
144
138
|
|
|
145
139
|
|
|
146
140
|
def crop_detection(
|
|
147
|
-
img: tf.Tensor, boxes: np.ndarray, crop_box:
|
|
148
|
-
) ->
|
|
141
|
+
img: tf.Tensor, boxes: np.ndarray, crop_box: tuple[float, float, float, float]
|
|
142
|
+
) -> tuple[tf.Tensor, np.ndarray]:
|
|
149
143
|
"""Crop and image and associated bboxes
|
|
150
144
|
|
|
151
145
|
Args:
|
|
152
|
-
----
|
|
153
146
|
img: image to crop
|
|
154
147
|
boxes: array of boxes to clip, absolute (int) or relative (float)
|
|
155
148
|
crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords.
|
|
156
149
|
|
|
157
150
|
Returns:
|
|
158
|
-
-------
|
|
159
151
|
A tuple of cropped image, cropped boxes, where the image is not resized.
|
|
160
152
|
"""
|
|
161
153
|
if any(val < 0 or val > 1 for val in crop_box):
|
|
@@ -172,16 +164,15 @@ def crop_detection(
|
|
|
172
164
|
|
|
173
165
|
def _gaussian_filter(
|
|
174
166
|
img: tf.Tensor,
|
|
175
|
-
kernel_size:
|
|
167
|
+
kernel_size: int | Iterable[int],
|
|
176
168
|
sigma: float,
|
|
177
|
-
mode:
|
|
178
|
-
pad_value:
|
|
169
|
+
mode: str | None = None,
|
|
170
|
+
pad_value: int = 0,
|
|
179
171
|
):
|
|
180
172
|
"""Apply Gaussian filter to image.
|
|
181
173
|
Adapted from: https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/filters.py
|
|
182
174
|
|
|
183
175
|
Args:
|
|
184
|
-
----
|
|
185
176
|
img: image to filter of shape (N, H, W, C)
|
|
186
177
|
kernel_size: kernel size of the filter
|
|
187
178
|
sigma: standard deviation of the Gaussian filter
|
|
@@ -189,7 +180,6 @@ def _gaussian_filter(
|
|
|
189
180
|
pad_value: value to pad the image with
|
|
190
181
|
|
|
191
182
|
Returns:
|
|
192
|
-
-------
|
|
193
183
|
A tensor of shape (N, H, W, C)
|
|
194
184
|
"""
|
|
195
185
|
ksize = tf.convert_to_tensor(tf.broadcast_to(kernel_size, [2]), dtype=tf.int32)
|
|
@@ -235,17 +225,15 @@ def _gaussian_filter(
|
|
|
235
225
|
return tf.nn.depthwise_conv2d(img, g, [1, 1, 1, 1], padding="VALID", data_format="NHWC")
|
|
236
226
|
|
|
237
227
|
|
|
238
|
-
def random_shadow(img: tf.Tensor, opacity_range:
|
|
228
|
+
def random_shadow(img: tf.Tensor, opacity_range: tuple[float, float], **kwargs) -> tf.Tensor:
|
|
239
229
|
"""Apply a random shadow to a given image
|
|
240
230
|
|
|
241
231
|
Args:
|
|
242
|
-
----
|
|
243
232
|
img: image to modify
|
|
244
233
|
opacity_range: the minimum and maximum desired opacity of the shadow
|
|
245
234
|
**kwargs: additional arguments to pass to `create_shadow_mask`
|
|
246
235
|
|
|
247
236
|
Returns:
|
|
248
|
-
-------
|
|
249
237
|
shadowed image
|
|
250
238
|
"""
|
|
251
239
|
shadow_mask = create_shadow_mask(img.shape[:2], **kwargs)
|
|
@@ -2,7 +2,7 @@ from doctr.file_utils import is_tf_available, is_torch_available
|
|
|
2
2
|
|
|
3
3
|
from .base import *
|
|
4
4
|
|
|
5
|
-
if
|
|
6
|
-
from .
|
|
7
|
-
elif
|
|
8
|
-
from .
|
|
5
|
+
if is_torch_available():
|
|
6
|
+
from .pytorch import *
|
|
7
|
+
elif is_tf_available():
|
|
8
|
+
from .tensorflow import * # type: ignore[assignment]
|