python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/contrib/artefacts.py +1 -1
- doctr/contrib/base.py +1 -1
- doctr/datasets/__init__.py +0 -5
- doctr/datasets/coco_text.py +1 -1
- doctr/datasets/cord.py +1 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/base.py +1 -1
- doctr/datasets/datasets/pytorch.py +3 -3
- doctr/datasets/detection.py +1 -1
- doctr/datasets/doc_artefacts.py +1 -1
- doctr/datasets/funsd.py +1 -1
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/generator/base.py +1 -1
- doctr/datasets/generator/pytorch.py +1 -1
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +1 -1
- doctr/datasets/iiit5k.py +1 -1
- doctr/datasets/iiithws.py +1 -1
- doctr/datasets/imgur5k.py +1 -1
- doctr/datasets/mjsynth.py +1 -1
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/orientation.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/sroie.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +1 -1
- doctr/datasets/synthtext.py +1 -1
- doctr/datasets/utils.py +1 -1
- doctr/datasets/vocabs.py +1 -3
- doctr/datasets/wildreceipt.py +1 -1
- doctr/file_utils.py +3 -102
- doctr/io/elements.py +1 -1
- doctr/io/html.py +1 -1
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/base.py +1 -1
- doctr/io/image/pytorch.py +2 -2
- doctr/io/pdf.py +1 -1
- doctr/io/reader.py +1 -1
- doctr/models/_utils.py +56 -18
- doctr/models/builder.py +1 -1
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -3
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +1 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +1 -1
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +2 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +1 -1
- doctr/models/classification/vip/__init__.py +1 -4
- doctr/models/classification/vip/layers/__init__.py +1 -4
- doctr/models/classification/vip/layers/pytorch.py +2 -2
- doctr/models/classification/vip/pytorch.py +1 -1
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +3 -3
- doctr/models/classification/zoo.py +7 -12
- doctr/models/core.py +1 -1
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/_utils/base.py +1 -1
- doctr/models/detection/_utils/pytorch.py +1 -1
- doctr/models/detection/core.py +2 -2
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +5 -13
- doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +5 -15
- doctr/models/detection/fast/pytorch.py +5 -5
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +4 -13
- doctr/models/detection/linknet/pytorch.py +3 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +2 -2
- doctr/models/detection/zoo.py +16 -33
- doctr/models/factory/hub.py +26 -34
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/base.py +1 -1
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +4 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +3 -3
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/modules/vision_transformer/pytorch.py +1 -1
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +4 -9
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +28 -33
- doctr/models/recognition/core.py +1 -1
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +7 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/base.py +1 -1
- doctr/models/recognition/master/pytorch.py +6 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/base.py +1 -1
- doctr/models/recognition/parseq/pytorch.py +6 -6
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +8 -17
- doctr/models/recognition/predictor/pytorch.py +2 -3
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +4 -4
- doctr/models/recognition/utils.py +1 -1
- doctr/models/recognition/viptr/__init__.py +1 -4
- doctr/models/recognition/viptr/pytorch.py +4 -4
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/base.py +1 -1
- doctr/models/recognition/vitstr/pytorch.py +4 -4
- doctr/models/recognition/zoo.py +14 -14
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +3 -2
- doctr/models/zoo.py +1 -1
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/base.py +3 -2
- doctr/transforms/functional/pytorch.py +5 -5
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +28 -94
- doctr/transforms/modules/pytorch.py +29 -27
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +1 -2
- doctr/utils/fonts.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/metrics.py +1 -1
- doctr/utils/multithreading.py +1 -1
- doctr/utils/reconstitution.py +1 -1
- doctr/utils/repr.py +1 -1
- doctr/utils/visualization.py +2 -2
- doctr/version.py +1 -1
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
- python_doctr-1.0.1.dist-info/RECORD +149 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -442
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -418
- doctr/models/classification/textnet/tensorflow.py +0 -275
- doctr/models/classification/vgg/tensorflow.py +0 -125
- doctr/models/classification/vit/tensorflow.py +0 -201
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
- doctr/models/detection/fast/tensorflow.py +0 -427
- doctr/models/detection/linknet/tensorflow.py +0 -377
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -317
- doctr/models/recognition/master/tensorflow.py +0 -320
- doctr/models/recognition/parseq/tensorflow.py +0 -516
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -423
- doctr/models/recognition/vitstr/tensorflow.py +0 -285
- doctr/models/utils/tensorflow.py +0 -189
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.12.0.dist-info/RECORD +0 -180
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -19,7 +19,7 @@ from doctr.datasets import VOCABS
|
|
|
19
19
|
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
|
|
20
20
|
|
|
21
21
|
from ...classification import vit_s
|
|
22
|
-
from ...utils
|
|
22
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
23
23
|
from .base import _PARSeq, _PARSeqPostProcessor
|
|
24
24
|
|
|
25
25
|
__all__ = ["PARSeq", "parseq"]
|
|
@@ -299,7 +299,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
299
299
|
|
|
300
300
|
# Stop decoding if all sequences have reached the EOS token
|
|
301
301
|
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
|
|
302
|
-
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
|
|
302
|
+
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
|
|
303
303
|
break
|
|
304
304
|
|
|
305
305
|
logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
|
|
@@ -314,7 +314,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
314
314
|
|
|
315
315
|
# Create padding mask for refined target input maskes all behind EOS token as False
|
|
316
316
|
# (N, 1, 1, max_length)
|
|
317
|
-
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
317
|
+
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
318
318
|
mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
|
|
319
319
|
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
|
|
320
320
|
|
|
@@ -367,7 +367,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
367
367
|
# remove the [EOS] tokens for the succeeding perms
|
|
368
368
|
if i == 1:
|
|
369
369
|
gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out)
|
|
370
|
-
n = (gt_out != self.vocab_size + 2).sum().item()
|
|
370
|
+
n = (gt_out != self.vocab_size + 2).sum().item()
|
|
371
371
|
|
|
372
372
|
loss /= loss_numel
|
|
373
373
|
|
|
@@ -391,7 +391,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
391
391
|
|
|
392
392
|
if target is None or return_preds:
|
|
393
393
|
# Disable for torch.compile compatibility
|
|
394
|
-
@torch.compiler.disable
|
|
394
|
+
@torch.compiler.disable
|
|
395
395
|
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
396
396
|
return self.postprocessor(logits)
|
|
397
397
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -18,17 +18,15 @@ def split_crops(
|
|
|
18
18
|
max_ratio: float,
|
|
19
19
|
target_ratio: int,
|
|
20
20
|
split_overlap_ratio: float,
|
|
21
|
-
channels_last: bool = True,
|
|
22
21
|
) -> tuple[list[np.ndarray], list[int | tuple[int, int, float]], bool]:
|
|
23
22
|
"""
|
|
24
23
|
Split crops horizontally if they exceed a given aspect ratio.
|
|
25
24
|
|
|
26
25
|
Args:
|
|
27
|
-
crops: List of image crops (H, W, C)
|
|
26
|
+
crops: List of image crops (H, W, C).
|
|
28
27
|
max_ratio: Aspect ratio threshold above which crops are split.
|
|
29
28
|
target_ratio: Target aspect ratio after splitting (e.g., 4 for 128x32).
|
|
30
29
|
split_overlap_ratio: Desired overlap between splits (as a fraction of split width).
|
|
31
|
-
channels_last: Whether the crops are in channels-last format.
|
|
32
30
|
|
|
33
31
|
Returns:
|
|
34
32
|
A tuple containing:
|
|
@@ -44,14 +42,14 @@ def split_crops(
|
|
|
44
42
|
crop_map: list[int | tuple[int, int, float]] = []
|
|
45
43
|
|
|
46
44
|
for crop in crops:
|
|
47
|
-
h, w = crop.shape[:2]
|
|
45
|
+
h, w = crop.shape[:2]
|
|
48
46
|
aspect_ratio = w / h
|
|
49
47
|
|
|
50
48
|
if aspect_ratio > max_ratio:
|
|
51
49
|
split_width = max(1, math.ceil(h * target_ratio))
|
|
52
50
|
overlap_width = max(0, math.floor(split_width * split_overlap_ratio))
|
|
53
51
|
|
|
54
|
-
splits, last_overlap = _split_horizontally(crop, split_width, overlap_width
|
|
52
|
+
splits, last_overlap = _split_horizontally(crop, split_width, overlap_width)
|
|
55
53
|
|
|
56
54
|
# Remove any empty splits
|
|
57
55
|
splits = [s for s in splits if all(dim > 0 for dim in s.shape)]
|
|
@@ -70,23 +68,20 @@ def split_crops(
|
|
|
70
68
|
return new_crops, crop_map, remap_required
|
|
71
69
|
|
|
72
70
|
|
|
73
|
-
def _split_horizontally(
|
|
74
|
-
image: np.ndarray, split_width: int, overlap_width: int, channels_last: bool
|
|
75
|
-
) -> tuple[list[np.ndarray], float]:
|
|
71
|
+
def _split_horizontally(image: np.ndarray, split_width: int, overlap_width: int) -> tuple[list[np.ndarray], float]:
|
|
76
72
|
"""
|
|
77
73
|
Horizontally split a single image with overlapping regions.
|
|
78
74
|
|
|
79
75
|
Args:
|
|
80
|
-
image: The image to split (H, W, C)
|
|
76
|
+
image: The image to split (H, W, C).
|
|
81
77
|
split_width: Width of each split.
|
|
82
78
|
overlap_width: Width of the overlapping region.
|
|
83
|
-
channels_last: Whether the image is in channels-last format.
|
|
84
79
|
|
|
85
80
|
Returns:
|
|
86
81
|
- A list of horizontal image slices.
|
|
87
82
|
- The actual overlap ratio of the last split.
|
|
88
83
|
"""
|
|
89
|
-
image_width = image.shape[1]
|
|
84
|
+
image_width = image.shape[1]
|
|
90
85
|
if image_width <= split_width:
|
|
91
86
|
return [image], 0.0
|
|
92
87
|
|
|
@@ -101,11 +96,7 @@ def _split_horizontally(
|
|
|
101
96
|
splits = []
|
|
102
97
|
for start_col in starts:
|
|
103
98
|
end_col = start_col + split_width
|
|
104
|
-
|
|
105
|
-
split = image[:, start_col:end_col, :]
|
|
106
|
-
else:
|
|
107
|
-
split = image[:, :, start_col:end_col]
|
|
108
|
-
splits.append(split)
|
|
99
|
+
splits.append(image[:, start_col:end_col, :])
|
|
109
100
|
|
|
110
101
|
# Calculate the last overlap ratio, if only one split no overlap
|
|
111
102
|
last_overlap = 0
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -44,7 +44,7 @@ class RecognitionPredictor(nn.Module):
|
|
|
44
44
|
@torch.inference_mode()
|
|
45
45
|
def forward(
|
|
46
46
|
self,
|
|
47
|
-
crops: Sequence[np.ndarray
|
|
47
|
+
crops: Sequence[np.ndarray],
|
|
48
48
|
**kwargs: Any,
|
|
49
49
|
) -> list[tuple[str, float]]:
|
|
50
50
|
if len(crops) == 0:
|
|
@@ -61,7 +61,6 @@ class RecognitionPredictor(nn.Module):
|
|
|
61
61
|
self.critical_ar,
|
|
62
62
|
self.target_ar,
|
|
63
63
|
self.overlap_ratio,
|
|
64
|
-
isinstance(crops[0], np.ndarray),
|
|
65
64
|
)
|
|
66
65
|
if remapped:
|
|
67
66
|
crops = new_crops
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -15,7 +15,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
17
|
from ...classification import resnet31
|
|
18
|
-
from ...utils
|
|
18
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
19
19
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
20
20
|
|
|
21
21
|
__all__ = ["SAR", "sar_resnet31"]
|
|
@@ -272,7 +272,7 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
272
272
|
|
|
273
273
|
if target is None or return_preds:
|
|
274
274
|
# Disable for torch.compile compatibility
|
|
275
|
-
@torch.compiler.disable
|
|
275
|
+
@torch.compiler.disable
|
|
276
276
|
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
277
277
|
return self.postprocessor(decoded_features)
|
|
278
278
|
|
|
@@ -304,7 +304,7 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
304
304
|
# Input length : number of timesteps
|
|
305
305
|
input_len = model_output.shape[1]
|
|
306
306
|
# Add one for additional <eos> token
|
|
307
|
-
seq_len = seq_len + 1
|
|
307
|
+
seq_len = seq_len + 1
|
|
308
308
|
# Compute loss
|
|
309
309
|
# (N, L, vocab_size + 1)
|
|
310
310
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -16,7 +16,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
|
|
16
16
|
from doctr.datasets import VOCABS, decode_sequence
|
|
17
17
|
|
|
18
18
|
from ...classification import vip_tiny
|
|
19
|
-
from ...utils
|
|
19
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
20
20
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
21
21
|
|
|
22
22
|
__all__ = ["VIPTR", "viptr_tiny"]
|
|
@@ -70,7 +70,7 @@ class VIPTRPostProcessor(RecognitionPostProcessor):
|
|
|
70
70
|
|
|
71
71
|
def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
72
72
|
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
73
|
-
with label_to_idx mapping
|
|
73
|
+
with label_to_idx mapping dictionary
|
|
74
74
|
|
|
75
75
|
Args:
|
|
76
76
|
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
@@ -166,7 +166,7 @@ class VIPTR(RecognitionModel, nn.Module):
|
|
|
166
166
|
|
|
167
167
|
if target is None or return_preds:
|
|
168
168
|
# Disable for torch.compile compatibility
|
|
169
|
-
@torch.compiler.disable
|
|
169
|
+
@torch.compiler.disable
|
|
170
170
|
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
171
171
|
return self.postprocessor(decoded_features)
|
|
172
172
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -15,7 +15,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
17
|
from ...classification import vit_b, vit_s
|
|
18
|
-
from ...utils
|
|
18
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
19
19
|
from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
20
20
|
|
|
21
21
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
@@ -117,7 +117,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
117
117
|
|
|
118
118
|
if target is None or return_preds:
|
|
119
119
|
# Disable for torch.compile compatibility
|
|
120
|
-
@torch.compiler.disable
|
|
120
|
+
@torch.compiler.disable
|
|
121
121
|
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
122
122
|
return self.postprocessor(decoded_features)
|
|
123
123
|
|
|
@@ -149,7 +149,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
149
149
|
# Input length : number of steps
|
|
150
150
|
input_len = model_output.shape[1]
|
|
151
151
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
152
|
-
seq_len = seq_len + 1
|
|
152
|
+
seq_len = seq_len + 1
|
|
153
153
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
154
154
|
# The "masked" first gt char is <sos>.
|
|
155
155
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
|
doctr/models/recognition/zoo.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.file_utils import is_tf_available, is_torch_available
|
|
9
8
|
from doctr.models.preprocessor import PreProcessor
|
|
9
|
+
from doctr.models.utils import _CompiledModule
|
|
10
10
|
|
|
11
11
|
from .. import recognition
|
|
12
12
|
from .predictor import RecognitionPredictor
|
|
@@ -23,11 +23,9 @@ ARCHS: list[str] = [
|
|
|
23
23
|
"vitstr_small",
|
|
24
24
|
"vitstr_base",
|
|
25
25
|
"parseq",
|
|
26
|
+
"viptr_tiny",
|
|
26
27
|
]
|
|
27
28
|
|
|
28
|
-
if is_torch_available():
|
|
29
|
-
ARCHS.extend(["viptr_tiny"])
|
|
30
|
-
|
|
31
29
|
|
|
32
30
|
def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor:
|
|
33
31
|
if isinstance(arch, str):
|
|
@@ -38,14 +36,16 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
38
36
|
pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
|
|
39
37
|
)
|
|
40
38
|
else:
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
39
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
40
|
+
allowed_archs = [
|
|
41
|
+
recognition.CRNN,
|
|
42
|
+
recognition.SAR,
|
|
43
|
+
recognition.MASTER,
|
|
44
|
+
recognition.ViTSTR,
|
|
45
|
+
recognition.PARSeq,
|
|
46
|
+
recognition.VIPTR,
|
|
47
|
+
_CompiledModule,
|
|
48
|
+
]
|
|
49
49
|
|
|
50
50
|
if not isinstance(arch, tuple(allowed_archs)):
|
|
51
51
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
@@ -56,7 +56,7 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
56
56
|
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
57
57
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
58
58
|
kwargs["batch_size"] = kwargs.get("batch_size", 128)
|
|
59
|
-
input_shape = _model.cfg["input_shape"][
|
|
59
|
+
input_shape = _model.cfg["input_shape"][-2:]
|
|
60
60
|
predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model)
|
|
61
61
|
|
|
62
62
|
return predictor
|
doctr/models/utils/__init__.py
CHANGED
doctr/models/utils/pytorch.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -164,12 +164,13 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
|
|
|
164
164
|
"""
|
|
165
165
|
torch.onnx.export(
|
|
166
166
|
model,
|
|
167
|
-
dummy_input,
|
|
167
|
+
dummy_input, # type: ignore[arg-type]
|
|
168
168
|
f"{model_name}.onnx",
|
|
169
169
|
input_names=["input"],
|
|
170
170
|
output_names=["logits"],
|
|
171
171
|
dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}},
|
|
172
172
|
export_params=True,
|
|
173
|
+
dynamo=False,
|
|
173
174
|
verbose=False,
|
|
174
175
|
**kwargs,
|
|
175
176
|
)
|
doctr/models/zoo.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -145,7 +145,8 @@ def create_shadow_mask(
|
|
|
145
145
|
|
|
146
146
|
# Convert to absolute coords
|
|
147
147
|
abs_contour: np.ndarray = (
|
|
148
|
-
np
|
|
148
|
+
np
|
|
149
|
+
.stack(
|
|
149
150
|
(contour[:, 0] * target_shape[1], contour[:, 1] * target_shape[0]),
|
|
150
151
|
axis=-1,
|
|
151
152
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -33,9 +33,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
|
|
|
33
33
|
rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
|
|
34
34
|
# Inverse the color
|
|
35
35
|
if out.dtype == torch.uint8:
|
|
36
|
-
out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
|
|
36
|
+
out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
|
|
37
37
|
else:
|
|
38
|
-
out = out * rgb_shift.to(dtype=out.dtype)
|
|
38
|
+
out = out * rgb_shift.to(dtype=out.dtype)
|
|
39
39
|
# Inverse the color
|
|
40
40
|
out = 255 - out if out.dtype == torch.uint8 else 1 - out
|
|
41
41
|
return out
|
|
@@ -77,7 +77,7 @@ def rotate_sample(
|
|
|
77
77
|
rotated_geoms: np.ndarray = rotate_abs_geoms(
|
|
78
78
|
_geoms,
|
|
79
79
|
angle,
|
|
80
|
-
img.shape[1:],
|
|
80
|
+
img.shape[1:], # type: ignore[arg-type]
|
|
81
81
|
expand,
|
|
82
82
|
).astype(np.float32)
|
|
83
83
|
|
|
@@ -124,7 +124,7 @@ def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwarg
|
|
|
124
124
|
Returns:
|
|
125
125
|
Shadowed image as a PyTorch tensor (same shape as input).
|
|
126
126
|
"""
|
|
127
|
-
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
|
|
127
|
+
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type]
|
|
128
128
|
opacity = np.random.uniform(*opacity_range)
|
|
129
129
|
|
|
130
130
|
# Apply Gaussian blur to the shadow mask
|
doctr/transforms/modules/base.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -20,27 +20,13 @@ __all__ = ["SampleCompose", "ImageTransform", "ColorInversion", "OneOf", "Random
|
|
|
20
20
|
class SampleCompose(NestedObject):
|
|
21
21
|
"""Implements a wrapper that will apply transformations sequentially on both image and target
|
|
22
22
|
|
|
23
|
-
..
|
|
23
|
+
.. code:: python
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
>>> import torch
|
|
31
|
-
>>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
|
|
32
|
-
>>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
|
|
33
|
-
>>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
|
|
34
|
-
|
|
35
|
-
.. tab:: TensorFlow
|
|
36
|
-
|
|
37
|
-
.. code:: python
|
|
38
|
-
|
|
39
|
-
>>> import numpy as np
|
|
40
|
-
>>> import tensorflow as tf
|
|
41
|
-
>>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
|
|
42
|
-
>>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
|
|
43
|
-
>>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4)))
|
|
25
|
+
>>> import numpy as np
|
|
26
|
+
>>> import torch
|
|
27
|
+
>>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate
|
|
28
|
+
>>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)])
|
|
29
|
+
>>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4)))
|
|
44
30
|
|
|
45
31
|
Args:
|
|
46
32
|
transforms: list of transformation modules
|
|
@@ -61,25 +47,12 @@ class SampleCompose(NestedObject):
|
|
|
61
47
|
class ImageTransform(NestedObject):
|
|
62
48
|
"""Implements a transform wrapper to turn an image-only transformation into an image+target transform
|
|
63
49
|
|
|
64
|
-
..
|
|
65
|
-
|
|
66
|
-
.. tab:: PyTorch
|
|
67
|
-
|
|
68
|
-
.. code:: python
|
|
50
|
+
.. code:: python
|
|
69
51
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
.. tab:: TensorFlow
|
|
76
|
-
|
|
77
|
-
.. code:: python
|
|
78
|
-
|
|
79
|
-
>>> import tensorflow as tf
|
|
80
|
-
>>> from doctr.transforms import ImageTransform, ColorInversion
|
|
81
|
-
>>> transfo = ImageTransform(ColorInversion((32, 32)))
|
|
82
|
-
>>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None)
|
|
52
|
+
>>> import torch
|
|
53
|
+
>>> from doctr.transforms import ImageTransform, ColorInversion
|
|
54
|
+
>>> transfo = ImageTransform(ColorInversion((32, 32)))
|
|
55
|
+
>>> out, _ = transfo(torch.rand(8, 64, 64, 3), None)
|
|
83
56
|
|
|
84
57
|
Args:
|
|
85
58
|
transform: the image transformation module to wrap
|
|
@@ -99,25 +72,12 @@ class ColorInversion(NestedObject):
|
|
|
99
72
|
"""Applies the following tranformation to a tensor (image or batch of images):
|
|
100
73
|
convert to grayscale, colorize (shift 0-values randomly), and then invert colors
|
|
101
74
|
|
|
102
|
-
..
|
|
103
|
-
|
|
104
|
-
.. tab:: PyTorch
|
|
105
|
-
|
|
106
|
-
.. code:: python
|
|
107
|
-
|
|
108
|
-
>>> import torch
|
|
109
|
-
>>> from doctr.transforms import ColorInversion
|
|
110
|
-
>>> transfo = ColorInversion(min_val=0.6)
|
|
111
|
-
>>> out = transfo(torch.rand(8, 64, 64, 3))
|
|
75
|
+
.. code:: python
|
|
112
76
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
>>> import tensorflow as tf
|
|
118
|
-
>>> from doctr.transforms import ColorInversion
|
|
119
|
-
>>> transfo = ColorInversion(min_val=0.6)
|
|
120
|
-
>>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1))
|
|
77
|
+
>>> import torch
|
|
78
|
+
>>> from doctr.transforms import ColorInversion
|
|
79
|
+
>>> transfo = ColorInversion(min_val=0.6)
|
|
80
|
+
>>> out = transfo(torch.rand(8, 64, 64, 3))
|
|
121
81
|
|
|
122
82
|
Args:
|
|
123
83
|
min_val: range [min_val, 1] to colorize RGB pixels
|
|
@@ -136,25 +96,12 @@ class ColorInversion(NestedObject):
|
|
|
136
96
|
class OneOf(NestedObject):
|
|
137
97
|
"""Randomly apply one of the input transformations
|
|
138
98
|
|
|
139
|
-
..
|
|
140
|
-
|
|
141
|
-
.. tab:: PyTorch
|
|
142
|
-
|
|
143
|
-
.. code:: python
|
|
144
|
-
|
|
145
|
-
>>> import torch
|
|
146
|
-
>>> from doctr.transforms import OneOf
|
|
147
|
-
>>> transfo = OneOf([JpegQuality(), Gamma()])
|
|
148
|
-
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
149
|
-
|
|
150
|
-
.. tab:: TensorFlow
|
|
99
|
+
.. code:: python
|
|
151
100
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
>>> transfo = OneOf([JpegQuality(), Gamma()])
|
|
157
|
-
>>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
|
|
101
|
+
>>> import torch
|
|
102
|
+
>>> from doctr.transforms import OneOf
|
|
103
|
+
>>> transfo = OneOf([JpegQuality(), Gamma()])
|
|
104
|
+
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
158
105
|
|
|
159
106
|
Args:
|
|
160
107
|
transforms: list of transformations, one only will be picked
|
|
@@ -175,25 +122,12 @@ class OneOf(NestedObject):
|
|
|
175
122
|
class RandomApply(NestedObject):
|
|
176
123
|
"""Apply with a probability p the input transformation
|
|
177
124
|
|
|
178
|
-
..
|
|
179
|
-
|
|
180
|
-
.. tab:: PyTorch
|
|
181
|
-
|
|
182
|
-
.. code:: python
|
|
183
|
-
|
|
184
|
-
>>> import torch
|
|
185
|
-
>>> from doctr.transforms import RandomApply
|
|
186
|
-
>>> transfo = RandomApply(Gamma(), p=.5)
|
|
187
|
-
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
188
|
-
|
|
189
|
-
.. tab:: TensorFlow
|
|
190
|
-
|
|
191
|
-
.. code:: python
|
|
125
|
+
.. code:: python
|
|
192
126
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
127
|
+
>>> import torch
|
|
128
|
+
>>> from doctr.transforms import RandomApply
|
|
129
|
+
>>> transfo = RandomApply(Gamma(), p=.5)
|
|
130
|
+
>>> out = transfo(torch.rand(1, 64, 64, 3))
|
|
197
131
|
|
|
198
132
|
Args:
|
|
199
133
|
transform: transformation to apply
|
|
@@ -258,7 +192,7 @@ class RandomCrop(NestedObject):
|
|
|
258
192
|
scale = random.uniform(self.scale[0], self.scale[1])
|
|
259
193
|
ratio = random.uniform(self.ratio[0], self.ratio[1])
|
|
260
194
|
|
|
261
|
-
height, width = img.shape[:
|
|
195
|
+
height, width = img.shape[-2:]
|
|
262
196
|
|
|
263
197
|
# Calculate crop size
|
|
264
198
|
crop_area = scale * width * height
|