python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
|
|
7
|
+
import math
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
|
|
@@ -13,74 +14,132 @@ __all__ = ["split_crops", "remap_preds"]
|
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
def split_crops(
|
|
16
|
-
crops:
|
|
17
|
+
crops: list[np.ndarray],
|
|
17
18
|
max_ratio: float,
|
|
18
19
|
target_ratio: int,
|
|
19
|
-
|
|
20
|
+
split_overlap_ratio: float,
|
|
20
21
|
channels_last: bool = True,
|
|
21
|
-
) ->
|
|
22
|
-
"""
|
|
22
|
+
) -> tuple[list[np.ndarray], list[int | tuple[int, int, float]], bool]:
|
|
23
|
+
"""
|
|
24
|
+
Split crops horizontally if they exceed a given aspect ratio.
|
|
23
25
|
|
|
24
26
|
Args:
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
channels_last: whether the numpy array has dimensions in channels last order
|
|
27
|
+
crops: List of image crops (H, W, C) if channels_last else (C, H, W).
|
|
28
|
+
max_ratio: Aspect ratio threshold above which crops are split.
|
|
29
|
+
target_ratio: Target aspect ratio after splitting (e.g., 4 for 128x32).
|
|
30
|
+
split_overlap_ratio: Desired overlap between splits (as a fraction of split width).
|
|
31
|
+
channels_last: Whether the crops are in channels-last format.
|
|
31
32
|
|
|
32
33
|
Returns:
|
|
33
|
-
|
|
34
|
-
|
|
34
|
+
A tuple containing:
|
|
35
|
+
- The new list of crops (possibly with splits),
|
|
36
|
+
- A mapping indicating how to reassemble predictions,
|
|
37
|
+
- A boolean indicating whether remapping is required.
|
|
35
38
|
"""
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
+
if split_overlap_ratio <= 0.0 or split_overlap_ratio >= 1.0:
|
|
40
|
+
raise ValueError(f"Valid range for split_overlap_ratio is (0.0, 1.0), but is: {split_overlap_ratio}")
|
|
41
|
+
|
|
42
|
+
remap_required = False
|
|
43
|
+
new_crops: list[np.ndarray] = []
|
|
44
|
+
crop_map: list[int | tuple[int, int, float]] = []
|
|
45
|
+
|
|
39
46
|
for crop in crops:
|
|
40
47
|
h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
|
|
41
48
|
aspect_ratio = w / h
|
|
49
|
+
|
|
42
50
|
if aspect_ratio > max_ratio:
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
#
|
|
49
|
-
if
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
51
|
+
split_width = max(1, math.ceil(h * target_ratio))
|
|
52
|
+
overlap_width = max(0, math.floor(split_width * split_overlap_ratio))
|
|
53
|
+
|
|
54
|
+
splits, last_overlap = _split_horizontally(crop, split_width, overlap_width, channels_last)
|
|
55
|
+
|
|
56
|
+
# Remove any empty splits
|
|
57
|
+
splits = [s for s in splits if all(dim > 0 for dim in s.shape)]
|
|
58
|
+
if splits:
|
|
59
|
+
crop_map.append((len(new_crops), len(new_crops) + len(splits), last_overlap))
|
|
60
|
+
new_crops.extend(splits)
|
|
61
|
+
remap_required = True
|
|
54
62
|
else:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
]
|
|
59
|
-
# Avoid sending zero-sized crops
|
|
60
|
-
_crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
|
|
61
|
-
# Record the slice of crops
|
|
62
|
-
crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
|
|
63
|
-
new_crops.extend(_crops)
|
|
64
|
-
# At least one crop will require merging
|
|
65
|
-
_remap_required = True
|
|
63
|
+
# Fallback: treat it as a single crop
|
|
64
|
+
crop_map.append(len(new_crops))
|
|
65
|
+
new_crops.append(crop)
|
|
66
66
|
else:
|
|
67
67
|
crop_map.append(len(new_crops))
|
|
68
68
|
new_crops.append(crop)
|
|
69
69
|
|
|
70
|
-
return new_crops, crop_map,
|
|
70
|
+
return new_crops, crop_map, remap_required
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _split_horizontally(
|
|
74
|
+
image: np.ndarray, split_width: int, overlap_width: int, channels_last: bool
|
|
75
|
+
) -> tuple[list[np.ndarray], float]:
|
|
76
|
+
"""
|
|
77
|
+
Horizontally split a single image with overlapping regions.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
image: The image to split (H, W, C) if channels_last else (C, H, W).
|
|
81
|
+
split_width: Width of each split.
|
|
82
|
+
overlap_width: Width of the overlapping region.
|
|
83
|
+
channels_last: Whether the image is in channels-last format.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
- A list of horizontal image slices.
|
|
87
|
+
- The actual overlap ratio of the last split.
|
|
88
|
+
"""
|
|
89
|
+
image_width = image.shape[1] if channels_last else image.shape[-1]
|
|
90
|
+
if image_width <= split_width:
|
|
91
|
+
return [image], 0.0
|
|
92
|
+
|
|
93
|
+
# Compute start columns for each split
|
|
94
|
+
step = split_width - overlap_width
|
|
95
|
+
starts = list(range(0, image_width - split_width + 1, step))
|
|
96
|
+
|
|
97
|
+
# Ensure the last patch reaches the end of the image
|
|
98
|
+
if starts[-1] + split_width < image_width:
|
|
99
|
+
starts.append(image_width - split_width)
|
|
100
|
+
|
|
101
|
+
splits = []
|
|
102
|
+
for start_col in starts:
|
|
103
|
+
end_col = start_col + split_width
|
|
104
|
+
if channels_last:
|
|
105
|
+
split = image[:, start_col:end_col, :]
|
|
106
|
+
else:
|
|
107
|
+
split = image[:, :, start_col:end_col]
|
|
108
|
+
splits.append(split)
|
|
109
|
+
|
|
110
|
+
# Calculate the last overlap ratio, if only one split no overlap
|
|
111
|
+
last_overlap = 0
|
|
112
|
+
if len(starts) > 1:
|
|
113
|
+
last_overlap = (starts[-2] + split_width) - starts[-1]
|
|
114
|
+
last_overlap_ratio = last_overlap / split_width if split_width else 0.0
|
|
115
|
+
|
|
116
|
+
return splits, last_overlap_ratio
|
|
71
117
|
|
|
72
118
|
|
|
73
119
|
def remap_preds(
|
|
74
|
-
preds:
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
120
|
+
preds: list[tuple[str, float]],
|
|
121
|
+
crop_map: list[int | tuple[int, int, float]],
|
|
122
|
+
overlap_ratio: float,
|
|
123
|
+
) -> list[tuple[str, float]]:
|
|
124
|
+
"""
|
|
125
|
+
Reconstruct predictions from possibly split crops.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
preds: List of (text, confidence) tuples from each crop.
|
|
129
|
+
crop_map: Map returned by `split_crops`.
|
|
130
|
+
overlap_ratio: Overlap ratio used during splitting.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
List of merged (text, confidence) tuples corresponding to original crops.
|
|
134
|
+
"""
|
|
135
|
+
remapped = []
|
|
136
|
+
for item in crop_map:
|
|
137
|
+
if isinstance(item, int):
|
|
138
|
+
remapped.append(preds[item])
|
|
81
139
|
else:
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
140
|
+
start_idx, end_idx, last_overlap = item
|
|
141
|
+
text_parts, confidences = zip(*preds[start_idx:end_idx])
|
|
142
|
+
merged_text = merge_multi_strings(list(text_parts), overlap_ratio, last_overlap)
|
|
143
|
+
merged_conf = sum(confidences) / len(confidences) # average confidence
|
|
144
|
+
remapped.append((merged_text, merged_conf))
|
|
145
|
+
return remapped
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import torch
|
|
@@ -21,7 +22,6 @@ class RecognitionPredictor(nn.Module):
|
|
|
21
22
|
"""Implements an object able to identify character sequences in images
|
|
22
23
|
|
|
23
24
|
Args:
|
|
24
|
-
----
|
|
25
25
|
pre_processor: transform inputs for easier batched model inference
|
|
26
26
|
model: core detection architecture
|
|
27
27
|
split_wide_crops: wether to use crop splitting for high aspect ratio crops
|
|
@@ -38,15 +38,15 @@ class RecognitionPredictor(nn.Module):
|
|
|
38
38
|
self.model = model.eval()
|
|
39
39
|
self.split_wide_crops = split_wide_crops
|
|
40
40
|
self.critical_ar = 8 # Critical aspect ratio
|
|
41
|
-
self.
|
|
41
|
+
self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
|
|
42
42
|
self.target_ar = 6 # Target aspect ratio
|
|
43
43
|
|
|
44
44
|
@torch.inference_mode()
|
|
45
45
|
def forward(
|
|
46
46
|
self,
|
|
47
|
-
crops: Sequence[
|
|
47
|
+
crops: Sequence[np.ndarray | torch.Tensor],
|
|
48
48
|
**kwargs: Any,
|
|
49
|
-
) ->
|
|
49
|
+
) -> list[tuple[str, float]]:
|
|
50
50
|
if len(crops) == 0:
|
|
51
51
|
return []
|
|
52
52
|
# Dimension check
|
|
@@ -60,14 +60,14 @@ class RecognitionPredictor(nn.Module):
|
|
|
60
60
|
crops, # type: ignore[arg-type]
|
|
61
61
|
self.critical_ar,
|
|
62
62
|
self.target_ar,
|
|
63
|
-
self.
|
|
63
|
+
self.overlap_ratio,
|
|
64
64
|
isinstance(crops[0], np.ndarray),
|
|
65
65
|
)
|
|
66
66
|
if remapped:
|
|
67
67
|
crops = new_crops
|
|
68
68
|
|
|
69
69
|
# Resize & batch them
|
|
70
|
-
processed_batches = self.pre_processor(crops)
|
|
70
|
+
processed_batches = self.pre_processor(crops) # type: ignore[arg-type]
|
|
71
71
|
|
|
72
72
|
# Forward it
|
|
73
73
|
_params = next(self.model.parameters())
|
|
@@ -81,6 +81,6 @@ class RecognitionPredictor(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
# Remap crops
|
|
83
83
|
if self.split_wide_crops and remapped:
|
|
84
|
-
out = remap_preds(out, crop_map, self.
|
|
84
|
+
out = remap_preds(out, crop_map, self.overlap_ratio)
|
|
85
85
|
|
|
86
86
|
return out
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
@@ -21,13 +21,12 @@ class RecognitionPredictor(NestedObject):
|
|
|
21
21
|
"""Implements an object able to identify character sequences in images
|
|
22
22
|
|
|
23
23
|
Args:
|
|
24
|
-
----
|
|
25
24
|
pre_processor: transform inputs for easier batched model inference
|
|
26
25
|
model: core detection architecture
|
|
27
26
|
split_wide_crops: wether to use crop splitting for high aspect ratio crops
|
|
28
27
|
"""
|
|
29
28
|
|
|
30
|
-
_children_names:
|
|
29
|
+
_children_names: list[str] = ["pre_processor", "model"]
|
|
31
30
|
|
|
32
31
|
def __init__(
|
|
33
32
|
self,
|
|
@@ -40,14 +39,14 @@ class RecognitionPredictor(NestedObject):
|
|
|
40
39
|
self.model = model
|
|
41
40
|
self.split_wide_crops = split_wide_crops
|
|
42
41
|
self.critical_ar = 8 # Critical aspect ratio
|
|
43
|
-
self.
|
|
42
|
+
self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
|
|
44
43
|
self.target_ar = 6 # Target aspect ratio
|
|
45
44
|
|
|
46
45
|
def __call__(
|
|
47
46
|
self,
|
|
48
|
-
crops:
|
|
47
|
+
crops: list[np.ndarray | tf.Tensor],
|
|
49
48
|
**kwargs: Any,
|
|
50
|
-
) ->
|
|
49
|
+
) -> list[tuple[str, float]]:
|
|
51
50
|
if len(crops) == 0:
|
|
52
51
|
return []
|
|
53
52
|
# Dimension check
|
|
@@ -57,7 +56,7 @@ class RecognitionPredictor(NestedObject):
|
|
|
57
56
|
# Split crops that are too wide
|
|
58
57
|
remapped = False
|
|
59
58
|
if self.split_wide_crops:
|
|
60
|
-
new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.
|
|
59
|
+
new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.overlap_ratio)
|
|
61
60
|
if remapped:
|
|
62
61
|
crops = new_crops
|
|
63
62
|
|
|
@@ -75,6 +74,6 @@ class RecognitionPredictor(NestedObject):
|
|
|
75
74
|
|
|
76
75
|
# Remap crops
|
|
77
76
|
if self.split_wide_crops and remapped:
|
|
78
|
-
out = remap_preds(out, crop_map, self.
|
|
77
|
+
out = remap_preds(out, crop_map, self.overlap_ratio)
|
|
79
78
|
|
|
80
79
|
return out
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import torch
|
|
10
11
|
from torch import nn
|
|
@@ -19,7 +20,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
|
|
|
19
20
|
|
|
20
21
|
__all__ = ["SAR", "sar_resnet31"]
|
|
21
22
|
|
|
22
|
-
default_cfgs:
|
|
23
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
24
|
"sar_resnet31": {
|
|
24
25
|
"mean": (0.694, 0.695, 0.693),
|
|
25
26
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -80,7 +81,6 @@ class SARDecoder(nn.Module):
|
|
|
80
81
|
"""Implements decoder module of the SAR model
|
|
81
82
|
|
|
82
83
|
Args:
|
|
83
|
-
----
|
|
84
84
|
rnn_units: number of hidden units in recurrent cells
|
|
85
85
|
max_length: maximum length of a sequence
|
|
86
86
|
vocab_size: number of classes in the model alphabet
|
|
@@ -114,12 +114,12 @@ class SARDecoder(nn.Module):
|
|
|
114
114
|
self,
|
|
115
115
|
features: torch.Tensor, # (N, C, H, W)
|
|
116
116
|
holistic: torch.Tensor, # (N, C)
|
|
117
|
-
gt:
|
|
117
|
+
gt: torch.Tensor | None = None, # (N, L)
|
|
118
118
|
) -> torch.Tensor:
|
|
119
119
|
if gt is not None:
|
|
120
120
|
gt_embedding = self.embed_tgt(gt)
|
|
121
121
|
|
|
122
|
-
logits_list:
|
|
122
|
+
logits_list: list[torch.Tensor] = []
|
|
123
123
|
|
|
124
124
|
for t in range(self.max_length + 1): # 32
|
|
125
125
|
if t == 0:
|
|
@@ -166,7 +166,6 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
166
166
|
Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
167
167
|
|
|
168
168
|
Args:
|
|
169
|
-
----
|
|
170
169
|
feature_extractor: the backbone serving as feature extractor
|
|
171
170
|
vocab: vocabulary used for encoding
|
|
172
171
|
rnn_units: number of hidden units in both encoder and decoder LSTM
|
|
@@ -187,9 +186,9 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
187
186
|
attention_units: int = 512,
|
|
188
187
|
max_length: int = 30,
|
|
189
188
|
dropout_prob: float = 0.0,
|
|
190
|
-
input_shape:
|
|
189
|
+
input_shape: tuple[int, int, int] = (3, 32, 128),
|
|
191
190
|
exportable: bool = False,
|
|
192
|
-
cfg:
|
|
191
|
+
cfg: dict[str, Any] | None = None,
|
|
193
192
|
) -> None:
|
|
194
193
|
super().__init__()
|
|
195
194
|
self.vocab = vocab
|
|
@@ -229,13 +228,22 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
229
228
|
nn.init.constant_(m.weight, 1)
|
|
230
229
|
nn.init.constant_(m.bias, 0)
|
|
231
230
|
|
|
231
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
232
|
+
"""Load pretrained parameters onto the model
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
236
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
237
|
+
"""
|
|
238
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
239
|
+
|
|
232
240
|
def forward(
|
|
233
241
|
self,
|
|
234
242
|
x: torch.Tensor,
|
|
235
|
-
target:
|
|
243
|
+
target: list[str] | None = None,
|
|
236
244
|
return_model_output: bool = False,
|
|
237
245
|
return_preds: bool = False,
|
|
238
|
-
) ->
|
|
246
|
+
) -> dict[str, Any]:
|
|
239
247
|
features = self.feat_extractor(x)["features"]
|
|
240
248
|
# NOTE: use max instead of functional max_pool2d which leads to ONNX incompatibility (kernel_size)
|
|
241
249
|
# Vertical max pooling (N, C, H, W) --> (N, C, W)
|
|
@@ -254,7 +262,7 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
254
262
|
|
|
255
263
|
decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt))
|
|
256
264
|
|
|
257
|
-
out:
|
|
265
|
+
out: dict[str, Any] = {}
|
|
258
266
|
if self.exportable:
|
|
259
267
|
out["logits"] = decoded_features
|
|
260
268
|
return out
|
|
@@ -263,8 +271,13 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
263
271
|
out["out_map"] = decoded_features
|
|
264
272
|
|
|
265
273
|
if target is None or return_preds:
|
|
274
|
+
# Disable for torch.compile compatibility
|
|
275
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
276
|
+
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
277
|
+
return self.postprocessor(decoded_features)
|
|
278
|
+
|
|
266
279
|
# Post-process boxes
|
|
267
|
-
out["preds"] =
|
|
280
|
+
out["preds"] = _postprocess(decoded_features)
|
|
268
281
|
|
|
269
282
|
if target is not None:
|
|
270
283
|
out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
|
|
@@ -281,19 +294,17 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
281
294
|
Sequences are masked after the EOS character.
|
|
282
295
|
|
|
283
296
|
Args:
|
|
284
|
-
----
|
|
285
297
|
model_output: predicted logits of the model
|
|
286
298
|
gt: the encoded tensor with gt labels
|
|
287
299
|
seq_len: lengths of each gt word inside the batch
|
|
288
300
|
|
|
289
301
|
Returns:
|
|
290
|
-
-------
|
|
291
302
|
The loss of the model on the batch
|
|
292
303
|
"""
|
|
293
304
|
# Input length : number of timesteps
|
|
294
305
|
input_len = model_output.shape[1]
|
|
295
306
|
# Add one for additional <eos> token
|
|
296
|
-
seq_len = seq_len + 1
|
|
307
|
+
seq_len = seq_len + 1 # type: ignore[assignment]
|
|
297
308
|
# Compute loss
|
|
298
309
|
# (N, L, vocab_size + 1)
|
|
299
310
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
|
|
@@ -308,14 +319,13 @@ class SARPostProcessor(RecognitionPostProcessor):
|
|
|
308
319
|
"""Post processor for SAR architectures
|
|
309
320
|
|
|
310
321
|
Args:
|
|
311
|
-
----
|
|
312
322
|
vocab: string containing the ordered sequence of supported characters
|
|
313
323
|
"""
|
|
314
324
|
|
|
315
325
|
def __call__(
|
|
316
326
|
self,
|
|
317
327
|
logits: torch.Tensor,
|
|
318
|
-
) ->
|
|
328
|
+
) -> list[tuple[str, float]]:
|
|
319
329
|
# compute pred with argmax for attention models
|
|
320
330
|
out_idxs = logits.argmax(-1)
|
|
321
331
|
# N x L
|
|
@@ -338,7 +348,7 @@ def _sar(
|
|
|
338
348
|
backbone_fn: Callable[[bool], nn.Module],
|
|
339
349
|
layer: str,
|
|
340
350
|
pretrained_backbone: bool = True,
|
|
341
|
-
ignore_keys:
|
|
351
|
+
ignore_keys: list[str] | None = None,
|
|
342
352
|
**kwargs: Any,
|
|
343
353
|
) -> SAR:
|
|
344
354
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -363,7 +373,7 @@ def _sar(
|
|
|
363
373
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
364
374
|
# remove the last layer weights
|
|
365
375
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
366
|
-
|
|
376
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
367
377
|
|
|
368
378
|
return model
|
|
369
379
|
|
|
@@ -379,12 +389,10 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
|
|
|
379
389
|
>>> out = model(input_tensor)
|
|
380
390
|
|
|
381
391
|
Args:
|
|
382
|
-
----
|
|
383
392
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
384
393
|
**kwargs: keyword arguments of the SAR architecture
|
|
385
394
|
|
|
386
395
|
Returns:
|
|
387
|
-
-------
|
|
388
396
|
text recognition architecture
|
|
389
397
|
"""
|
|
390
398
|
return _sar(
|