python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +1 -5
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1100 -54
- doctr/file_utils.py +2 -92
- doctr/io/elements.py +37 -3
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +4 -4
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -4
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +11 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vip/__init__.py +1 -0
- doctr/models/classification/vip/layers/__init__.py +1 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +12 -3
- doctr/models/classification/zoo.py +7 -8
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +7 -16
- doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +6 -17
- doctr/models/detection/fast/pytorch.py +17 -8
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +5 -15
- doctr/models/detection/linknet/pytorch.py +12 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +9 -22
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +52 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +16 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +15 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +26 -8
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +100 -47
- doctr/models/recognition/predictor/pytorch.py +4 -5
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +13 -4
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +1 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +13 -4
- doctr/models/recognition/zoo.py +13 -8
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +29 -19
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -433
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -397
- doctr/models/classification/textnet/tensorflow.py +0 -266
- doctr/models/classification/vgg/tensorflow.py +0 -116
- doctr/models/classification/vit/tensorflow.py +0 -192
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
- doctr/models/detection/fast/tensorflow.py +0 -419
- doctr/models/detection/linknet/tensorflow.py +0 -369
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -308
- doctr/models/recognition/master/tensorflow.py +0 -313
- doctr/models/recognition/parseq/tensorflow.py +0 -508
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -416
- doctr/models/recognition/vitstr/tensorflow.py +0 -278
- doctr/models/utils/tensorflow.py +0 -182
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.11.0.dist-info/RECORD +0 -173
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
doctr/models/predictor/base.py
CHANGED
|
@@ -116,18 +116,14 @@ class _OCRPredictor:
|
|
|
116
116
|
def _generate_crops(
|
|
117
117
|
pages: list[np.ndarray],
|
|
118
118
|
loc_preds: list[np.ndarray],
|
|
119
|
-
channels_last: bool,
|
|
120
119
|
assume_straight_pages: bool = False,
|
|
121
120
|
assume_horizontal: bool = False,
|
|
122
121
|
) -> list[list[np.ndarray]]:
|
|
123
122
|
if assume_straight_pages:
|
|
124
|
-
crops = [
|
|
125
|
-
extract_crops(page, _boxes[:, :4], channels_last=channels_last)
|
|
126
|
-
for page, _boxes in zip(pages, loc_preds)
|
|
127
|
-
]
|
|
123
|
+
crops = [extract_crops(page, _boxes[:, :4]) for page, _boxes in zip(pages, loc_preds)]
|
|
128
124
|
else:
|
|
129
125
|
crops = [
|
|
130
|
-
extract_rcrops(page, _boxes[:, :4],
|
|
126
|
+
extract_rcrops(page, _boxes[:, :4], assume_horizontal=assume_horizontal)
|
|
131
127
|
for page, _boxes in zip(pages, loc_preds)
|
|
132
128
|
]
|
|
133
129
|
return crops
|
|
@@ -136,11 +132,10 @@ class _OCRPredictor:
|
|
|
136
132
|
def _prepare_crops(
|
|
137
133
|
pages: list[np.ndarray],
|
|
138
134
|
loc_preds: list[np.ndarray],
|
|
139
|
-
channels_last: bool,
|
|
140
135
|
assume_straight_pages: bool = False,
|
|
141
136
|
assume_horizontal: bool = False,
|
|
142
137
|
) -> tuple[list[list[np.ndarray]], list[np.ndarray]]:
|
|
143
|
-
crops = _OCRPredictor._generate_crops(pages, loc_preds,
|
|
138
|
+
crops = _OCRPredictor._generate_crops(pages, loc_preds, assume_straight_pages, assume_horizontal)
|
|
144
139
|
|
|
145
140
|
# Avoid sending zero-sized crops
|
|
146
141
|
is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
|
|
@@ -68,14 +68,14 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
68
68
|
@torch.inference_mode()
|
|
69
69
|
def forward(
|
|
70
70
|
self,
|
|
71
|
-
pages: list[np.ndarray
|
|
71
|
+
pages: list[np.ndarray],
|
|
72
72
|
**kwargs: Any,
|
|
73
73
|
) -> Document:
|
|
74
74
|
# Dimension check
|
|
75
75
|
if any(page.ndim != 3 for page in pages):
|
|
76
76
|
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
|
|
77
77
|
|
|
78
|
-
origin_page_shapes = [page.shape[:2]
|
|
78
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
79
79
|
|
|
80
80
|
# Localize text elements
|
|
81
81
|
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
|
|
@@ -109,8 +109,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
109
109
|
loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
|
|
110
110
|
# Detach objectness scores from loc_preds
|
|
111
111
|
loc_preds, objectness_scores = detach_scores(loc_preds)
|
|
112
|
-
# Check whether crop mode should be switched to channels first
|
|
113
|
-
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
114
112
|
|
|
115
113
|
# Apply hooks to loc_preds if any
|
|
116
114
|
for hook in self.hooks:
|
|
@@ -120,7 +118,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
120
118
|
crops, loc_preds = self._prepare_crops(
|
|
121
119
|
pages,
|
|
122
120
|
loc_preds,
|
|
123
|
-
channels_last=channels_last,
|
|
124
121
|
assume_straight_pages=self.assume_straight_pages,
|
|
125
122
|
assume_horizontal=self._page_orientation_disabled,
|
|
126
123
|
)
|
|
@@ -150,7 +147,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
150
147
|
boxes,
|
|
151
148
|
objectness_scores,
|
|
152
149
|
text_preds,
|
|
153
|
-
origin_page_shapes,
|
|
150
|
+
origin_page_shapes,
|
|
154
151
|
crop_orientations,
|
|
155
152
|
orientations,
|
|
156
153
|
languages_dict,
|
|
@@ -60,65 +60,60 @@ class PreProcessor(nn.Module):
|
|
|
60
60
|
|
|
61
61
|
return batches
|
|
62
62
|
|
|
63
|
-
def sample_transforms(self, x: np.ndarray
|
|
63
|
+
def sample_transforms(self, x: np.ndarray) -> torch.Tensor:
|
|
64
64
|
if x.ndim != 3:
|
|
65
65
|
raise AssertionError("expected list of 3D Tensors")
|
|
66
|
-
if
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
x = torch.from_numpy(x.copy()).permute(2, 0, 1)
|
|
70
|
-
elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
71
|
-
raise TypeError("unsupported data type for torch.Tensor")
|
|
66
|
+
if x.dtype not in (np.uint8, np.float32, np.float16):
|
|
67
|
+
raise TypeError("unsupported data type for numpy.ndarray")
|
|
68
|
+
tensor = torch.from_numpy(x.copy()).permute(2, 0, 1)
|
|
72
69
|
# Resizing
|
|
73
|
-
|
|
70
|
+
tensor = self.resize(tensor)
|
|
74
71
|
# Data type
|
|
75
|
-
if
|
|
76
|
-
|
|
72
|
+
if tensor.dtype == torch.uint8:
|
|
73
|
+
tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
|
|
77
74
|
else:
|
|
78
|
-
|
|
75
|
+
tensor = tensor.to(dtype=torch.float32)
|
|
79
76
|
|
|
80
|
-
return
|
|
77
|
+
return tensor
|
|
81
78
|
|
|
82
|
-
def __call__(self, x:
|
|
79
|
+
def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[torch.Tensor]:
|
|
83
80
|
"""Prepare document data for model forwarding
|
|
84
81
|
|
|
85
82
|
Args:
|
|
86
|
-
x: list of images (np.array) or
|
|
83
|
+
x: list of images (np.array) or a single image (np.array) of shape (H, W, C)
|
|
87
84
|
|
|
88
85
|
Returns:
|
|
89
|
-
list of page batches
|
|
86
|
+
list of page batches (*, C, H, W) ready for model inference
|
|
90
87
|
"""
|
|
91
88
|
# Input type check
|
|
92
|
-
if isinstance(x,
|
|
89
|
+
if isinstance(x, np.ndarray):
|
|
93
90
|
if x.ndim != 4:
|
|
94
91
|
raise AssertionError("expected 4D Tensor")
|
|
95
|
-
if
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
100
|
-
raise TypeError("unsupported data type for torch.Tensor")
|
|
92
|
+
if x.dtype not in (np.uint8, np.float32, np.float16):
|
|
93
|
+
raise TypeError("unsupported data type for numpy.ndarray")
|
|
94
|
+
tensor = torch.from_numpy(x.copy()).permute(0, 3, 1, 2)
|
|
95
|
+
|
|
101
96
|
# Resizing
|
|
102
|
-
if
|
|
103
|
-
|
|
104
|
-
|
|
97
|
+
if tensor.shape[-2] != self.resize.size[0] or tensor.shape[-1] != self.resize.size[1]:
|
|
98
|
+
tensor = F.resize(
|
|
99
|
+
tensor, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
|
|
105
100
|
)
|
|
106
101
|
# Data type
|
|
107
|
-
if
|
|
108
|
-
|
|
102
|
+
if tensor.dtype == torch.uint8:
|
|
103
|
+
tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
|
|
109
104
|
else:
|
|
110
|
-
|
|
111
|
-
batches = [
|
|
105
|
+
tensor = tensor.to(dtype=torch.float32)
|
|
106
|
+
batches = [tensor]
|
|
112
107
|
|
|
113
|
-
elif isinstance(x, list) and all(isinstance(sample,
|
|
108
|
+
elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x):
|
|
114
109
|
# Sample transform (to tensor, resize)
|
|
115
110
|
samples = list(multithread_exec(self.sample_transforms, x))
|
|
116
111
|
# Batching
|
|
117
|
-
batches = self.batch_inputs(samples)
|
|
112
|
+
batches = self.batch_inputs(samples)
|
|
118
113
|
else:
|
|
119
114
|
raise TypeError(f"invalid input type: {type(x)}")
|
|
120
115
|
|
|
121
116
|
# Batch transforms (normalize)
|
|
122
117
|
batches = list(multithread_exec(self.normalize, batches))
|
|
123
118
|
|
|
124
|
-
return batches
|
|
119
|
+
return batches
|
|
@@ -15,7 +15,7 @@ from torch.nn import functional as F
|
|
|
15
15
|
from doctr.datasets import VOCABS, decode_sequence
|
|
16
16
|
|
|
17
17
|
from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
|
|
18
|
-
from ...utils
|
|
18
|
+
from ...utils import load_pretrained_params
|
|
19
19
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
20
20
|
|
|
21
21
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
@@ -25,8 +25,8 @@ default_cfgs: dict[str, dict[str, Any]] = {
|
|
|
25
25
|
"mean": (0.694, 0.695, 0.693),
|
|
26
26
|
"std": (0.299, 0.296, 0.301),
|
|
27
27
|
"input_shape": (3, 32, 128),
|
|
28
|
-
"vocab": VOCABS["
|
|
29
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
28
|
+
"vocab": VOCABS["french"],
|
|
29
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.12.0/crnn_vgg16_bn-0417f351.pt&src=0",
|
|
30
30
|
},
|
|
31
31
|
"crnn_mobilenet_v3_small": {
|
|
32
32
|
"mean": (0.694, 0.695, 0.693),
|
|
@@ -82,7 +82,7 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
82
82
|
|
|
83
83
|
def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
84
84
|
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
85
|
-
with label_to_idx mapping
|
|
85
|
+
with label_to_idx mapping dictionary
|
|
86
86
|
|
|
87
87
|
Args:
|
|
88
88
|
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
@@ -155,6 +155,15 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
155
155
|
m.weight.data.fill_(1.0)
|
|
156
156
|
m.bias.data.zero_()
|
|
157
157
|
|
|
158
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
159
|
+
"""Load pretrained parameters onto the model
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
163
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
164
|
+
"""
|
|
165
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
166
|
+
|
|
158
167
|
def compute_loss(
|
|
159
168
|
self,
|
|
160
169
|
model_output: torch.Tensor,
|
|
@@ -214,7 +223,7 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
214
223
|
|
|
215
224
|
if target is None or return_preds:
|
|
216
225
|
# Disable for torch.compile compatibility
|
|
217
|
-
@torch.compiler.disable
|
|
226
|
+
@torch.compiler.disable
|
|
218
227
|
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
219
228
|
return self.postprocessor(logits)
|
|
220
229
|
|
|
@@ -248,13 +257,13 @@ def _crnn(
|
|
|
248
257
|
_cfg["input_shape"] = kwargs["input_shape"]
|
|
249
258
|
|
|
250
259
|
# Build the model
|
|
251
|
-
model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
|
|
260
|
+
model = CRNN(feat_extractor, cfg=_cfg, **kwargs) # type: ignore[arg-type]
|
|
252
261
|
# Load pretrained parameters
|
|
253
262
|
if pretrained:
|
|
254
263
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
255
264
|
# remove the last layer weights
|
|
256
265
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
257
|
-
|
|
266
|
+
model.from_pretrained(_cfg["url"], ignore_keys=_ignore_keys)
|
|
258
267
|
|
|
259
268
|
return model
|
|
260
269
|
|
|
@@ -16,7 +16,7 @@ from doctr.datasets import VOCABS
|
|
|
16
16
|
from doctr.models.classification import magc_resnet31
|
|
17
17
|
from doctr.models.modules.transformer import Decoder, PositionalEncoding
|
|
18
18
|
|
|
19
|
-
from ...utils
|
|
19
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
20
20
|
from .base import _MASTER, _MASTERPostProcessor
|
|
21
21
|
|
|
22
22
|
__all__ = ["MASTER", "master"]
|
|
@@ -107,7 +107,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
107
107
|
# NOTE: nn.TransformerDecoder takes the inverse from this implementation
|
|
108
108
|
# [True, True, True, ..., False, False, False] -> False is masked
|
|
109
109
|
# (N, 1, 1, max_length)
|
|
110
|
-
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
|
|
110
|
+
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
|
|
111
111
|
target_length = target.size(1)
|
|
112
112
|
# sub mask filled diagonal with True = see and False = masked (max_length, max_length)
|
|
113
113
|
# NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
|
|
@@ -140,7 +140,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
140
140
|
# Input length : number of timesteps
|
|
141
141
|
input_len = model_output.shape[1]
|
|
142
142
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
143
|
-
seq_len = seq_len + 1
|
|
143
|
+
seq_len = seq_len + 1
|
|
144
144
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
145
145
|
# The "masked" first gt char is <sos>. Delete last logit of the model output.
|
|
146
146
|
cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -151,6 +151,15 @@ class MASTER(_MASTER, nn.Module):
|
|
|
151
151
|
ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype)
|
|
152
152
|
return ce_loss.mean()
|
|
153
153
|
|
|
154
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
155
|
+
"""Load pretrained parameters onto the model
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
159
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
160
|
+
"""
|
|
161
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
162
|
+
|
|
154
163
|
def forward(
|
|
155
164
|
self,
|
|
156
165
|
x: torch.Tensor,
|
|
@@ -167,7 +176,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
167
176
|
return_preds: if True, decode logits
|
|
168
177
|
|
|
169
178
|
Returns:
|
|
170
|
-
A
|
|
179
|
+
A dictionary containing eventually loss, logits and predictions.
|
|
171
180
|
"""
|
|
172
181
|
# Encode
|
|
173
182
|
features = self.feat_extractor(x)["features"]
|
|
@@ -210,7 +219,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
210
219
|
|
|
211
220
|
if return_preds:
|
|
212
221
|
# Disable for torch.compile compatibility
|
|
213
|
-
@torch.compiler.disable
|
|
222
|
+
@torch.compiler.disable
|
|
214
223
|
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
215
224
|
return self.postprocessor(logits)
|
|
216
225
|
|
|
@@ -301,7 +310,7 @@ def _master(
|
|
|
301
310
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
302
311
|
# remove the last layer weights
|
|
303
312
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
304
|
-
|
|
313
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
305
314
|
|
|
306
315
|
return model
|
|
307
316
|
|
|
@@ -19,7 +19,7 @@ from doctr.datasets import VOCABS
|
|
|
19
19
|
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
|
|
20
20
|
|
|
21
21
|
from ...classification import vit_s
|
|
22
|
-
from ...utils
|
|
22
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
23
23
|
from .base import _PARSeq, _PARSeqPostProcessor
|
|
24
24
|
|
|
25
25
|
__all__ = ["PARSeq", "parseq"]
|
|
@@ -76,8 +76,6 @@ class PARSeqDecoder(nn.Module):
|
|
|
76
76
|
self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
|
|
77
77
|
self.position_feed_forward = PositionwiseFeedForward(d_model, ffd * ffd_ratio, dropout, nn.GELU())
|
|
78
78
|
|
|
79
|
-
self.attention_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
80
|
-
self.cross_attention_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
81
79
|
self.query_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
82
80
|
self.content_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
83
81
|
self.feed_forward_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
@@ -173,6 +171,26 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
173
171
|
nn.init.constant_(m.weight, 1)
|
|
174
172
|
nn.init.constant_(m.bias, 0)
|
|
175
173
|
|
|
174
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
175
|
+
"""Load pretrained parameters onto the model
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
179
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
180
|
+
"""
|
|
181
|
+
# NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
|
|
182
|
+
# ref.: https://github.com/mindee/doctr/issues/1911
|
|
183
|
+
if kwargs.get("ignore_keys") is None:
|
|
184
|
+
kwargs["ignore_keys"] = []
|
|
185
|
+
|
|
186
|
+
kwargs["ignore_keys"].extend([
|
|
187
|
+
"decoder.attention_norm.weight",
|
|
188
|
+
"decoder.attention_norm.bias",
|
|
189
|
+
"decoder.cross_attention_norm.weight",
|
|
190
|
+
"decoder.cross_attention_norm.bias",
|
|
191
|
+
])
|
|
192
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
193
|
+
|
|
176
194
|
def generate_permutations(self, seqlen: torch.Tensor) -> torch.Tensor:
|
|
177
195
|
# Generates permutations of the target sequence.
|
|
178
196
|
# Borrowed from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
|
|
@@ -210,7 +228,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
210
228
|
|
|
211
229
|
sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
|
|
212
230
|
eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
|
|
213
|
-
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
|
|
231
|
+
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
|
|
214
232
|
if len(combined) > 1:
|
|
215
233
|
combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
|
|
216
234
|
return combined
|
|
@@ -281,7 +299,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
281
299
|
|
|
282
300
|
# Stop decoding if all sequences have reached the EOS token
|
|
283
301
|
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
|
|
284
|
-
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
|
|
302
|
+
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
|
|
285
303
|
break
|
|
286
304
|
|
|
287
305
|
logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
|
|
@@ -296,7 +314,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
296
314
|
|
|
297
315
|
# Create padding mask for refined target input maskes all behind EOS token as False
|
|
298
316
|
# (N, 1, 1, max_length)
|
|
299
|
-
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
317
|
+
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
300
318
|
mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
|
|
301
319
|
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
|
|
302
320
|
|
|
@@ -373,7 +391,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
373
391
|
|
|
374
392
|
if target is None or return_preds:
|
|
375
393
|
# Disable for torch.compile compatibility
|
|
376
|
-
@torch.compiler.disable
|
|
394
|
+
@torch.compiler.disable
|
|
377
395
|
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
378
396
|
return self.postprocessor(logits)
|
|
379
397
|
|
|
@@ -448,7 +466,7 @@ def _parseq(
|
|
|
448
466
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
449
467
|
# remove the last layer weights
|
|
450
468
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
451
|
-
|
|
469
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
452
470
|
|
|
453
471
|
return model
|
|
454
472
|
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
+
import math
|
|
8
|
+
|
|
7
9
|
import numpy as np
|
|
8
10
|
|
|
9
11
|
from ..utils import merge_multi_strings
|
|
@@ -15,69 +17,120 @@ def split_crops(
|
|
|
15
17
|
crops: list[np.ndarray],
|
|
16
18
|
max_ratio: float,
|
|
17
19
|
target_ratio: int,
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
split_overlap_ratio: float,
|
|
21
|
+
) -> tuple[list[np.ndarray], list[int | tuple[int, int, float]], bool]:
|
|
22
|
+
"""
|
|
23
|
+
Split crops horizontally if they exceed a given aspect ratio.
|
|
22
24
|
|
|
23
25
|
Args:
|
|
24
|
-
crops:
|
|
25
|
-
max_ratio:
|
|
26
|
-
target_ratio:
|
|
27
|
-
|
|
28
|
-
channels_last: whether the numpy array has dimensions in channels last order
|
|
26
|
+
crops: List of image crops (H, W, C).
|
|
27
|
+
max_ratio: Aspect ratio threshold above which crops are split.
|
|
28
|
+
target_ratio: Target aspect ratio after splitting (e.g., 4 for 128x32).
|
|
29
|
+
split_overlap_ratio: Desired overlap between splits (as a fraction of split width).
|
|
29
30
|
|
|
30
31
|
Returns:
|
|
31
|
-
|
|
32
|
+
A tuple containing:
|
|
33
|
+
- The new list of crops (possibly with splits),
|
|
34
|
+
- A mapping indicating how to reassemble predictions,
|
|
35
|
+
- A boolean indicating whether remapping is required.
|
|
32
36
|
"""
|
|
33
|
-
|
|
34
|
-
|
|
37
|
+
if split_overlap_ratio <= 0.0 or split_overlap_ratio >= 1.0:
|
|
38
|
+
raise ValueError(f"Valid range for split_overlap_ratio is (0.0, 1.0), but is: {split_overlap_ratio}")
|
|
39
|
+
|
|
40
|
+
remap_required = False
|
|
35
41
|
new_crops: list[np.ndarray] = []
|
|
42
|
+
crop_map: list[int | tuple[int, int, float]] = []
|
|
43
|
+
|
|
36
44
|
for crop in crops:
|
|
37
|
-
h, w = crop.shape[:2]
|
|
45
|
+
h, w = crop.shape[:2]
|
|
38
46
|
aspect_ratio = w / h
|
|
47
|
+
|
|
39
48
|
if aspect_ratio > max_ratio:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
#
|
|
46
|
-
if
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
49
|
+
split_width = max(1, math.ceil(h * target_ratio))
|
|
50
|
+
overlap_width = max(0, math.floor(split_width * split_overlap_ratio))
|
|
51
|
+
|
|
52
|
+
splits, last_overlap = _split_horizontally(crop, split_width, overlap_width)
|
|
53
|
+
|
|
54
|
+
# Remove any empty splits
|
|
55
|
+
splits = [s for s in splits if all(dim > 0 for dim in s.shape)]
|
|
56
|
+
if splits:
|
|
57
|
+
crop_map.append((len(new_crops), len(new_crops) + len(splits), last_overlap))
|
|
58
|
+
new_crops.extend(splits)
|
|
59
|
+
remap_required = True
|
|
51
60
|
else:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
]
|
|
56
|
-
# Avoid sending zero-sized crops
|
|
57
|
-
_crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
|
|
58
|
-
# Record the slice of crops
|
|
59
|
-
crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
|
|
60
|
-
new_crops.extend(_crops)
|
|
61
|
-
# At least one crop will require merging
|
|
62
|
-
_remap_required = True
|
|
61
|
+
# Fallback: treat it as a single crop
|
|
62
|
+
crop_map.append(len(new_crops))
|
|
63
|
+
new_crops.append(crop)
|
|
63
64
|
else:
|
|
64
65
|
crop_map.append(len(new_crops))
|
|
65
66
|
new_crops.append(crop)
|
|
66
67
|
|
|
67
|
-
return new_crops, crop_map,
|
|
68
|
+
return new_crops, crop_map, remap_required
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _split_horizontally(image: np.ndarray, split_width: int, overlap_width: int) -> tuple[list[np.ndarray], float]:
|
|
72
|
+
"""
|
|
73
|
+
Horizontally split a single image with overlapping regions.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
image: The image to split (H, W, C).
|
|
77
|
+
split_width: Width of each split.
|
|
78
|
+
overlap_width: Width of the overlapping region.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
- A list of horizontal image slices.
|
|
82
|
+
- The actual overlap ratio of the last split.
|
|
83
|
+
"""
|
|
84
|
+
image_width = image.shape[1]
|
|
85
|
+
if image_width <= split_width:
|
|
86
|
+
return [image], 0.0
|
|
87
|
+
|
|
88
|
+
# Compute start columns for each split
|
|
89
|
+
step = split_width - overlap_width
|
|
90
|
+
starts = list(range(0, image_width - split_width + 1, step))
|
|
91
|
+
|
|
92
|
+
# Ensure the last patch reaches the end of the image
|
|
93
|
+
if starts[-1] + split_width < image_width:
|
|
94
|
+
starts.append(image_width - split_width)
|
|
95
|
+
|
|
96
|
+
splits = []
|
|
97
|
+
for start_col in starts:
|
|
98
|
+
end_col = start_col + split_width
|
|
99
|
+
splits.append(image[:, start_col:end_col, :])
|
|
100
|
+
|
|
101
|
+
# Calculate the last overlap ratio, if only one split no overlap
|
|
102
|
+
last_overlap = 0
|
|
103
|
+
if len(starts) > 1:
|
|
104
|
+
last_overlap = (starts[-2] + split_width) - starts[-1]
|
|
105
|
+
last_overlap_ratio = last_overlap / split_width if split_width else 0.0
|
|
106
|
+
|
|
107
|
+
return splits, last_overlap_ratio
|
|
68
108
|
|
|
69
109
|
|
|
70
110
|
def remap_preds(
|
|
71
|
-
preds: list[tuple[str, float]],
|
|
111
|
+
preds: list[tuple[str, float]],
|
|
112
|
+
crop_map: list[int | tuple[int, int, float]],
|
|
113
|
+
overlap_ratio: float,
|
|
72
114
|
) -> list[tuple[str, float]]:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
115
|
+
"""
|
|
116
|
+
Reconstruct predictions from possibly split crops.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
preds: List of (text, confidence) tuples from each crop.
|
|
120
|
+
crop_map: Map returned by `split_crops`.
|
|
121
|
+
overlap_ratio: Overlap ratio used during splitting.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
List of merged (text, confidence) tuples corresponding to original crops.
|
|
125
|
+
"""
|
|
126
|
+
remapped = []
|
|
127
|
+
for item in crop_map:
|
|
128
|
+
if isinstance(item, int):
|
|
129
|
+
remapped.append(preds[item])
|
|
78
130
|
else:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
131
|
+
start_idx, end_idx, last_overlap = item
|
|
132
|
+
text_parts, confidences = zip(*preds[start_idx:end_idx])
|
|
133
|
+
merged_text = merge_multi_strings(list(text_parts), overlap_ratio, last_overlap)
|
|
134
|
+
merged_conf = sum(confidences) / len(confidences) # average confidence
|
|
135
|
+
remapped.append((merged_text, merged_conf))
|
|
136
|
+
return remapped
|