python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import torch
|
|
10
11
|
from torch import nn
|
|
@@ -19,7 +20,7 @@ from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
|
19
20
|
|
|
20
21
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
21
22
|
|
|
22
|
-
default_cfgs:
|
|
23
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
24
|
"vitstr_small": {
|
|
24
25
|
"mean": (0.694, 0.695, 0.693),
|
|
25
26
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -42,7 +43,6 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
42
43
|
Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
43
44
|
|
|
44
45
|
Args:
|
|
45
|
-
----
|
|
46
46
|
feature_extractor: the backbone serving as feature extractor
|
|
47
47
|
vocab: vocabulary used for encoding
|
|
48
48
|
embedding_units: number of embedding units
|
|
@@ -59,9 +59,9 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
59
59
|
vocab: str,
|
|
60
60
|
embedding_units: int,
|
|
61
61
|
max_length: int = 32, # different from paper
|
|
62
|
-
input_shape:
|
|
62
|
+
input_shape: tuple[int, int, int] = (3, 32, 128), # different from paper
|
|
63
63
|
exportable: bool = False,
|
|
64
|
-
cfg:
|
|
64
|
+
cfg: dict[str, Any] | None = None,
|
|
65
65
|
) -> None:
|
|
66
66
|
super().__init__()
|
|
67
67
|
self.vocab = vocab
|
|
@@ -74,13 +74,22 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
74
74
|
|
|
75
75
|
self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
|
|
76
76
|
|
|
77
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
78
|
+
"""Load pretrained parameters onto the model
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
82
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
83
|
+
"""
|
|
84
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
85
|
+
|
|
77
86
|
def forward(
|
|
78
87
|
self,
|
|
79
88
|
x: torch.Tensor,
|
|
80
|
-
target:
|
|
89
|
+
target: list[str] | None = None,
|
|
81
90
|
return_model_output: bool = False,
|
|
82
91
|
return_preds: bool = False,
|
|
83
|
-
) ->
|
|
92
|
+
) -> dict[str, Any]:
|
|
84
93
|
features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
|
|
85
94
|
|
|
86
95
|
if target is not None:
|
|
@@ -98,7 +107,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
98
107
|
logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1)
|
|
99
108
|
decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
|
|
100
109
|
|
|
101
|
-
out:
|
|
110
|
+
out: dict[str, Any] = {}
|
|
102
111
|
if self.exportable:
|
|
103
112
|
out["logits"] = decoded_features
|
|
104
113
|
return out
|
|
@@ -107,8 +116,13 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
107
116
|
out["out_map"] = decoded_features
|
|
108
117
|
|
|
109
118
|
if target is None or return_preds:
|
|
119
|
+
# Disable for torch.compile compatibility
|
|
120
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
121
|
+
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
122
|
+
return self.postprocessor(decoded_features)
|
|
123
|
+
|
|
110
124
|
# Post-process boxes
|
|
111
|
-
out["preds"] =
|
|
125
|
+
out["preds"] = _postprocess(decoded_features)
|
|
112
126
|
|
|
113
127
|
if target is not None:
|
|
114
128
|
out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
|
|
@@ -125,19 +139,17 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
125
139
|
Sequences are masked after the EOS character.
|
|
126
140
|
|
|
127
141
|
Args:
|
|
128
|
-
----
|
|
129
142
|
model_output: predicted logits of the model
|
|
130
143
|
gt: the encoded tensor with gt labels
|
|
131
144
|
seq_len: lengths of each gt word inside the batch
|
|
132
145
|
|
|
133
146
|
Returns:
|
|
134
|
-
-------
|
|
135
147
|
The loss of the model on the batch
|
|
136
148
|
"""
|
|
137
149
|
# Input length : number of steps
|
|
138
150
|
input_len = model_output.shape[1]
|
|
139
151
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
140
|
-
seq_len = seq_len + 1
|
|
152
|
+
seq_len = seq_len + 1 # type: ignore[assignment]
|
|
141
153
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
142
154
|
# The "masked" first gt char is <sos>.
|
|
143
155
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -153,14 +165,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
|
|
|
153
165
|
"""Post processor for ViTSTR architecture
|
|
154
166
|
|
|
155
167
|
Args:
|
|
156
|
-
----
|
|
157
168
|
vocab: string containing the ordered sequence of supported characters
|
|
158
169
|
"""
|
|
159
170
|
|
|
160
171
|
def __call__(
|
|
161
172
|
self,
|
|
162
173
|
logits: torch.Tensor,
|
|
163
|
-
) ->
|
|
174
|
+
) -> list[tuple[str, float]]:
|
|
164
175
|
# compute pred with argmax for attention models
|
|
165
176
|
out_idxs = logits.argmax(-1)
|
|
166
177
|
preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
|
|
@@ -183,7 +194,7 @@ def _vitstr(
|
|
|
183
194
|
pretrained: bool,
|
|
184
195
|
backbone_fn: Callable[[bool], nn.Module],
|
|
185
196
|
layer: str,
|
|
186
|
-
ignore_keys:
|
|
197
|
+
ignore_keys: list[str] | None = None,
|
|
187
198
|
**kwargs: Any,
|
|
188
199
|
) -> ViTSTR:
|
|
189
200
|
# Patch the config
|
|
@@ -212,7 +223,7 @@ def _vitstr(
|
|
|
212
223
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
213
224
|
# remove the last layer weights
|
|
214
225
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
215
|
-
|
|
226
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
216
227
|
|
|
217
228
|
return model
|
|
218
229
|
|
|
@@ -228,12 +239,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
228
239
|
>>> out = model(input_tensor)
|
|
229
240
|
|
|
230
241
|
Args:
|
|
231
|
-
----
|
|
232
242
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
233
243
|
kwargs: keyword arguments of the ViTSTR architecture
|
|
234
244
|
|
|
235
245
|
Returns:
|
|
236
|
-
-------
|
|
237
246
|
text recognition architecture
|
|
238
247
|
"""
|
|
239
248
|
return _vitstr(
|
|
@@ -259,12 +268,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
259
268
|
>>> out = model(input_tensor)
|
|
260
269
|
|
|
261
270
|
Args:
|
|
262
|
-
----
|
|
263
271
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
264
272
|
kwargs: keyword arguments of the ViTSTR architecture
|
|
265
273
|
|
|
266
274
|
Returns:
|
|
267
|
-
-------
|
|
268
275
|
text recognition architecture
|
|
269
276
|
"""
|
|
270
277
|
return _vitstr(
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import Model, layers
|
|
@@ -17,7 +17,7 @@ from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
|
17
17
|
|
|
18
18
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
19
19
|
|
|
20
|
-
default_cfgs:
|
|
20
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
21
21
|
"vitstr_small": {
|
|
22
22
|
"mean": (0.694, 0.695, 0.693),
|
|
23
23
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -40,7 +40,6 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
40
40
|
Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
41
41
|
|
|
42
42
|
Args:
|
|
43
|
-
----
|
|
44
43
|
feature_extractor: the backbone serving as feature extractor
|
|
45
44
|
vocab: vocabulary used for encoding
|
|
46
45
|
embedding_units: number of embedding units
|
|
@@ -51,7 +50,7 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
51
50
|
cfg: dictionary containing information about the model
|
|
52
51
|
"""
|
|
53
52
|
|
|
54
|
-
_children_names:
|
|
53
|
+
_children_names: list[str] = ["feat_extractor", "postprocessor"]
|
|
55
54
|
|
|
56
55
|
def __init__(
|
|
57
56
|
self,
|
|
@@ -60,9 +59,9 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
60
59
|
embedding_units: int,
|
|
61
60
|
max_length: int = 32,
|
|
62
61
|
dropout_prob: float = 0.0,
|
|
63
|
-
input_shape:
|
|
62
|
+
input_shape: tuple[int, int, int] = (32, 128, 3), # different from paper
|
|
64
63
|
exportable: bool = False,
|
|
65
|
-
cfg:
|
|
64
|
+
cfg: dict[str, Any] | None = None,
|
|
66
65
|
) -> None:
|
|
67
66
|
super().__init__()
|
|
68
67
|
self.vocab = vocab
|
|
@@ -75,23 +74,30 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
75
74
|
|
|
76
75
|
self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
|
|
77
76
|
|
|
77
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
78
|
+
"""Load pretrained parameters onto the model
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
82
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
83
|
+
"""
|
|
84
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
85
|
+
|
|
78
86
|
@staticmethod
|
|
79
87
|
def compute_loss(
|
|
80
88
|
model_output: tf.Tensor,
|
|
81
89
|
gt: tf.Tensor,
|
|
82
|
-
seq_len:
|
|
90
|
+
seq_len: list[int],
|
|
83
91
|
) -> tf.Tensor:
|
|
84
92
|
"""Compute categorical cross-entropy loss for the model.
|
|
85
93
|
Sequences are masked after the EOS character.
|
|
86
94
|
|
|
87
95
|
Args:
|
|
88
|
-
----
|
|
89
96
|
model_output: predicted logits of the model
|
|
90
97
|
gt: the encoded tensor with gt labels
|
|
91
98
|
seq_len: lengths of each gt word inside the batch
|
|
92
99
|
|
|
93
100
|
Returns:
|
|
94
|
-
-------
|
|
95
101
|
The loss of the model on the batch
|
|
96
102
|
"""
|
|
97
103
|
# Input length : number of steps
|
|
@@ -114,11 +120,11 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
114
120
|
def call(
|
|
115
121
|
self,
|
|
116
122
|
x: tf.Tensor,
|
|
117
|
-
target:
|
|
123
|
+
target: list[str] | None = None,
|
|
118
124
|
return_model_output: bool = False,
|
|
119
125
|
return_preds: bool = False,
|
|
120
126
|
**kwargs: Any,
|
|
121
|
-
) ->
|
|
127
|
+
) -> dict[str, Any]:
|
|
122
128
|
features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
|
|
123
129
|
|
|
124
130
|
if target is not None:
|
|
@@ -136,7 +142,7 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
136
142
|
) # (batch_size, max_length, vocab + 1)
|
|
137
143
|
decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
|
|
138
144
|
|
|
139
|
-
out:
|
|
145
|
+
out: dict[str, tf.Tensor] = {}
|
|
140
146
|
if self.exportable:
|
|
141
147
|
out["logits"] = decoded_features
|
|
142
148
|
return out
|
|
@@ -158,14 +164,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
|
|
|
158
164
|
"""Post processor for ViTSTR architecture
|
|
159
165
|
|
|
160
166
|
Args:
|
|
161
|
-
----
|
|
162
167
|
vocab: string containing the ordered sequence of supported characters
|
|
163
168
|
"""
|
|
164
169
|
|
|
165
170
|
def __call__(
|
|
166
171
|
self,
|
|
167
172
|
logits: tf.Tensor,
|
|
168
|
-
) ->
|
|
173
|
+
) -> list[tuple[str, float]]:
|
|
169
174
|
# compute pred with argmax for attention models
|
|
170
175
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
171
176
|
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
|
|
@@ -191,7 +196,7 @@ def _vitstr(
|
|
|
191
196
|
arch: str,
|
|
192
197
|
pretrained: bool,
|
|
193
198
|
backbone_fn,
|
|
194
|
-
input_shape:
|
|
199
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
195
200
|
**kwargs: Any,
|
|
196
201
|
) -> ViTSTR:
|
|
197
202
|
# Patch the config
|
|
@@ -221,9 +226,7 @@ def _vitstr(
|
|
|
221
226
|
# Load pretrained parameters
|
|
222
227
|
if pretrained:
|
|
223
228
|
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
224
|
-
|
|
225
|
-
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
226
|
-
)
|
|
229
|
+
model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
227
230
|
|
|
228
231
|
return model
|
|
229
232
|
|
|
@@ -239,12 +242,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
239
242
|
>>> out = model(input_tensor)
|
|
240
243
|
|
|
241
244
|
Args:
|
|
242
|
-
----
|
|
243
245
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
244
246
|
**kwargs: keyword arguments of the ViTSTR architecture
|
|
245
247
|
|
|
246
248
|
Returns:
|
|
247
|
-
-------
|
|
248
249
|
text recognition architecture
|
|
249
250
|
"""
|
|
250
251
|
return _vitstr(
|
|
@@ -268,12 +269,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
268
269
|
>>> out = model(input_tensor)
|
|
269
270
|
|
|
270
271
|
Args:
|
|
271
|
-
----
|
|
272
272
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
273
273
|
**kwargs: keyword arguments of the ViTSTR architecture
|
|
274
274
|
|
|
275
275
|
Returns:
|
|
276
|
-
-------
|
|
277
276
|
text recognition architecture
|
|
278
277
|
"""
|
|
279
278
|
return _vitstr(
|
doctr/models/recognition/zoo.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.file_utils import is_tf_available
|
|
8
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
9
9
|
from doctr.models.preprocessor import PreProcessor
|
|
10
10
|
|
|
11
11
|
from .. import recognition
|
|
@@ -14,7 +14,7 @@ from .predictor import RecognitionPredictor
|
|
|
14
14
|
__all__ = ["recognition_predictor"]
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
ARCHS:
|
|
17
|
+
ARCHS: list[str] = [
|
|
18
18
|
"crnn_vgg16_bn",
|
|
19
19
|
"crnn_mobilenet_v3_small",
|
|
20
20
|
"crnn_mobilenet_v3_large",
|
|
@@ -25,6 +25,9 @@ ARCHS: List[str] = [
|
|
|
25
25
|
"parseq",
|
|
26
26
|
]
|
|
27
27
|
|
|
28
|
+
if is_torch_available():
|
|
29
|
+
ARCHS.extend(["viptr_tiny"])
|
|
30
|
+
|
|
28
31
|
|
|
29
32
|
def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor:
|
|
30
33
|
if isinstance(arch, str):
|
|
@@ -35,9 +38,16 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
35
38
|
pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
|
|
36
39
|
)
|
|
37
40
|
else:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
+
allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
|
|
42
|
+
if is_torch_available():
|
|
43
|
+
# Add VIPTR which is only available in torch at the moment
|
|
44
|
+
allowed_archs.append(recognition.VIPTR)
|
|
45
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
46
|
+
from doctr.models.utils import _CompiledModule
|
|
47
|
+
|
|
48
|
+
allowed_archs.append(_CompiledModule)
|
|
49
|
+
|
|
50
|
+
if not isinstance(arch, tuple(allowed_archs)):
|
|
41
51
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
42
52
|
_model = arch
|
|
43
53
|
|
|
@@ -52,7 +62,13 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
52
62
|
return predictor
|
|
53
63
|
|
|
54
64
|
|
|
55
|
-
def recognition_predictor(
|
|
65
|
+
def recognition_predictor(
|
|
66
|
+
arch: Any = "crnn_vgg16_bn",
|
|
67
|
+
pretrained: bool = False,
|
|
68
|
+
symmetric_pad: bool = False,
|
|
69
|
+
batch_size: int = 128,
|
|
70
|
+
**kwargs: Any,
|
|
71
|
+
) -> RecognitionPredictor:
|
|
56
72
|
"""Text recognition architecture.
|
|
57
73
|
|
|
58
74
|
Example::
|
|
@@ -63,13 +79,13 @@ def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False,
|
|
|
63
79
|
>>> out = model([input_page])
|
|
64
80
|
|
|
65
81
|
Args:
|
|
66
|
-
----
|
|
67
82
|
arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
|
|
68
83
|
pretrained: If True, returns a model pre-trained on our text recognition dataset
|
|
84
|
+
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
|
|
85
|
+
batch_size: number of samples the model processes in parallel
|
|
69
86
|
**kwargs: optional parameters to be passed to the architecture
|
|
70
87
|
|
|
71
88
|
Returns:
|
|
72
|
-
-------
|
|
73
89
|
Recognition predictor
|
|
74
90
|
"""
|
|
75
|
-
return _predictor(arch, pretrained, **kwargs)
|
|
91
|
+
return _predictor(arch=arch, pretrained=pretrained, symmetric_pad=symmetric_pad, batch_size=batch_size, **kwargs)
|
doctr/models/utils/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
doctr/models/utils/pytorch.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
|
+
import validators
|
|
10
11
|
from torch import nn
|
|
11
12
|
|
|
12
13
|
from doctr.utils.data import download_from_url
|
|
@@ -18,8 +19,12 @@ __all__ = [
|
|
|
18
19
|
"export_model_to_onnx",
|
|
19
20
|
"_copy_tensor",
|
|
20
21
|
"_bf16_to_float32",
|
|
22
|
+
"_CompiledModule",
|
|
21
23
|
]
|
|
22
24
|
|
|
25
|
+
# torch compiled model type
|
|
26
|
+
_CompiledModule = torch._dynamo.eval_frame.OptimizedModule
|
|
27
|
+
|
|
23
28
|
|
|
24
29
|
def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
|
|
25
30
|
return x.clone().detach()
|
|
@@ -32,42 +37,50 @@ def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
|
|
|
32
37
|
|
|
33
38
|
def load_pretrained_params(
|
|
34
39
|
model: nn.Module,
|
|
35
|
-
|
|
36
|
-
hash_prefix:
|
|
37
|
-
ignore_keys:
|
|
40
|
+
path_or_url: str | None = None,
|
|
41
|
+
hash_prefix: str | None = None,
|
|
42
|
+
ignore_keys: list[str] | None = None,
|
|
38
43
|
**kwargs: Any,
|
|
39
44
|
) -> None:
|
|
40
45
|
"""Load a set of parameters onto a model
|
|
41
46
|
|
|
42
47
|
>>> from doctr.models import load_pretrained_params
|
|
43
|
-
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.
|
|
48
|
+
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.pt")
|
|
44
49
|
|
|
45
50
|
Args:
|
|
46
|
-
----
|
|
47
51
|
model: the PyTorch model to be loaded
|
|
48
|
-
|
|
52
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
49
53
|
hash_prefix: first characters of SHA256 expected hash
|
|
50
54
|
ignore_keys: list of weights to be ignored from the state_dict
|
|
51
55
|
**kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
|
|
52
56
|
"""
|
|
53
|
-
if
|
|
54
|
-
logging.warning("
|
|
55
|
-
|
|
56
|
-
|
|
57
|
+
if path_or_url is None:
|
|
58
|
+
logging.warning("No model URL or Path provided, using default initialization.")
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
archive_path = (
|
|
62
|
+
download_from_url(path_or_url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
|
|
63
|
+
if validators.url(path_or_url)
|
|
64
|
+
else path_or_url
|
|
65
|
+
)
|
|
57
66
|
|
|
58
|
-
|
|
59
|
-
|
|
67
|
+
# Read state_dict
|
|
68
|
+
state_dict = torch.load(archive_path, map_location="cpu")
|
|
60
69
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
70
|
+
# Remove weights from the state_dict
|
|
71
|
+
if ignore_keys is not None and len(ignore_keys) > 0:
|
|
72
|
+
for key in ignore_keys:
|
|
73
|
+
if key in state_dict:
|
|
64
74
|
state_dict.pop(key)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
75
|
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
|
76
|
+
if any(k not in ignore_keys for k in missing_keys + unexpected_keys):
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Unable to load state_dict, due to non-matching keys.\n"
|
|
79
|
+
+ f"Unexpected keys: {unexpected_keys}\nMissing keys: {missing_keys}"
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
# Load weights
|
|
83
|
+
model.load_state_dict(state_dict)
|
|
71
84
|
|
|
72
85
|
|
|
73
86
|
def conv_sequence_pt(
|
|
@@ -76,7 +89,7 @@ def conv_sequence_pt(
|
|
|
76
89
|
relu: bool = False,
|
|
77
90
|
bn: bool = False,
|
|
78
91
|
**kwargs: Any,
|
|
79
|
-
) ->
|
|
92
|
+
) -> list[nn.Module]:
|
|
80
93
|
"""Builds a convolutional-based layer sequence
|
|
81
94
|
|
|
82
95
|
>>> from torch.nn import Sequential
|
|
@@ -84,7 +97,6 @@ def conv_sequence_pt(
|
|
|
84
97
|
>>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3))
|
|
85
98
|
|
|
86
99
|
Args:
|
|
87
|
-
----
|
|
88
100
|
in_channels: number of input channels
|
|
89
101
|
out_channels: number of output channels
|
|
90
102
|
relu: whether ReLU should be used
|
|
@@ -92,13 +104,12 @@ def conv_sequence_pt(
|
|
|
92
104
|
**kwargs: additional arguments to be passed to the convolutional layer
|
|
93
105
|
|
|
94
106
|
Returns:
|
|
95
|
-
-------
|
|
96
107
|
list of layers
|
|
97
108
|
"""
|
|
98
109
|
# No bias before Batch norm
|
|
99
110
|
kwargs["bias"] = kwargs.get("bias", not bn)
|
|
100
111
|
# Add activation directly to the conv if there is no BN
|
|
101
|
-
conv_seq:
|
|
112
|
+
conv_seq: list[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)]
|
|
102
113
|
|
|
103
114
|
if bn:
|
|
104
115
|
conv_seq.append(nn.BatchNorm2d(out_channels))
|
|
@@ -110,8 +121,8 @@ def conv_sequence_pt(
|
|
|
110
121
|
|
|
111
122
|
|
|
112
123
|
def set_device_and_dtype(
|
|
113
|
-
model: Any, batches:
|
|
114
|
-
) ->
|
|
124
|
+
model: Any, batches: list[torch.Tensor], device: str | torch.device, dtype: torch.dtype
|
|
125
|
+
) -> tuple[Any, list[torch.Tensor]]:
|
|
115
126
|
"""Set the device and dtype of a model and its batches
|
|
116
127
|
|
|
117
128
|
>>> import torch
|
|
@@ -122,14 +133,12 @@ def set_device_and_dtype(
|
|
|
122
133
|
>>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16)
|
|
123
134
|
|
|
124
135
|
Args:
|
|
125
|
-
----
|
|
126
136
|
model: the model to be set
|
|
127
137
|
batches: the batches to be set
|
|
128
138
|
device: the device to be used
|
|
129
139
|
dtype: the dtype to be used
|
|
130
140
|
|
|
131
141
|
Returns:
|
|
132
|
-
-------
|
|
133
142
|
the model and batches set
|
|
134
143
|
"""
|
|
135
144
|
return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches]
|
|
@@ -145,19 +154,17 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
|
|
|
145
154
|
>>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
|
|
146
155
|
|
|
147
156
|
Args:
|
|
148
|
-
----
|
|
149
157
|
model: the PyTorch model to be exported
|
|
150
158
|
model_name: the name for the exported model
|
|
151
159
|
dummy_input: the dummy input to the model
|
|
152
160
|
kwargs: additional arguments to be passed to torch.onnx.export
|
|
153
161
|
|
|
154
162
|
Returns:
|
|
155
|
-
-------
|
|
156
163
|
the path to the exported model
|
|
157
164
|
"""
|
|
158
165
|
torch.onnx.export(
|
|
159
166
|
model,
|
|
160
|
-
dummy_input,
|
|
167
|
+
dummy_input,
|
|
161
168
|
f"{model_name}.onnx",
|
|
162
169
|
input_names=["input"],
|
|
163
170
|
output_names=["logits"],
|