python-doctr 0.12.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +0 -5
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/vocabs.py +0 -2
- doctr/file_utils.py +2 -101
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +3 -3
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +2 -2
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +1 -1
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +1 -1
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vip/__init__.py +1 -4
- doctr/models/classification/vip/layers/__init__.py +1 -4
- doctr/models/classification/vip/layers/pytorch.py +1 -1
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +2 -2
- doctr/models/classification/zoo.py +6 -11
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +4 -12
- doctr/models/detection/differentiable_binarization/pytorch.py +3 -3
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +4 -14
- doctr/models/detection/fast/pytorch.py +4 -4
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +3 -12
- doctr/models/detection/linknet/pytorch.py +2 -2
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +8 -21
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +2 -6
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +3 -3
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +2 -5
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +6 -6
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +5 -5
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +5 -5
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +7 -16
- doctr/models/recognition/predictor/pytorch.py +1 -2
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +3 -3
- doctr/models/recognition/viptr/__init__.py +1 -4
- doctr/models/recognition/viptr/pytorch.py +3 -3
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +3 -3
- doctr/models/recognition/zoo.py +13 -13
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +1 -1
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/geometry.py +6 -10
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +18 -75
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -442
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -418
- doctr/models/classification/textnet/tensorflow.py +0 -275
- doctr/models/classification/vgg/tensorflow.py +0 -125
- doctr/models/classification/vit/tensorflow.py +0 -201
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
- doctr/models/detection/fast/tensorflow.py +0 -427
- doctr/models/detection/linknet/tensorflow.py +0 -377
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -317
- doctr/models/recognition/master/tensorflow.py +0 -320
- doctr/models/recognition/parseq/tensorflow.py +0 -516
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -423
- doctr/models/recognition/vitstr/tensorflow.py +0 -285
- doctr/models/utils/tensorflow.py +0 -189
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.12.0.dist-info/RECORD +0 -180
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
|
@@ -19,7 +19,7 @@ from doctr.datasets import VOCABS
|
|
|
19
19
|
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
|
|
20
20
|
|
|
21
21
|
from ...classification import vit_s
|
|
22
|
-
from ...utils
|
|
22
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
23
23
|
from .base import _PARSeq, _PARSeqPostProcessor
|
|
24
24
|
|
|
25
25
|
__all__ = ["PARSeq", "parseq"]
|
|
@@ -299,7 +299,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
299
299
|
|
|
300
300
|
# Stop decoding if all sequences have reached the EOS token
|
|
301
301
|
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
|
|
302
|
-
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
|
|
302
|
+
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
|
|
303
303
|
break
|
|
304
304
|
|
|
305
305
|
logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
|
|
@@ -314,7 +314,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
314
314
|
|
|
315
315
|
# Create padding mask for refined target input maskes all behind EOS token as False
|
|
316
316
|
# (N, 1, 1, max_length)
|
|
317
|
-
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
317
|
+
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
318
318
|
mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
|
|
319
319
|
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
|
|
320
320
|
|
|
@@ -367,7 +367,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
367
367
|
# remove the [EOS] tokens for the succeeding perms
|
|
368
368
|
if i == 1:
|
|
369
369
|
gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out)
|
|
370
|
-
n = (gt_out != self.vocab_size + 2).sum().item()
|
|
370
|
+
n = (gt_out != self.vocab_size + 2).sum().item()
|
|
371
371
|
|
|
372
372
|
loss /= loss_numel
|
|
373
373
|
|
|
@@ -391,7 +391,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
391
391
|
|
|
392
392
|
if target is None or return_preds:
|
|
393
393
|
# Disable for torch.compile compatibility
|
|
394
|
-
@torch.compiler.disable
|
|
394
|
+
@torch.compiler.disable
|
|
395
395
|
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
396
396
|
return self.postprocessor(logits)
|
|
397
397
|
|
|
@@ -18,17 +18,15 @@ def split_crops(
|
|
|
18
18
|
max_ratio: float,
|
|
19
19
|
target_ratio: int,
|
|
20
20
|
split_overlap_ratio: float,
|
|
21
|
-
channels_last: bool = True,
|
|
22
21
|
) -> tuple[list[np.ndarray], list[int | tuple[int, int, float]], bool]:
|
|
23
22
|
"""
|
|
24
23
|
Split crops horizontally if they exceed a given aspect ratio.
|
|
25
24
|
|
|
26
25
|
Args:
|
|
27
|
-
crops: List of image crops (H, W, C)
|
|
26
|
+
crops: List of image crops (H, W, C).
|
|
28
27
|
max_ratio: Aspect ratio threshold above which crops are split.
|
|
29
28
|
target_ratio: Target aspect ratio after splitting (e.g., 4 for 128x32).
|
|
30
29
|
split_overlap_ratio: Desired overlap between splits (as a fraction of split width).
|
|
31
|
-
channels_last: Whether the crops are in channels-last format.
|
|
32
30
|
|
|
33
31
|
Returns:
|
|
34
32
|
A tuple containing:
|
|
@@ -44,14 +42,14 @@ def split_crops(
|
|
|
44
42
|
crop_map: list[int | tuple[int, int, float]] = []
|
|
45
43
|
|
|
46
44
|
for crop in crops:
|
|
47
|
-
h, w = crop.shape[:2]
|
|
45
|
+
h, w = crop.shape[:2]
|
|
48
46
|
aspect_ratio = w / h
|
|
49
47
|
|
|
50
48
|
if aspect_ratio > max_ratio:
|
|
51
49
|
split_width = max(1, math.ceil(h * target_ratio))
|
|
52
50
|
overlap_width = max(0, math.floor(split_width * split_overlap_ratio))
|
|
53
51
|
|
|
54
|
-
splits, last_overlap = _split_horizontally(crop, split_width, overlap_width
|
|
52
|
+
splits, last_overlap = _split_horizontally(crop, split_width, overlap_width)
|
|
55
53
|
|
|
56
54
|
# Remove any empty splits
|
|
57
55
|
splits = [s for s in splits if all(dim > 0 for dim in s.shape)]
|
|
@@ -70,23 +68,20 @@ def split_crops(
|
|
|
70
68
|
return new_crops, crop_map, remap_required
|
|
71
69
|
|
|
72
70
|
|
|
73
|
-
def _split_horizontally(
|
|
74
|
-
image: np.ndarray, split_width: int, overlap_width: int, channels_last: bool
|
|
75
|
-
) -> tuple[list[np.ndarray], float]:
|
|
71
|
+
def _split_horizontally(image: np.ndarray, split_width: int, overlap_width: int) -> tuple[list[np.ndarray], float]:
|
|
76
72
|
"""
|
|
77
73
|
Horizontally split a single image with overlapping regions.
|
|
78
74
|
|
|
79
75
|
Args:
|
|
80
|
-
image: The image to split (H, W, C)
|
|
76
|
+
image: The image to split (H, W, C).
|
|
81
77
|
split_width: Width of each split.
|
|
82
78
|
overlap_width: Width of the overlapping region.
|
|
83
|
-
channels_last: Whether the image is in channels-last format.
|
|
84
79
|
|
|
85
80
|
Returns:
|
|
86
81
|
- A list of horizontal image slices.
|
|
87
82
|
- The actual overlap ratio of the last split.
|
|
88
83
|
"""
|
|
89
|
-
image_width = image.shape[1]
|
|
84
|
+
image_width = image.shape[1]
|
|
90
85
|
if image_width <= split_width:
|
|
91
86
|
return [image], 0.0
|
|
92
87
|
|
|
@@ -101,11 +96,7 @@ def _split_horizontally(
|
|
|
101
96
|
splits = []
|
|
102
97
|
for start_col in starts:
|
|
103
98
|
end_col = start_col + split_width
|
|
104
|
-
|
|
105
|
-
split = image[:, start_col:end_col, :]
|
|
106
|
-
else:
|
|
107
|
-
split = image[:, :, start_col:end_col]
|
|
108
|
-
splits.append(split)
|
|
99
|
+
splits.append(image[:, start_col:end_col, :])
|
|
109
100
|
|
|
110
101
|
# Calculate the last overlap ratio, if only one split no overlap
|
|
111
102
|
last_overlap = 0
|
|
@@ -44,7 +44,7 @@ class RecognitionPredictor(nn.Module):
|
|
|
44
44
|
@torch.inference_mode()
|
|
45
45
|
def forward(
|
|
46
46
|
self,
|
|
47
|
-
crops: Sequence[np.ndarray
|
|
47
|
+
crops: Sequence[np.ndarray],
|
|
48
48
|
**kwargs: Any,
|
|
49
49
|
) -> list[tuple[str, float]]:
|
|
50
50
|
if len(crops) == 0:
|
|
@@ -61,7 +61,6 @@ class RecognitionPredictor(nn.Module):
|
|
|
61
61
|
self.critical_ar,
|
|
62
62
|
self.target_ar,
|
|
63
63
|
self.overlap_ratio,
|
|
64
|
-
isinstance(crops[0], np.ndarray),
|
|
65
64
|
)
|
|
66
65
|
if remapped:
|
|
67
66
|
crops = new_crops
|
|
@@ -15,7 +15,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
17
|
from ...classification import resnet31
|
|
18
|
-
from ...utils
|
|
18
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
19
19
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
20
20
|
|
|
21
21
|
__all__ = ["SAR", "sar_resnet31"]
|
|
@@ -272,7 +272,7 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
272
272
|
|
|
273
273
|
if target is None or return_preds:
|
|
274
274
|
# Disable for torch.compile compatibility
|
|
275
|
-
@torch.compiler.disable
|
|
275
|
+
@torch.compiler.disable
|
|
276
276
|
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
277
277
|
return self.postprocessor(decoded_features)
|
|
278
278
|
|
|
@@ -304,7 +304,7 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
304
304
|
# Input length : number of timesteps
|
|
305
305
|
input_len = model_output.shape[1]
|
|
306
306
|
# Add one for additional <eos> token
|
|
307
|
-
seq_len = seq_len + 1
|
|
307
|
+
seq_len = seq_len + 1
|
|
308
308
|
# Compute loss
|
|
309
309
|
# (N, L, vocab_size + 1)
|
|
310
310
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
|
|
@@ -16,7 +16,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
|
|
16
16
|
from doctr.datasets import VOCABS, decode_sequence
|
|
17
17
|
|
|
18
18
|
from ...classification import vip_tiny
|
|
19
|
-
from ...utils
|
|
19
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
20
20
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
21
21
|
|
|
22
22
|
__all__ = ["VIPTR", "viptr_tiny"]
|
|
@@ -70,7 +70,7 @@ class VIPTRPostProcessor(RecognitionPostProcessor):
|
|
|
70
70
|
|
|
71
71
|
def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
72
72
|
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
73
|
-
with label_to_idx mapping
|
|
73
|
+
with label_to_idx mapping dictionary
|
|
74
74
|
|
|
75
75
|
Args:
|
|
76
76
|
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
@@ -166,7 +166,7 @@ class VIPTR(RecognitionModel, nn.Module):
|
|
|
166
166
|
|
|
167
167
|
if target is None or return_preds:
|
|
168
168
|
# Disable for torch.compile compatibility
|
|
169
|
-
@torch.compiler.disable
|
|
169
|
+
@torch.compiler.disable
|
|
170
170
|
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
171
171
|
return self.postprocessor(decoded_features)
|
|
172
172
|
|
|
@@ -15,7 +15,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
17
|
from ...classification import vit_b, vit_s
|
|
18
|
-
from ...utils
|
|
18
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
19
19
|
from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
20
20
|
|
|
21
21
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
@@ -117,7 +117,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
117
117
|
|
|
118
118
|
if target is None or return_preds:
|
|
119
119
|
# Disable for torch.compile compatibility
|
|
120
|
-
@torch.compiler.disable
|
|
120
|
+
@torch.compiler.disable
|
|
121
121
|
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
122
122
|
return self.postprocessor(decoded_features)
|
|
123
123
|
|
|
@@ -149,7 +149,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
149
149
|
# Input length : number of steps
|
|
150
150
|
input_len = model_output.shape[1]
|
|
151
151
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
152
|
-
seq_len = seq_len + 1
|
|
152
|
+
seq_len = seq_len + 1
|
|
153
153
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
154
154
|
# The "masked" first gt char is <sos>.
|
|
155
155
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
|
doctr/models/recognition/zoo.py
CHANGED
|
@@ -5,8 +5,8 @@
|
|
|
5
5
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.file_utils import is_tf_available, is_torch_available
|
|
9
8
|
from doctr.models.preprocessor import PreProcessor
|
|
9
|
+
from doctr.models.utils import _CompiledModule
|
|
10
10
|
|
|
11
11
|
from .. import recognition
|
|
12
12
|
from .predictor import RecognitionPredictor
|
|
@@ -23,11 +23,9 @@ ARCHS: list[str] = [
|
|
|
23
23
|
"vitstr_small",
|
|
24
24
|
"vitstr_base",
|
|
25
25
|
"parseq",
|
|
26
|
+
"viptr_tiny",
|
|
26
27
|
]
|
|
27
28
|
|
|
28
|
-
if is_torch_available():
|
|
29
|
-
ARCHS.extend(["viptr_tiny"])
|
|
30
|
-
|
|
31
29
|
|
|
32
30
|
def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor:
|
|
33
31
|
if isinstance(arch, str):
|
|
@@ -38,14 +36,16 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
38
36
|
pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
|
|
39
37
|
)
|
|
40
38
|
else:
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
39
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
40
|
+
allowed_archs = [
|
|
41
|
+
recognition.CRNN,
|
|
42
|
+
recognition.SAR,
|
|
43
|
+
recognition.MASTER,
|
|
44
|
+
recognition.ViTSTR,
|
|
45
|
+
recognition.PARSeq,
|
|
46
|
+
recognition.VIPTR,
|
|
47
|
+
_CompiledModule,
|
|
48
|
+
]
|
|
49
49
|
|
|
50
50
|
if not isinstance(arch, tuple(allowed_archs)):
|
|
51
51
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
@@ -56,7 +56,7 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
56
56
|
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
57
57
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
58
58
|
kwargs["batch_size"] = kwargs.get("batch_size", 128)
|
|
59
|
-
input_shape = _model.cfg["input_shape"][
|
|
59
|
+
input_shape = _model.cfg["input_shape"][-2:]
|
|
60
60
|
predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model)
|
|
61
61
|
|
|
62
62
|
return predictor
|
doctr/models/utils/__init__.py
CHANGED
doctr/models/utils/pytorch.py
CHANGED
|
@@ -164,7 +164,7 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
|
|
|
164
164
|
"""
|
|
165
165
|
torch.onnx.export(
|
|
166
166
|
model,
|
|
167
|
-
dummy_input,
|
|
167
|
+
dummy_input, # type: ignore[arg-type]
|
|
168
168
|
f"{model_name}.onnx",
|
|
169
169
|
input_names=["input"],
|
|
170
170
|
output_names=["logits"],
|
|
@@ -33,9 +33,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
|
|
|
33
33
|
rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
|
|
34
34
|
# Inverse the color
|
|
35
35
|
if out.dtype == torch.uint8:
|
|
36
|
-
out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
|
|
36
|
+
out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
|
|
37
37
|
else:
|
|
38
|
-
out = out * rgb_shift.to(dtype=out.dtype)
|
|
38
|
+
out = out * rgb_shift.to(dtype=out.dtype)
|
|
39
39
|
# Inverse the color
|
|
40
40
|
out = 255 - out if out.dtype == torch.uint8 else 1 - out
|
|
41
41
|
return out
|
|
@@ -77,7 +77,7 @@ def rotate_sample(
|
|
|
77
77
|
rotated_geoms: np.ndarray = rotate_abs_geoms(
|
|
78
78
|
_geoms,
|
|
79
79
|
angle,
|
|
80
|
-
img.shape[1:],
|
|
80
|
+
img.shape[1:], # type: ignore[arg-type]
|
|
81
81
|
expand,
|
|
82
82
|
).astype(np.float32)
|
|
83
83
|
|
|
@@ -124,7 +124,7 @@ def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwarg
|
|
|
124
124
|
Returns:
|
|
125
125
|
Shadowed image as a PyTorch tensor (same shape as input).
|
|
126
126
|
"""
|
|
127
|
-
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
|
|
127
|
+
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type]
|
|
128
128
|
opacity = np.random.uniform(*opacity_range)
|
|
129
129
|
|
|
130
130
|
# Apply Gaussian blur to the shadow mask
|
doctr/transforms/modules/base.py
CHANGED
|
@@ -20,27 +20,13 @@ __all__ = ["SampleCompose", "ImageTransform", "ColorInversion", "OneOf", "Random
|
|
|
20
20
|
class SampleCompose(NestedObject):
|
|
21
21
|
"""Implements a wrapper that will apply transformations sequentially on both image and target
|
|
22
22
|
|
|
23
|
-
..
|
|
23
|
+
.. code:: python
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
>>> import torch
|
|
31
|
-
>>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
|
|
32
|
-
>>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
|
|
33
|
-
>>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
|
|
34
|
-
|
|
35
|
-
.. tab:: TensorFlow
|
|
36
|
-
|
|
37
|
-
.. code:: python
|
|
38
|
-
|
|
39
|
-
>>> import numpy as np
|
|
40
|
-
>>> import tensorflow as tf
|
|
41
|
-
>>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
|
|
42
|
-
>>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
|
|
43
|
-
>>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4)))
|
|
25
|
+
>>> import numpy as np
|
|
26
|
+
>>> import torch
|
|
27
|
+
>>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
|
|
28
|
+
>>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
|
|
29
|
+
>>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
|
|
44
30
|
|
|
45
31
|
Args:
|
|
46
32
|
transforms: list of transformation modules
|
|
@@ -61,25 +47,12 @@ class SampleCompose(NestedObject):
|
|
|
61
47
|
class ImageTransform(NestedObject):
|
|
62
48
|
"""Implements a transform wrapper to turn an image-only transformation into an image+target transform
|
|
63
49
|
|
|
64
|
-
..
|
|
65
|
-
|
|
66
|
-
.. tab:: PyTorch
|
|
67
|
-
|
|
68
|
-
.. code:: python
|
|
50
|
+
.. code:: python
|
|
69
51
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
.. tab:: TensorFlow
|
|
76
|
-
|
|
77
|
-
.. code:: python
|
|
78
|
-
|
|
79
|
-
>>> import tensorflow as tf
|
|
80
|
-
>>> from doctr.transforms import ImageTransform, ColorInversion
|
|
81
|
-
>>> transfo = ImageTransform(ColorInversion((32, 32)))
|
|
82
|
-
>>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None)
|
|
52
|
+
>>> import torch
|
|
53
|
+
>>> from doctr.transforms import ImageTransform, ColorInversion
|
|
54
|
+
>>> transfo = ImageTransform(ColorInversion((32, 32)))
|
|
55
|
+
>>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
|
|
83
56
|
|
|
84
57
|
Args:
|
|
85
58
|
transform: the image transformation module to wrap
|
|
@@ -99,25 +72,12 @@ class ColorInversion(NestedObject):
|
|
|
99
72
|
"""Applies the following tranformation to a tensor (image or batch of images):
|
|
100
73
|
convert to grayscale, colorize (shift 0-values randomly), and then invert colors
|
|
101
74
|
|
|
102
|
-
..
|
|
103
|
-
|
|
104
|
-
.. tab:: PyTorch
|
|
105
|
-
|
|
106
|
-
.. code:: python
|
|
107
|
-
|
|
108
|
-
>>> import torch
|
|
109
|
-
>>> from doctr.transforms import ColorInversion
|
|
110
|
-
>>> transfo = ColorInversion(min_val=0.6)
|
|
111
|
-
>>> out = transfo(torch.rand(8, 64, 64, 3))
|
|
75
|
+
.. code:: python
|
|
112
76
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
>>> import tensorflow as tf
|
|
118
|
-
>>> from doctr.transforms import ColorInversion
|
|
119
|
-
>>> transfo = ColorInversion(min_val=0.6)
|
|
120
|
-
>>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1))
|
|
77
|
+
>>> import torch
|
|
78
|
+
>>> from doctr.transforms import ColorInversion
|
|
79
|
+
>>> transfo = ColorInversion(min_val=0.6)
|
|
80
|
+
>>> out = transfo(torch.rand(8, 64, 64, 3))
|
|
121
81
|
|
|
122
82
|
Args:
|
|
123
83
|
min_val: range [min_val, 1] to colorize RGB pixels
|
|
@@ -136,25 +96,12 @@ class ColorInversion(NestedObject):
|
|
|
136
96
|
class OneOf(NestedObject):
|
|
137
97
|
"""Randomly apply one of the input transformations
|
|
138
98
|
|
|
139
|
-
..
|
|
140
|
-
|
|
141
|
-
.. tab:: PyTorch
|
|
142
|
-
|
|
143
|
-
.. code:: python
|
|
144
|
-
|
|
145
|
-
>>> import torch
|
|
146
|
-
>>> from doctr.transforms import OneOf
|
|
147
|
-
>>> transfo = OneOf([JpegQuality(), Gamma()])
|
|
148
|
-
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
149
|
-
|
|
150
|
-
.. tab:: TensorFlow
|
|
99
|
+
.. code:: python
|
|
151
100
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
>>> transfo = OneOf([JpegQuality(), Gamma()])
|
|
157
|
-
>>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
|
|
101
|
+
>>> import torch
|
|
102
|
+
>>> from doctr.transforms import OneOf
|
|
103
|
+
>>> transfo = OneOf([JpegQuality(), Gamma()])
|
|
104
|
+
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
158
105
|
|
|
159
106
|
Args:
|
|
160
107
|
transforms: list of transformations, one only will be picked
|
|
@@ -175,25 +122,12 @@ class OneOf(NestedObject):
|
|
|
175
122
|
class RandomApply(NestedObject):
|
|
176
123
|
"""Apply with a probability p the input transformation
|
|
177
124
|
|
|
178
|
-
..
|
|
179
|
-
|
|
180
|
-
.. tab:: PyTorch
|
|
181
|
-
|
|
182
|
-
.. code:: python
|
|
183
|
-
|
|
184
|
-
>>> import torch
|
|
185
|
-
>>> from doctr.transforms import RandomApply
|
|
186
|
-
>>> transfo = RandomApply(Gamma(), p=.5)
|
|
187
|
-
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
188
|
-
|
|
189
|
-
.. tab:: TensorFlow
|
|
190
|
-
|
|
191
|
-
.. code:: python
|
|
125
|
+
.. code:: python
|
|
192
126
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
127
|
+
>>> import torch
|
|
128
|
+
>>> from doctr.transforms import RandomApply
|
|
129
|
+
>>> transfo = RandomApply(Gamma(), p=.5)
|
|
130
|
+
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
197
131
|
|
|
198
132
|
Args:
|
|
199
133
|
transform: transformation to apply
|
|
@@ -13,7 +13,7 @@ from torch.nn.functional import pad
|
|
|
13
13
|
from torchvision.transforms import functional as F
|
|
14
14
|
from torchvision.transforms import transforms as T
|
|
15
15
|
|
|
16
|
-
from ..functional
|
|
16
|
+
from ..functional import random_shadow
|
|
17
17
|
|
|
18
18
|
__all__ = [
|
|
19
19
|
"Resize",
|
|
@@ -27,7 +27,21 @@ __all__ = [
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class Resize(T.Resize):
|
|
30
|
-
"""Resize the input image to the given size
|
|
30
|
+
"""Resize the input image to the given size
|
|
31
|
+
|
|
32
|
+
>>> import torch
|
|
33
|
+
>>> from doctr.transforms import Resize
|
|
34
|
+
>>> transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=True)
|
|
35
|
+
>>> out = transfo(torch.rand((3, 64, 64)))
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
size: output size in pixels, either a tuple (height, width) or a single integer for square images
|
|
39
|
+
interpolation: interpolation mode to use for resizing, default is bilinear
|
|
40
|
+
preserve_aspect_ratio: whether to preserve the aspect ratio of the image,
|
|
41
|
+
if True, the image will be resized to fit within the target size while maintaining its aspect ratio
|
|
42
|
+
symmetric_pad: whether to symmetrically pad the image to the target size,
|
|
43
|
+
if True, the image will be padded equally on both sides to fit the target size
|
|
44
|
+
"""
|
|
31
45
|
|
|
32
46
|
def __init__(
|
|
33
47
|
self,
|
|
@@ -36,25 +50,19 @@ class Resize(T.Resize):
|
|
|
36
50
|
preserve_aspect_ratio: bool = False,
|
|
37
51
|
symmetric_pad: bool = False,
|
|
38
52
|
) -> None:
|
|
39
|
-
super().__init__(size, interpolation, antialias=True)
|
|
53
|
+
super().__init__(size if isinstance(size, (list, tuple)) else (size, size), interpolation, antialias=True)
|
|
40
54
|
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
41
55
|
self.symmetric_pad = symmetric_pad
|
|
42
56
|
|
|
43
|
-
if not isinstance(self.size, (int, tuple, list)):
|
|
44
|
-
raise AssertionError("size should be either a tuple, a list or an int")
|
|
45
|
-
|
|
46
57
|
def forward(
|
|
47
58
|
self,
|
|
48
59
|
img: torch.Tensor,
|
|
49
60
|
target: np.ndarray | None = None,
|
|
50
61
|
) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
|
|
51
|
-
|
|
52
|
-
target_ratio = img.shape[-2] / img.shape[-1]
|
|
53
|
-
else:
|
|
54
|
-
target_ratio = self.size[0] / self.size[1]
|
|
62
|
+
target_ratio = self.size[0] / self.size[1]
|
|
55
63
|
actual_ratio = img.shape[-2] / img.shape[-1]
|
|
56
64
|
|
|
57
|
-
if not self.preserve_aspect_ratio or (target_ratio == actual_ratio
|
|
65
|
+
if not self.preserve_aspect_ratio or (target_ratio == actual_ratio):
|
|
58
66
|
# If we don't preserve the aspect ratio or the wanted aspect ratio is the same than the original one
|
|
59
67
|
# We can use with the regular resize
|
|
60
68
|
if target is not None:
|
|
@@ -62,16 +70,10 @@ class Resize(T.Resize):
|
|
|
62
70
|
return super().forward(img)
|
|
63
71
|
else:
|
|
64
72
|
# Resize
|
|
65
|
-
if
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1])
|
|
70
|
-
elif isinstance(self.size, int): # self.size is the longest side, infer the other
|
|
71
|
-
if img.shape[-2] <= img.shape[-1]:
|
|
72
|
-
tmp_size = (max(int(self.size * actual_ratio), 1), self.size)
|
|
73
|
-
else:
|
|
74
|
-
tmp_size = (self.size, max(int(self.size / actual_ratio), 1))
|
|
73
|
+
if actual_ratio > target_ratio:
|
|
74
|
+
tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1))
|
|
75
|
+
else:
|
|
76
|
+
tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1])
|
|
75
77
|
|
|
76
78
|
# Scale image
|
|
77
79
|
img = F.resize(img, tmp_size, self.interpolation, antialias=True)
|
|
@@ -93,14 +95,14 @@ class Resize(T.Resize):
|
|
|
93
95
|
if self.preserve_aspect_ratio:
|
|
94
96
|
# Get absolute coords
|
|
95
97
|
if target.shape[1:] == (4,):
|
|
96
|
-
if
|
|
98
|
+
if self.symmetric_pad:
|
|
97
99
|
target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
|
|
98
100
|
target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
|
|
99
101
|
else:
|
|
100
102
|
target[:, [0, 2]] *= raw_shape[-1] / img.shape[-1]
|
|
101
103
|
target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
|
|
102
104
|
elif target.shape[1:] == (4, 2):
|
|
103
|
-
if
|
|
105
|
+
if self.symmetric_pad:
|
|
104
106
|
target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
|
|
105
107
|
target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
|
|
106
108
|
else:
|
|
@@ -143,9 +145,9 @@ class GaussianNoise(torch.nn.Module):
|
|
|
143
145
|
# Reshape the distribution
|
|
144
146
|
noise = self.mean + 2 * self.std * torch.rand(x.shape, device=x.device) - self.std
|
|
145
147
|
if x.dtype == torch.uint8:
|
|
146
|
-
return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8)
|
|
148
|
+
return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8)
|
|
147
149
|
else:
|
|
148
|
-
return (x + noise.to(dtype=x.dtype)).clamp(0, 1)
|
|
150
|
+
return (x + noise.to(dtype=x.dtype)).clamp(0, 1)
|
|
149
151
|
|
|
150
152
|
def extra_repr(self) -> str:
|
|
151
153
|
return f"mean={self.mean}, std={self.std}"
|
|
@@ -233,7 +235,7 @@ class RandomShadow(torch.nn.Module):
|
|
|
233
235
|
try:
|
|
234
236
|
if x.dtype == torch.uint8:
|
|
235
237
|
return (
|
|
236
|
-
(
|
|
238
|
+
(
|
|
237
239
|
255
|
|
238
240
|
* random_shadow(
|
|
239
241
|
x.to(dtype=torch.float32) / 255,
|