python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
doctr/transforms/modules/base.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
7
|
import random
|
|
8
|
-
from
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import Any
|
|
9
10
|
|
|
10
11
|
import numpy as np
|
|
11
12
|
|
|
@@ -21,37 +22,36 @@ class SampleCompose(NestedObject):
|
|
|
21
22
|
|
|
22
23
|
.. tabs::
|
|
23
24
|
|
|
24
|
-
.. tab::
|
|
25
|
+
.. tab:: PyTorch
|
|
25
26
|
|
|
26
27
|
.. code:: python
|
|
27
28
|
|
|
28
29
|
>>> import numpy as np
|
|
29
|
-
>>> import
|
|
30
|
+
>>> import torch
|
|
30
31
|
>>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
|
|
31
|
-
>>>
|
|
32
|
-
>>> out, out_boxes =
|
|
32
|
+
>>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
|
|
33
|
+
>>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
|
|
33
34
|
|
|
34
|
-
.. tab::
|
|
35
|
+
.. tab:: TensorFlow
|
|
35
36
|
|
|
36
37
|
.. code:: python
|
|
37
38
|
|
|
38
39
|
>>> import numpy as np
|
|
39
|
-
>>> import
|
|
40
|
+
>>> import tensorflow as tf
|
|
40
41
|
>>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
|
|
41
|
-
>>>
|
|
42
|
-
>>> out, out_boxes =
|
|
42
|
+
>>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
|
|
43
|
+
>>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4)))
|
|
43
44
|
|
|
44
45
|
Args:
|
|
45
|
-
----
|
|
46
46
|
transforms: list of transformation modules
|
|
47
47
|
"""
|
|
48
48
|
|
|
49
|
-
_children_names:
|
|
49
|
+
_children_names: list[str] = ["sample_transforms"]
|
|
50
50
|
|
|
51
|
-
def __init__(self, transforms:
|
|
51
|
+
def __init__(self, transforms: list[Callable[[Any, Any], tuple[Any, Any]]]) -> None:
|
|
52
52
|
self.sample_transforms = transforms
|
|
53
53
|
|
|
54
|
-
def __call__(self, x: Any, target: Any) ->
|
|
54
|
+
def __call__(self, x: Any, target: Any) -> tuple[Any, Any]:
|
|
55
55
|
for t in self.sample_transforms:
|
|
56
56
|
x, target = t(x, target)
|
|
57
57
|
|
|
@@ -63,35 +63,34 @@ class ImageTransform(NestedObject):
|
|
|
63
63
|
|
|
64
64
|
.. tabs::
|
|
65
65
|
|
|
66
|
-
.. tab::
|
|
66
|
+
.. tab:: PyTorch
|
|
67
67
|
|
|
68
68
|
.. code:: python
|
|
69
69
|
|
|
70
|
-
>>> import
|
|
70
|
+
>>> import torch
|
|
71
71
|
>>> from doctr.transforms import ImageTransform, ColorInversion
|
|
72
72
|
>>> transfo = ImageTransform(ColorInversion((32, 32)))
|
|
73
|
-
>>> out, _ = transfo(
|
|
73
|
+
>>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
|
|
74
74
|
|
|
75
|
-
.. tab::
|
|
75
|
+
.. tab:: TensorFlow
|
|
76
76
|
|
|
77
77
|
.. code:: python
|
|
78
78
|
|
|
79
|
-
>>> import
|
|
79
|
+
>>> import tensorflow as tf
|
|
80
80
|
>>> from doctr.transforms import ImageTransform, ColorInversion
|
|
81
81
|
>>> transfo = ImageTransform(ColorInversion((32, 32)))
|
|
82
|
-
>>> out, _ = transfo(
|
|
82
|
+
>>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None)
|
|
83
83
|
|
|
84
84
|
Args:
|
|
85
|
-
----
|
|
86
85
|
transform: the image transformation module to wrap
|
|
87
86
|
"""
|
|
88
87
|
|
|
89
|
-
_children_names:
|
|
88
|
+
_children_names: list[str] = ["img_transform"]
|
|
90
89
|
|
|
91
90
|
def __init__(self, transform: Callable[[Any], Any]) -> None:
|
|
92
91
|
self.img_transform = transform
|
|
93
92
|
|
|
94
|
-
def __call__(self, img: Any, target: Any) ->
|
|
93
|
+
def __call__(self, img: Any, target: Any) -> tuple[Any, Any]:
|
|
95
94
|
img = self.img_transform(img)
|
|
96
95
|
return img, target
|
|
97
96
|
|
|
@@ -102,26 +101,25 @@ class ColorInversion(NestedObject):
|
|
|
102
101
|
|
|
103
102
|
.. tabs::
|
|
104
103
|
|
|
105
|
-
.. tab::
|
|
104
|
+
.. tab:: PyTorch
|
|
106
105
|
|
|
107
106
|
.. code:: python
|
|
108
107
|
|
|
109
|
-
>>> import
|
|
108
|
+
>>> import torch
|
|
110
109
|
>>> from doctr.transforms import ColorInversion
|
|
111
110
|
>>> transfo = ColorInversion(min_val=0.6)
|
|
112
|
-
>>> out = transfo(
|
|
111
|
+
>>> out = transfo(torch.rand(8, 64, 64, 3))
|
|
113
112
|
|
|
114
|
-
.. tab::
|
|
113
|
+
.. tab:: TensorFlow
|
|
115
114
|
|
|
116
115
|
.. code:: python
|
|
117
116
|
|
|
118
|
-
>>> import
|
|
117
|
+
>>> import tensorflow as tf
|
|
119
118
|
>>> from doctr.transforms import ColorInversion
|
|
120
119
|
>>> transfo = ColorInversion(min_val=0.6)
|
|
121
|
-
>>> out = transfo(
|
|
120
|
+
>>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1))
|
|
122
121
|
|
|
123
122
|
Args:
|
|
124
|
-
----
|
|
125
123
|
min_val: range [min_val, 1] to colorize RGB pixels
|
|
126
124
|
"""
|
|
127
125
|
|
|
@@ -140,35 +138,34 @@ class OneOf(NestedObject):
|
|
|
140
138
|
|
|
141
139
|
.. tabs::
|
|
142
140
|
|
|
143
|
-
.. tab::
|
|
141
|
+
.. tab:: PyTorch
|
|
144
142
|
|
|
145
143
|
.. code:: python
|
|
146
144
|
|
|
147
|
-
>>> import
|
|
145
|
+
>>> import torch
|
|
148
146
|
>>> from doctr.transforms import OneOf
|
|
149
147
|
>>> transfo = OneOf([JpegQuality(), Gamma()])
|
|
150
|
-
>>> out = transfo(
|
|
148
|
+
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
151
149
|
|
|
152
|
-
.. tab::
|
|
150
|
+
.. tab:: TensorFlow
|
|
153
151
|
|
|
154
152
|
.. code:: python
|
|
155
153
|
|
|
156
|
-
>>> import
|
|
154
|
+
>>> import tensorflow as tf
|
|
157
155
|
>>> from doctr.transforms import OneOf
|
|
158
156
|
>>> transfo = OneOf([JpegQuality(), Gamma()])
|
|
159
|
-
>>> out = transfo(
|
|
157
|
+
>>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
|
|
160
158
|
|
|
161
159
|
Args:
|
|
162
|
-
----
|
|
163
160
|
transforms: list of transformations, one only will be picked
|
|
164
161
|
"""
|
|
165
162
|
|
|
166
|
-
_children_names:
|
|
163
|
+
_children_names: list[str] = ["transforms"]
|
|
167
164
|
|
|
168
|
-
def __init__(self, transforms:
|
|
165
|
+
def __init__(self, transforms: list[Callable[[Any], Any]]) -> None:
|
|
169
166
|
self.transforms = transforms
|
|
170
167
|
|
|
171
|
-
def __call__(self, img: Any, target:
|
|
168
|
+
def __call__(self, img: Any, target: np.ndarray | None = None) -> Any | tuple[Any, np.ndarray]:
|
|
172
169
|
# Pick transformation
|
|
173
170
|
transfo = self.transforms[int(random.random() * len(self.transforms))]
|
|
174
171
|
# Apply
|
|
@@ -180,26 +177,25 @@ class RandomApply(NestedObject):
|
|
|
180
177
|
|
|
181
178
|
.. tabs::
|
|
182
179
|
|
|
183
|
-
.. tab::
|
|
180
|
+
.. tab:: PyTorch
|
|
184
181
|
|
|
185
182
|
.. code:: python
|
|
186
183
|
|
|
187
|
-
>>> import
|
|
184
|
+
>>> import torch
|
|
188
185
|
>>> from doctr.transforms import RandomApply
|
|
189
186
|
>>> transfo = RandomApply(Gamma(), p=.5)
|
|
190
|
-
>>> out = transfo(
|
|
187
|
+
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
191
188
|
|
|
192
|
-
.. tab::
|
|
189
|
+
.. tab:: TensorFlow
|
|
193
190
|
|
|
194
191
|
.. code:: python
|
|
195
192
|
|
|
196
|
-
>>> import
|
|
193
|
+
>>> import tensorflow as tf
|
|
197
194
|
>>> from doctr.transforms import RandomApply
|
|
198
195
|
>>> transfo = RandomApply(Gamma(), p=.5)
|
|
199
|
-
>>> out = transfo(
|
|
196
|
+
>>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
|
|
200
197
|
|
|
201
198
|
Args:
|
|
202
|
-
----
|
|
203
199
|
transform: transformation to apply
|
|
204
200
|
p: probability to apply
|
|
205
201
|
"""
|
|
@@ -211,7 +207,7 @@ class RandomApply(NestedObject):
|
|
|
211
207
|
def extra_repr(self) -> str:
|
|
212
208
|
return f"transform={self.transform}, p={self.p}"
|
|
213
209
|
|
|
214
|
-
def __call__(self, img: Any, target:
|
|
210
|
+
def __call__(self, img: Any, target: np.ndarray | None = None) -> Any | tuple[Any, np.ndarray]:
|
|
215
211
|
if random.random() < self.p:
|
|
216
212
|
return self.transform(img) if target is None else self.transform(img, target) # type: ignore[call-arg]
|
|
217
213
|
return img if target is None else (img, target)
|
|
@@ -224,9 +220,7 @@ class RandomRotate(NestedObject):
|
|
|
224
220
|
:align: center
|
|
225
221
|
|
|
226
222
|
Args:
|
|
227
|
-
|
|
228
|
-
max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in
|
|
229
|
-
[-max_angle, max_angle]
|
|
223
|
+
max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in [-max_angle, max_angle]
|
|
230
224
|
expand: whether the image should be padded before the rotation
|
|
231
225
|
"""
|
|
232
226
|
|
|
@@ -237,7 +231,7 @@ class RandomRotate(NestedObject):
|
|
|
237
231
|
def extra_repr(self) -> str:
|
|
238
232
|
return f"max_angle={self.max_angle}, expand={self.expand}"
|
|
239
233
|
|
|
240
|
-
def __call__(self, img: Any, target: np.ndarray) ->
|
|
234
|
+
def __call__(self, img: Any, target: np.ndarray) -> tuple[Any, np.ndarray]:
|
|
241
235
|
angle = random.uniform(-self.max_angle, self.max_angle)
|
|
242
236
|
r_img, r_polys = F.rotate_sample(img, target, angle, self.expand)
|
|
243
237
|
# Removes deleted boxes
|
|
@@ -249,19 +243,18 @@ class RandomCrop(NestedObject):
|
|
|
249
243
|
"""Randomly crop a tensor image and its boxes
|
|
250
244
|
|
|
251
245
|
Args:
|
|
252
|
-
----
|
|
253
246
|
scale: tuple of floats, relative (min_area, max_area) of the crop
|
|
254
247
|
ratio: tuple of float, relative (min_ratio, max_ratio) where ratio = h/w
|
|
255
248
|
"""
|
|
256
249
|
|
|
257
|
-
def __init__(self, scale:
|
|
250
|
+
def __init__(self, scale: tuple[float, float] = (0.08, 1.0), ratio: tuple[float, float] = (0.75, 1.33)) -> None:
|
|
258
251
|
self.scale = scale
|
|
259
252
|
self.ratio = ratio
|
|
260
253
|
|
|
261
254
|
def extra_repr(self) -> str:
|
|
262
255
|
return f"scale={self.scale}, ratio={self.ratio}"
|
|
263
256
|
|
|
264
|
-
def __call__(self, img: Any, target: np.ndarray) ->
|
|
257
|
+
def __call__(self, img: Any, target: np.ndarray) -> tuple[Any, np.ndarray]:
|
|
265
258
|
scale = random.uniform(self.scale[0], self.scale[1])
|
|
266
259
|
ratio = random.uniform(self.ratio[0], self.ratio[1])
|
|
267
260
|
|
|
@@ -1,21 +1,29 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import Optional, Tuple, Union
|
|
8
7
|
|
|
9
8
|
import numpy as np
|
|
10
9
|
import torch
|
|
11
10
|
from PIL.Image import Image
|
|
11
|
+
from scipy.ndimage import gaussian_filter
|
|
12
12
|
from torch.nn.functional import pad
|
|
13
13
|
from torchvision.transforms import functional as F
|
|
14
14
|
from torchvision.transforms import transforms as T
|
|
15
15
|
|
|
16
16
|
from ..functional.pytorch import random_shadow
|
|
17
17
|
|
|
18
|
-
__all__ = [
|
|
18
|
+
__all__ = [
|
|
19
|
+
"Resize",
|
|
20
|
+
"GaussianNoise",
|
|
21
|
+
"ChannelShuffle",
|
|
22
|
+
"RandomHorizontalFlip",
|
|
23
|
+
"RandomShadow",
|
|
24
|
+
"RandomResize",
|
|
25
|
+
"GaussianBlur",
|
|
26
|
+
]
|
|
19
27
|
|
|
20
28
|
|
|
21
29
|
class Resize(T.Resize):
|
|
@@ -23,7 +31,7 @@ class Resize(T.Resize):
|
|
|
23
31
|
|
|
24
32
|
def __init__(
|
|
25
33
|
self,
|
|
26
|
-
size:
|
|
34
|
+
size: int | tuple[int, int],
|
|
27
35
|
interpolation=F.InterpolationMode.BILINEAR,
|
|
28
36
|
preserve_aspect_ratio: bool = False,
|
|
29
37
|
symmetric_pad: bool = False,
|
|
@@ -38,8 +46,8 @@ class Resize(T.Resize):
|
|
|
38
46
|
def forward(
|
|
39
47
|
self,
|
|
40
48
|
img: torch.Tensor,
|
|
41
|
-
target:
|
|
42
|
-
) ->
|
|
49
|
+
target: np.ndarray | None = None,
|
|
50
|
+
) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
|
|
43
51
|
if isinstance(self.size, int):
|
|
44
52
|
target_ratio = img.shape[-2] / img.shape[-1]
|
|
45
53
|
else:
|
|
@@ -74,16 +82,18 @@ class Resize(T.Resize):
|
|
|
74
82
|
if self.symmetric_pad:
|
|
75
83
|
half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2))
|
|
76
84
|
_pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1])
|
|
85
|
+
# Pad image
|
|
77
86
|
img = pad(img, _pad)
|
|
78
87
|
|
|
79
88
|
# In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
|
|
80
89
|
if target is not None:
|
|
90
|
+
if self.symmetric_pad:
|
|
91
|
+
offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
|
|
92
|
+
|
|
81
93
|
if self.preserve_aspect_ratio:
|
|
82
94
|
# Get absolute coords
|
|
83
95
|
if target.shape[1:] == (4,):
|
|
84
96
|
if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
|
|
85
|
-
if np.max(target) <= 1:
|
|
86
|
-
offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
|
|
87
97
|
target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
|
|
88
98
|
target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
|
|
89
99
|
else:
|
|
@@ -91,16 +101,15 @@ class Resize(T.Resize):
|
|
|
91
101
|
target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
|
|
92
102
|
elif target.shape[1:] == (4, 2):
|
|
93
103
|
if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
|
|
94
|
-
if np.max(target) <= 1:
|
|
95
|
-
offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
|
|
96
104
|
target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
|
|
97
105
|
target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
|
|
98
106
|
else:
|
|
99
107
|
target[..., 0] *= raw_shape[-1] / img.shape[-1]
|
|
100
108
|
target[..., 1] *= raw_shape[-2] / img.shape[-2]
|
|
101
109
|
else:
|
|
102
|
-
raise AssertionError
|
|
103
|
-
|
|
110
|
+
raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")
|
|
111
|
+
|
|
112
|
+
return img, np.clip(target, 0, 1)
|
|
104
113
|
|
|
105
114
|
return img
|
|
106
115
|
|
|
@@ -121,7 +130,6 @@ class GaussianNoise(torch.nn.Module):
|
|
|
121
130
|
>>> out = transfo(torch.rand((3, 224, 224)))
|
|
122
131
|
|
|
123
132
|
Args:
|
|
124
|
-
----
|
|
125
133
|
mean : mean of the gaussian distribution
|
|
126
134
|
std : std of the gaussian distribution
|
|
127
135
|
"""
|
|
@@ -135,14 +143,47 @@ class GaussianNoise(torch.nn.Module):
|
|
|
135
143
|
# Reshape the distribution
|
|
136
144
|
noise = self.mean + 2 * self.std * torch.rand(x.shape, device=x.device) - self.std
|
|
137
145
|
if x.dtype == torch.uint8:
|
|
138
|
-
return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8)
|
|
146
|
+
return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8) # type: ignore[attr-defined]
|
|
139
147
|
else:
|
|
140
|
-
return (x + noise.to(dtype=x.dtype)).clamp(0, 1)
|
|
148
|
+
return (x + noise.to(dtype=x.dtype)).clamp(0, 1) # type: ignore[attr-defined]
|
|
141
149
|
|
|
142
150
|
def extra_repr(self) -> str:
|
|
143
151
|
return f"mean={self.mean}, std={self.std}"
|
|
144
152
|
|
|
145
153
|
|
|
154
|
+
class GaussianBlur(torch.nn.Module):
|
|
155
|
+
"""Apply Gaussian Blur to the input tensor
|
|
156
|
+
|
|
157
|
+
>>> import torch
|
|
158
|
+
>>> from doctr.transforms import GaussianBlur
|
|
159
|
+
>>> transfo = GaussianBlur(sigma=(0.0, 1.0))
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
sigma : standard deviation range for the gaussian kernel
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
def __init__(self, sigma: tuple[float, float]) -> None:
|
|
166
|
+
super().__init__()
|
|
167
|
+
self.sigma_range = sigma
|
|
168
|
+
|
|
169
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
170
|
+
# Sample a random sigma value within the specified range
|
|
171
|
+
sigma = torch.empty(1).uniform_(*self.sigma_range).item()
|
|
172
|
+
|
|
173
|
+
# Apply Gaussian blur along spatial dimensions only
|
|
174
|
+
blurred = torch.tensor(
|
|
175
|
+
gaussian_filter(
|
|
176
|
+
x.numpy(),
|
|
177
|
+
sigma=sigma,
|
|
178
|
+
mode="reflect",
|
|
179
|
+
truncate=4.0,
|
|
180
|
+
),
|
|
181
|
+
dtype=x.dtype,
|
|
182
|
+
device=x.device,
|
|
183
|
+
)
|
|
184
|
+
return blurred
|
|
185
|
+
|
|
186
|
+
|
|
146
187
|
class ChannelShuffle(torch.nn.Module):
|
|
147
188
|
"""Randomly shuffle channel order of a given image"""
|
|
148
189
|
|
|
@@ -158,9 +199,7 @@ class ChannelShuffle(torch.nn.Module):
|
|
|
158
199
|
class RandomHorizontalFlip(T.RandomHorizontalFlip):
|
|
159
200
|
"""Randomly flip the input image horizontally"""
|
|
160
201
|
|
|
161
|
-
def forward(
|
|
162
|
-
self, img: Union[torch.Tensor, Image], target: np.ndarray
|
|
163
|
-
) -> Tuple[Union[torch.Tensor, Image], np.ndarray]:
|
|
202
|
+
def forward(self, img: torch.Tensor | Image, target: np.ndarray) -> tuple[torch.Tensor | Image, np.ndarray]:
|
|
164
203
|
if torch.rand(1) < self.p:
|
|
165
204
|
_img = F.hflip(img)
|
|
166
205
|
_target = target.copy()
|
|
@@ -182,11 +221,10 @@ class RandomShadow(torch.nn.Module):
|
|
|
182
221
|
>>> out = transfo(torch.rand((3, 64, 64)))
|
|
183
222
|
|
|
184
223
|
Args:
|
|
185
|
-
----
|
|
186
224
|
opacity_range : minimum and maximum opacity of the shade
|
|
187
225
|
"""
|
|
188
226
|
|
|
189
|
-
def __init__(self, opacity_range:
|
|
227
|
+
def __init__(self, opacity_range: tuple[float, float] | None = None) -> None:
|
|
190
228
|
super().__init__()
|
|
191
229
|
self.opacity_range = opacity_range if isinstance(opacity_range, tuple) else (0.2, 0.8)
|
|
192
230
|
|
|
@@ -195,7 +233,7 @@ class RandomShadow(torch.nn.Module):
|
|
|
195
233
|
try:
|
|
196
234
|
if x.dtype == torch.uint8:
|
|
197
235
|
return (
|
|
198
|
-
(
|
|
236
|
+
( # type: ignore[attr-defined]
|
|
199
237
|
255
|
|
200
238
|
* random_shadow(
|
|
201
239
|
x.to(dtype=torch.float32) / 255,
|
|
@@ -224,20 +262,19 @@ class RandomResize(torch.nn.Module):
|
|
|
224
262
|
>>> out = transfo(torch.rand((3, 64, 64)))
|
|
225
263
|
|
|
226
264
|
Args:
|
|
227
|
-
----
|
|
228
265
|
scale_range: range of the resizing factor for width and height (independently)
|
|
229
266
|
preserve_aspect_ratio: whether to preserve the aspect ratio of the image,
|
|
230
|
-
|
|
267
|
+
given a float value, the aspect ratio will be preserved with this probability
|
|
231
268
|
symmetric_pad: whether to symmetrically pad the image,
|
|
232
|
-
|
|
269
|
+
given a float value, the symmetric padding will be applied with this probability
|
|
233
270
|
p: probability to apply the transformation
|
|
234
271
|
"""
|
|
235
272
|
|
|
236
273
|
def __init__(
|
|
237
274
|
self,
|
|
238
|
-
scale_range:
|
|
239
|
-
preserve_aspect_ratio:
|
|
240
|
-
symmetric_pad:
|
|
275
|
+
scale_range: tuple[float, float] = (0.3, 0.9),
|
|
276
|
+
preserve_aspect_ratio: bool | float = False,
|
|
277
|
+
symmetric_pad: bool | float = False,
|
|
241
278
|
p: float = 0.5,
|
|
242
279
|
) -> None:
|
|
243
280
|
super().__init__()
|
|
@@ -247,7 +284,7 @@ class RandomResize(torch.nn.Module):
|
|
|
247
284
|
self.p = p
|
|
248
285
|
self._resize = Resize
|
|
249
286
|
|
|
250
|
-
def forward(self, img: torch.Tensor, target: np.ndarray) ->
|
|
287
|
+
def forward(self, img: torch.Tensor, target: np.ndarray) -> tuple[torch.Tensor, np.ndarray]:
|
|
251
288
|
if torch.rand(1) < self.p:
|
|
252
289
|
scale_h = np.random.uniform(*self.scale_range)
|
|
253
290
|
scale_w = np.random.uniform(*self.scale_range)
|