python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +1 -5
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1100 -54
- doctr/file_utils.py +2 -92
- doctr/io/elements.py +37 -3
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +4 -4
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -4
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +11 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vip/__init__.py +1 -0
- doctr/models/classification/vip/layers/__init__.py +1 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +12 -3
- doctr/models/classification/zoo.py +7 -8
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +7 -16
- doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +6 -17
- doctr/models/detection/fast/pytorch.py +17 -8
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +5 -15
- doctr/models/detection/linknet/pytorch.py +12 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +9 -22
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +52 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +16 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +15 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +26 -8
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +100 -47
- doctr/models/recognition/predictor/pytorch.py +4 -5
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +13 -4
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +1 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +13 -4
- doctr/models/recognition/zoo.py +13 -8
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +29 -19
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -433
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -397
- doctr/models/classification/textnet/tensorflow.py +0 -266
- doctr/models/classification/vgg/tensorflow.py +0 -116
- doctr/models/classification/vit/tensorflow.py +0 -192
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
- doctr/models/detection/fast/tensorflow.py +0 -419
- doctr/models/detection/linknet/tensorflow.py +0 -369
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -308
- doctr/models/recognition/master/tensorflow.py +0 -313
- doctr/models/recognition/parseq/tensorflow.py +0 -508
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -416
- doctr/models/recognition/vitstr/tensorflow.py +0 -278
- doctr/models/utils/tensorflow.py +0 -182
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.11.0.dist-info/RECORD +0 -173
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
doctr/models/utils/pytorch.py
CHANGED
|
@@ -7,6 +7,7 @@ import logging
|
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
|
+
import validators
|
|
10
11
|
from torch import nn
|
|
11
12
|
|
|
12
13
|
from doctr.utils.data import download_from_url
|
|
@@ -36,7 +37,7 @@ def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
|
|
|
36
37
|
|
|
37
38
|
def load_pretrained_params(
|
|
38
39
|
model: nn.Module,
|
|
39
|
-
|
|
40
|
+
path_or_url: str | None = None,
|
|
40
41
|
hash_prefix: str | None = None,
|
|
41
42
|
ignore_keys: list[str] | None = None,
|
|
42
43
|
**kwargs: Any,
|
|
@@ -44,33 +45,42 @@ def load_pretrained_params(
|
|
|
44
45
|
"""Load a set of parameters onto a model
|
|
45
46
|
|
|
46
47
|
>>> from doctr.models import load_pretrained_params
|
|
47
|
-
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.
|
|
48
|
+
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.pt")
|
|
48
49
|
|
|
49
50
|
Args:
|
|
50
51
|
model: the PyTorch model to be loaded
|
|
51
|
-
|
|
52
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
52
53
|
hash_prefix: first characters of SHA256 expected hash
|
|
53
54
|
ignore_keys: list of weights to be ignored from the state_dict
|
|
54
55
|
**kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
|
|
55
56
|
"""
|
|
56
|
-
if
|
|
57
|
-
logging.warning("
|
|
58
|
-
|
|
59
|
-
|
|
57
|
+
if path_or_url is None:
|
|
58
|
+
logging.warning("No model URL or Path provided, using default initialization.")
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
archive_path = (
|
|
62
|
+
download_from_url(path_or_url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
|
|
63
|
+
if validators.url(path_or_url)
|
|
64
|
+
else path_or_url
|
|
65
|
+
)
|
|
60
66
|
|
|
61
|
-
|
|
62
|
-
|
|
67
|
+
# Read state_dict
|
|
68
|
+
state_dict = torch.load(archive_path, map_location="cpu")
|
|
63
69
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
70
|
+
# Remove weights from the state_dict
|
|
71
|
+
if ignore_keys is not None and len(ignore_keys) > 0:
|
|
72
|
+
for key in ignore_keys:
|
|
73
|
+
if key in state_dict:
|
|
67
74
|
state_dict.pop(key)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
75
|
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
|
76
|
+
if any(k not in ignore_keys for k in missing_keys + unexpected_keys):
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Unable to load state_dict, due to non-matching keys.\n"
|
|
79
|
+
+ f"Unexpected keys: {unexpected_keys}\nMissing keys: {missing_keys}"
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
# Load weights
|
|
83
|
+
model.load_state_dict(state_dict)
|
|
74
84
|
|
|
75
85
|
|
|
76
86
|
def conv_sequence_pt(
|
|
@@ -154,7 +164,7 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
|
|
|
154
164
|
"""
|
|
155
165
|
torch.onnx.export(
|
|
156
166
|
model,
|
|
157
|
-
dummy_input,
|
|
167
|
+
dummy_input, # type: ignore[arg-type]
|
|
158
168
|
f"{model_name}.onnx",
|
|
159
169
|
input_names=["input"],
|
|
160
170
|
output_names=["logits"],
|
|
@@ -33,9 +33,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
|
|
|
33
33
|
rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
|
|
34
34
|
# Inverse the color
|
|
35
35
|
if out.dtype == torch.uint8:
|
|
36
|
-
out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
|
|
36
|
+
out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
|
|
37
37
|
else:
|
|
38
|
-
out = out * rgb_shift.to(dtype=out.dtype)
|
|
38
|
+
out = out * rgb_shift.to(dtype=out.dtype)
|
|
39
39
|
# Inverse the color
|
|
40
40
|
out = 255 - out if out.dtype == torch.uint8 else 1 - out
|
|
41
41
|
return out
|
|
@@ -77,7 +77,7 @@ def rotate_sample(
|
|
|
77
77
|
rotated_geoms: np.ndarray = rotate_abs_geoms(
|
|
78
78
|
_geoms,
|
|
79
79
|
angle,
|
|
80
|
-
img.shape[1:],
|
|
80
|
+
img.shape[1:], # type: ignore[arg-type]
|
|
81
81
|
expand,
|
|
82
82
|
).astype(np.float32)
|
|
83
83
|
|
|
@@ -124,7 +124,7 @@ def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwarg
|
|
|
124
124
|
Returns:
|
|
125
125
|
Shadowed image as a PyTorch tensor (same shape as input).
|
|
126
126
|
"""
|
|
127
|
-
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
|
|
127
|
+
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type]
|
|
128
128
|
opacity = np.random.uniform(*opacity_range)
|
|
129
129
|
|
|
130
130
|
# Apply Gaussian blur to the shadow mask
|
doctr/transforms/modules/base.py
CHANGED
|
@@ -20,27 +20,13 @@ __all__ = ["SampleCompose", "ImageTransform", "ColorInversion", "OneOf", "Random
|
|
|
20
20
|
class SampleCompose(NestedObject):
|
|
21
21
|
"""Implements a wrapper that will apply transformations sequentially on both image and target
|
|
22
22
|
|
|
23
|
-
..
|
|
23
|
+
.. code:: python
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
>>> import torch
|
|
31
|
-
>>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
|
|
32
|
-
>>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
|
|
33
|
-
>>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
|
|
34
|
-
|
|
35
|
-
.. tab:: TensorFlow
|
|
36
|
-
|
|
37
|
-
.. code:: python
|
|
38
|
-
|
|
39
|
-
>>> import numpy as np
|
|
40
|
-
>>> import tensorflow as tf
|
|
41
|
-
>>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
|
|
42
|
-
>>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
|
|
43
|
-
>>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4)))
|
|
25
|
+
>>> import numpy as np
|
|
26
|
+
>>> import torch
|
|
27
|
+
>>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
|
|
28
|
+
>>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
|
|
29
|
+
>>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
|
|
44
30
|
|
|
45
31
|
Args:
|
|
46
32
|
transforms: list of transformation modules
|
|
@@ -61,25 +47,12 @@ class SampleCompose(NestedObject):
|
|
|
61
47
|
class ImageTransform(NestedObject):
|
|
62
48
|
"""Implements a transform wrapper to turn an image-only transformation into an image+target transform
|
|
63
49
|
|
|
64
|
-
..
|
|
65
|
-
|
|
66
|
-
.. tab:: PyTorch
|
|
67
|
-
|
|
68
|
-
.. code:: python
|
|
50
|
+
.. code:: python
|
|
69
51
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
.. tab:: TensorFlow
|
|
76
|
-
|
|
77
|
-
.. code:: python
|
|
78
|
-
|
|
79
|
-
>>> import tensorflow as tf
|
|
80
|
-
>>> from doctr.transforms import ImageTransform, ColorInversion
|
|
81
|
-
>>> transfo = ImageTransform(ColorInversion((32, 32)))
|
|
82
|
-
>>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None)
|
|
52
|
+
>>> import torch
|
|
53
|
+
>>> from doctr.transforms import ImageTransform, ColorInversion
|
|
54
|
+
>>> transfo = ImageTransform(ColorInversion((32, 32)))
|
|
55
|
+
>>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
|
|
83
56
|
|
|
84
57
|
Args:
|
|
85
58
|
transform: the image transformation module to wrap
|
|
@@ -99,25 +72,12 @@ class ColorInversion(NestedObject):
|
|
|
99
72
|
"""Applies the following tranformation to a tensor (image or batch of images):
|
|
100
73
|
convert to grayscale, colorize (shift 0-values randomly), and then invert colors
|
|
101
74
|
|
|
102
|
-
..
|
|
103
|
-
|
|
104
|
-
.. tab:: PyTorch
|
|
105
|
-
|
|
106
|
-
.. code:: python
|
|
107
|
-
|
|
108
|
-
>>> import torch
|
|
109
|
-
>>> from doctr.transforms import ColorInversion
|
|
110
|
-
>>> transfo = ColorInversion(min_val=0.6)
|
|
111
|
-
>>> out = transfo(torch.rand(8, 64, 64, 3))
|
|
75
|
+
.. code:: python
|
|
112
76
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
>>> import tensorflow as tf
|
|
118
|
-
>>> from doctr.transforms import ColorInversion
|
|
119
|
-
>>> transfo = ColorInversion(min_val=0.6)
|
|
120
|
-
>>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1))
|
|
77
|
+
>>> import torch
|
|
78
|
+
>>> from doctr.transforms import ColorInversion
|
|
79
|
+
>>> transfo = ColorInversion(min_val=0.6)
|
|
80
|
+
>>> out = transfo(torch.rand(8, 64, 64, 3))
|
|
121
81
|
|
|
122
82
|
Args:
|
|
123
83
|
min_val: range [min_val, 1] to colorize RGB pixels
|
|
@@ -136,25 +96,12 @@ class ColorInversion(NestedObject):
|
|
|
136
96
|
class OneOf(NestedObject):
|
|
137
97
|
"""Randomly apply one of the input transformations
|
|
138
98
|
|
|
139
|
-
..
|
|
140
|
-
|
|
141
|
-
.. tab:: PyTorch
|
|
142
|
-
|
|
143
|
-
.. code:: python
|
|
144
|
-
|
|
145
|
-
>>> import torch
|
|
146
|
-
>>> from doctr.transforms import OneOf
|
|
147
|
-
>>> transfo = OneOf([JpegQuality(), Gamma()])
|
|
148
|
-
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
149
|
-
|
|
150
|
-
.. tab:: TensorFlow
|
|
99
|
+
.. code:: python
|
|
151
100
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
>>> transfo = OneOf([JpegQuality(), Gamma()])
|
|
157
|
-
>>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
|
|
101
|
+
>>> import torch
|
|
102
|
+
>>> from doctr.transforms import OneOf
|
|
103
|
+
>>> transfo = OneOf([JpegQuality(), Gamma()])
|
|
104
|
+
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
158
105
|
|
|
159
106
|
Args:
|
|
160
107
|
transforms: list of transformations, one only will be picked
|
|
@@ -175,25 +122,12 @@ class OneOf(NestedObject):
|
|
|
175
122
|
class RandomApply(NestedObject):
|
|
176
123
|
"""Apply with a probability p the input transformation
|
|
177
124
|
|
|
178
|
-
..
|
|
179
|
-
|
|
180
|
-
.. tab:: PyTorch
|
|
181
|
-
|
|
182
|
-
.. code:: python
|
|
183
|
-
|
|
184
|
-
>>> import torch
|
|
185
|
-
>>> from doctr.transforms import RandomApply
|
|
186
|
-
>>> transfo = RandomApply(Gamma(), p=.5)
|
|
187
|
-
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
188
|
-
|
|
189
|
-
.. tab:: TensorFlow
|
|
190
|
-
|
|
191
|
-
.. code:: python
|
|
125
|
+
.. code:: python
|
|
192
126
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
127
|
+
>>> import torch
|
|
128
|
+
>>> from doctr.transforms import RandomApply
|
|
129
|
+
>>> transfo = RandomApply(Gamma(), p=.5)
|
|
130
|
+
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
197
131
|
|
|
198
132
|
Args:
|
|
199
133
|
transform: transformation to apply
|
|
@@ -13,7 +13,7 @@ from torch.nn.functional import pad
|
|
|
13
13
|
from torchvision.transforms import functional as F
|
|
14
14
|
from torchvision.transforms import transforms as T
|
|
15
15
|
|
|
16
|
-
from ..functional
|
|
16
|
+
from ..functional import random_shadow
|
|
17
17
|
|
|
18
18
|
__all__ = [
|
|
19
19
|
"Resize",
|
|
@@ -27,7 +27,21 @@ __all__ = [
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class Resize(T.Resize):
|
|
30
|
-
"""Resize the input image to the given size
|
|
30
|
+
"""Resize the input image to the given size
|
|
31
|
+
|
|
32
|
+
>>> import torch
|
|
33
|
+
>>> from doctr.transforms import Resize
|
|
34
|
+
>>> transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=True)
|
|
35
|
+
>>> out = transfo(torch.rand((3, 64, 64)))
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
size: output size in pixels, either a tuple (height, width) or a single integer for square images
|
|
39
|
+
interpolation: interpolation mode to use for resizing, default is bilinear
|
|
40
|
+
preserve_aspect_ratio: whether to preserve the aspect ratio of the image,
|
|
41
|
+
if True, the image will be resized to fit within the target size while maintaining its aspect ratio
|
|
42
|
+
symmetric_pad: whether to symmetrically pad the image to the target size,
|
|
43
|
+
if True, the image will be padded equally on both sides to fit the target size
|
|
44
|
+
"""
|
|
31
45
|
|
|
32
46
|
def __init__(
|
|
33
47
|
self,
|
|
@@ -36,25 +50,19 @@ class Resize(T.Resize):
|
|
|
36
50
|
preserve_aspect_ratio: bool = False,
|
|
37
51
|
symmetric_pad: bool = False,
|
|
38
52
|
) -> None:
|
|
39
|
-
super().__init__(size, interpolation, antialias=True)
|
|
53
|
+
super().__init__(size if isinstance(size, (list, tuple)) else (size, size), interpolation, antialias=True)
|
|
40
54
|
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
41
55
|
self.symmetric_pad = symmetric_pad
|
|
42
56
|
|
|
43
|
-
if not isinstance(self.size, (int, tuple, list)):
|
|
44
|
-
raise AssertionError("size should be either a tuple, a list or an int")
|
|
45
|
-
|
|
46
57
|
def forward(
|
|
47
58
|
self,
|
|
48
59
|
img: torch.Tensor,
|
|
49
60
|
target: np.ndarray | None = None,
|
|
50
61
|
) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
|
|
51
|
-
|
|
52
|
-
target_ratio = img.shape[-2] / img.shape[-1]
|
|
53
|
-
else:
|
|
54
|
-
target_ratio = self.size[0] / self.size[1]
|
|
62
|
+
target_ratio = self.size[0] / self.size[1]
|
|
55
63
|
actual_ratio = img.shape[-2] / img.shape[-1]
|
|
56
64
|
|
|
57
|
-
if not self.preserve_aspect_ratio or (target_ratio == actual_ratio
|
|
65
|
+
if not self.preserve_aspect_ratio or (target_ratio == actual_ratio):
|
|
58
66
|
# If we don't preserve the aspect ratio or the wanted aspect ratio is the same than the original one
|
|
59
67
|
# We can use with the regular resize
|
|
60
68
|
if target is not None:
|
|
@@ -62,16 +70,10 @@ class Resize(T.Resize):
|
|
|
62
70
|
return super().forward(img)
|
|
63
71
|
else:
|
|
64
72
|
# Resize
|
|
65
|
-
if
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1])
|
|
70
|
-
elif isinstance(self.size, int): # self.size is the longest side, infer the other
|
|
71
|
-
if img.shape[-2] <= img.shape[-1]:
|
|
72
|
-
tmp_size = (max(int(self.size * actual_ratio), 1), self.size)
|
|
73
|
-
else:
|
|
74
|
-
tmp_size = (self.size, max(int(self.size / actual_ratio), 1))
|
|
73
|
+
if actual_ratio > target_ratio:
|
|
74
|
+
tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1))
|
|
75
|
+
else:
|
|
76
|
+
tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1])
|
|
75
77
|
|
|
76
78
|
# Scale image
|
|
77
79
|
img = F.resize(img, tmp_size, self.interpolation, antialias=True)
|
|
@@ -93,14 +95,14 @@ class Resize(T.Resize):
|
|
|
93
95
|
if self.preserve_aspect_ratio:
|
|
94
96
|
# Get absolute coords
|
|
95
97
|
if target.shape[1:] == (4,):
|
|
96
|
-
if
|
|
98
|
+
if self.symmetric_pad:
|
|
97
99
|
target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
|
|
98
100
|
target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
|
|
99
101
|
else:
|
|
100
102
|
target[:, [0, 2]] *= raw_shape[-1] / img.shape[-1]
|
|
101
103
|
target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
|
|
102
104
|
elif target.shape[1:] == (4, 2):
|
|
103
|
-
if
|
|
105
|
+
if self.symmetric_pad:
|
|
104
106
|
target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
|
|
105
107
|
target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
|
|
106
108
|
else:
|
|
@@ -143,9 +145,9 @@ class GaussianNoise(torch.nn.Module):
|
|
|
143
145
|
# Reshape the distribution
|
|
144
146
|
noise = self.mean + 2 * self.std * torch.rand(x.shape, device=x.device) - self.std
|
|
145
147
|
if x.dtype == torch.uint8:
|
|
146
|
-
return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8)
|
|
148
|
+
return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8)
|
|
147
149
|
else:
|
|
148
|
-
return (x + noise.to(dtype=x.dtype)).clamp(0, 1)
|
|
150
|
+
return (x + noise.to(dtype=x.dtype)).clamp(0, 1)
|
|
149
151
|
|
|
150
152
|
def extra_repr(self) -> str:
|
|
151
153
|
return f"mean={self.mean}, std={self.std}"
|
|
@@ -233,7 +235,7 @@ class RandomShadow(torch.nn.Module):
|
|
|
233
235
|
try:
|
|
234
236
|
if x.dtype == torch.uint8:
|
|
235
237
|
return (
|
|
236
|
-
(
|
|
238
|
+
(
|
|
237
239
|
255
|
|
238
240
|
* random_shadow(
|
|
239
241
|
x.to(dtype=torch.float32) / 255,
|
doctr/utils/data.py
CHANGED
|
@@ -92,7 +92,7 @@ def download_from_url(
|
|
|
92
92
|
# Create folder hierarchy
|
|
93
93
|
folder_path.mkdir(parents=True, exist_ok=True)
|
|
94
94
|
except OSError:
|
|
95
|
-
error_message = f"Failed creating cache
|
|
95
|
+
error_message = f"Failed creating cache directory at {folder_path}"
|
|
96
96
|
if os.environ.get("DOCTR_CACHE_DIR", ""):
|
|
97
97
|
error_message += " using path from 'DOCTR_CACHE_DIR' environment variable."
|
|
98
98
|
else:
|
doctr/utils/geometry.py
CHANGED
|
@@ -300,7 +300,7 @@ def rotate_image(
|
|
|
300
300
|
# Compute the expanded padding
|
|
301
301
|
exp_img: np.ndarray
|
|
302
302
|
if expand:
|
|
303
|
-
exp_shape = compute_expanded_shape(image.shape[:2], angle)
|
|
303
|
+
exp_shape = compute_expanded_shape(image.shape[:2], angle)
|
|
304
304
|
h_pad, w_pad = (
|
|
305
305
|
int(max(0, ceil(exp_shape[0] - image.shape[0]))),
|
|
306
306
|
int(max(0, ceil(exp_shape[1] - image.shape[1]))),
|
|
@@ -390,14 +390,13 @@ def convert_to_relative_coords(geoms: np.ndarray, img_shape: tuple[int, int]) ->
|
|
|
390
390
|
raise ValueError(f"invalid format for arg `geoms`: {geoms.shape}")
|
|
391
391
|
|
|
392
392
|
|
|
393
|
-
def extract_crops(img: np.ndarray, boxes: np.ndarray
|
|
393
|
+
def extract_crops(img: np.ndarray, boxes: np.ndarray) -> list[np.ndarray]:
|
|
394
394
|
"""Created cropped images from list of bounding boxes
|
|
395
395
|
|
|
396
396
|
Args:
|
|
397
397
|
img: input image
|
|
398
398
|
boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative
|
|
399
399
|
coordinates (xmin, ymin, xmax, ymax)
|
|
400
|
-
channels_last: whether the channel dimensions is the last one instead of the last one
|
|
401
400
|
|
|
402
401
|
Returns:
|
|
403
402
|
list of cropped images
|
|
@@ -409,21 +408,19 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True
|
|
|
409
408
|
|
|
410
409
|
# Project relative coordinates
|
|
411
410
|
_boxes = boxes.copy()
|
|
412
|
-
h, w = img.shape[:2]
|
|
411
|
+
h, w = img.shape[:2]
|
|
413
412
|
if not np.issubdtype(_boxes.dtype, np.integer):
|
|
414
413
|
_boxes[:, [0, 2]] *= w
|
|
415
414
|
_boxes[:, [1, 3]] *= h
|
|
416
415
|
_boxes = _boxes.round().astype(int)
|
|
417
416
|
# Add last index
|
|
418
417
|
_boxes[2:] += 1
|
|
419
|
-
if channels_last:
|
|
420
|
-
return deepcopy([img[box[1] : box[3], box[0] : box[2]] for box in _boxes])
|
|
421
418
|
|
|
422
|
-
return deepcopy([img[
|
|
419
|
+
return deepcopy([img[box[1] : box[3], box[0] : box[2]] for box in _boxes])
|
|
423
420
|
|
|
424
421
|
|
|
425
422
|
def extract_rcrops(
|
|
426
|
-
img: np.ndarray, polys: np.ndarray, dtype=np.float32,
|
|
423
|
+
img: np.ndarray, polys: np.ndarray, dtype=np.float32, assume_horizontal: bool = False
|
|
427
424
|
) -> list[np.ndarray]:
|
|
428
425
|
"""Created cropped images from list of rotated bounding boxes
|
|
429
426
|
|
|
@@ -431,7 +428,6 @@ def extract_rcrops(
|
|
|
431
428
|
img: input image
|
|
432
429
|
polys: bounding boxes of shape (N, 4, 2)
|
|
433
430
|
dtype: target data type of bounding boxes
|
|
434
|
-
channels_last: whether the channel dimensions is the last one instead of the last one
|
|
435
431
|
assume_horizontal: whether the boxes are assumed to be only horizontally oriented
|
|
436
432
|
|
|
437
433
|
Returns:
|
|
@@ -444,12 +440,12 @@ def extract_rcrops(
|
|
|
444
440
|
|
|
445
441
|
# Project relative coordinates
|
|
446
442
|
_boxes = polys.copy()
|
|
447
|
-
height, width = img.shape[:2]
|
|
443
|
+
height, width = img.shape[:2]
|
|
448
444
|
if not np.issubdtype(_boxes.dtype, np.integer):
|
|
449
445
|
_boxes[:, :, 0] *= width
|
|
450
446
|
_boxes[:, :, 1] *= height
|
|
451
447
|
|
|
452
|
-
src_img = img
|
|
448
|
+
src_img = img
|
|
453
449
|
|
|
454
450
|
# Handle only horizontal oriented boxes
|
|
455
451
|
if assume_horizontal:
|
doctr/utils/visualization.py
CHANGED
|
@@ -148,7 +148,7 @@ def get_colors(num_colors: int) -> list[tuple[float, float, float]]:
|
|
|
148
148
|
hue = i / 360.0
|
|
149
149
|
lightness = (50 + np.random.rand() * 10) / 100.0
|
|
150
150
|
saturation = (90 + np.random.rand() * 10) / 100.0
|
|
151
|
-
colors.append(colorsys.hls_to_rgb(hue, lightness, saturation))
|
|
151
|
+
colors.append(colorsys.hls_to_rgb(hue, lightness, saturation)) # type: ignore[arg-type]
|
|
152
152
|
return colors
|
|
153
153
|
|
|
154
154
|
|
doctr/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '
|
|
1
|
+
__version__ = 'v1.0.0'
|