python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +17 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +17 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +14 -5
- doctr/datasets/ic13.py +13 -5
- doctr/datasets/iiit5k.py +31 -20
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +16 -5
- doctr/datasets/svhn.py +16 -5
- doctr/datasets/svt.py +14 -5
- doctr/datasets/synthtext.py +14 -5
- doctr/datasets/utils.py +37 -27
- doctr/datasets/vocabs.py +21 -7
- doctr/datasets/wildreceipt.py +25 -10
- doctr/file_utils.py +18 -4
- doctr/io/elements.py +69 -81
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +32 -50
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +21 -17
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +7 -17
- doctr/models/classification/mobilenet/tensorflow.py +22 -29
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +13 -11
- doctr/models/classification/predictor/tensorflow.py +13 -11
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +41 -39
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +19 -20
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +18 -15
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +16 -16
- doctr/models/classification/zoo.py +36 -19
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +28 -37
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +36 -33
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +7 -8
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +8 -13
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +8 -5
- doctr/models/kie_predictor/pytorch.py +22 -19
- doctr/models/kie_predictor/tensorflow.py +21 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -12
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +3 -4
- doctr/models/modules/vision_transformer/tensorflow.py +4 -4
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +52 -41
- doctr/models/predictor/pytorch.py +16 -13
- doctr/models/predictor/tensorflow.py +16 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +11 -15
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +19 -29
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +21 -26
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +26 -30
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +19 -24
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +21 -24
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +13 -16
- doctr/models/utils/tensorflow.py +31 -30
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +21 -29
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +65 -28
- doctr/transforms/modules/tensorflow.py +33 -44
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +120 -64
- doctr/utils/metrics.py +18 -38
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +157 -75
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.9.0.dist-info/RECORD +0 -173
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import torch
|
|
10
11
|
from torch import nn
|
|
@@ -19,7 +20,7 @@ from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
|
19
20
|
|
|
20
21
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
21
22
|
|
|
22
|
-
default_cfgs:
|
|
23
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
24
|
"vitstr_small": {
|
|
24
25
|
"mean": (0.694, 0.695, 0.693),
|
|
25
26
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -42,7 +43,6 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
42
43
|
Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
43
44
|
|
|
44
45
|
Args:
|
|
45
|
-
----
|
|
46
46
|
feature_extractor: the backbone serving as feature extractor
|
|
47
47
|
vocab: vocabulary used for encoding
|
|
48
48
|
embedding_units: number of embedding units
|
|
@@ -59,9 +59,9 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
59
59
|
vocab: str,
|
|
60
60
|
embedding_units: int,
|
|
61
61
|
max_length: int = 32, # different from paper
|
|
62
|
-
input_shape:
|
|
62
|
+
input_shape: tuple[int, int, int] = (3, 32, 128), # different from paper
|
|
63
63
|
exportable: bool = False,
|
|
64
|
-
cfg:
|
|
64
|
+
cfg: dict[str, Any] | None = None,
|
|
65
65
|
) -> None:
|
|
66
66
|
super().__init__()
|
|
67
67
|
self.vocab = vocab
|
|
@@ -77,10 +77,10 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
77
77
|
def forward(
|
|
78
78
|
self,
|
|
79
79
|
x: torch.Tensor,
|
|
80
|
-
target:
|
|
80
|
+
target: list[str] | None = None,
|
|
81
81
|
return_model_output: bool = False,
|
|
82
82
|
return_preds: bool = False,
|
|
83
|
-
) ->
|
|
83
|
+
) -> dict[str, Any]:
|
|
84
84
|
features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
|
|
85
85
|
|
|
86
86
|
if target is not None:
|
|
@@ -98,7 +98,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
98
98
|
logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1)
|
|
99
99
|
decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
|
|
100
100
|
|
|
101
|
-
out:
|
|
101
|
+
out: dict[str, Any] = {}
|
|
102
102
|
if self.exportable:
|
|
103
103
|
out["logits"] = decoded_features
|
|
104
104
|
return out
|
|
@@ -107,8 +107,13 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
107
107
|
out["out_map"] = decoded_features
|
|
108
108
|
|
|
109
109
|
if target is None or return_preds:
|
|
110
|
+
# Disable for torch.compile compatibility
|
|
111
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
112
|
+
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
|
|
113
|
+
return self.postprocessor(decoded_features)
|
|
114
|
+
|
|
110
115
|
# Post-process boxes
|
|
111
|
-
out["preds"] =
|
|
116
|
+
out["preds"] = _postprocess(decoded_features)
|
|
112
117
|
|
|
113
118
|
if target is not None:
|
|
114
119
|
out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
|
|
@@ -125,19 +130,17 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
125
130
|
Sequences are masked after the EOS character.
|
|
126
131
|
|
|
127
132
|
Args:
|
|
128
|
-
----
|
|
129
133
|
model_output: predicted logits of the model
|
|
130
134
|
gt: the encoded tensor with gt labels
|
|
131
135
|
seq_len: lengths of each gt word inside the batch
|
|
132
136
|
|
|
133
137
|
Returns:
|
|
134
|
-
-------
|
|
135
138
|
The loss of the model on the batch
|
|
136
139
|
"""
|
|
137
140
|
# Input length : number of steps
|
|
138
141
|
input_len = model_output.shape[1]
|
|
139
142
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
140
|
-
seq_len = seq_len + 1
|
|
143
|
+
seq_len = seq_len + 1 # type: ignore[assignment]
|
|
141
144
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
142
145
|
# The "masked" first gt char is <sos>.
|
|
143
146
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -153,14 +156,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
|
|
|
153
156
|
"""Post processor for ViTSTR architecture
|
|
154
157
|
|
|
155
158
|
Args:
|
|
156
|
-
----
|
|
157
159
|
vocab: string containing the ordered sequence of supported characters
|
|
158
160
|
"""
|
|
159
161
|
|
|
160
162
|
def __call__(
|
|
161
163
|
self,
|
|
162
164
|
logits: torch.Tensor,
|
|
163
|
-
) ->
|
|
165
|
+
) -> list[tuple[str, float]]:
|
|
164
166
|
# compute pred with argmax for attention models
|
|
165
167
|
out_idxs = logits.argmax(-1)
|
|
166
168
|
preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
|
|
@@ -183,7 +185,7 @@ def _vitstr(
|
|
|
183
185
|
pretrained: bool,
|
|
184
186
|
backbone_fn: Callable[[bool], nn.Module],
|
|
185
187
|
layer: str,
|
|
186
|
-
ignore_keys:
|
|
188
|
+
ignore_keys: list[str] | None = None,
|
|
187
189
|
**kwargs: Any,
|
|
188
190
|
) -> ViTSTR:
|
|
189
191
|
# Patch the config
|
|
@@ -228,12 +230,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
228
230
|
>>> out = model(input_tensor)
|
|
229
231
|
|
|
230
232
|
Args:
|
|
231
|
-
----
|
|
232
233
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
233
234
|
kwargs: keyword arguments of the ViTSTR architecture
|
|
234
235
|
|
|
235
236
|
Returns:
|
|
236
|
-
-------
|
|
237
237
|
text recognition architecture
|
|
238
238
|
"""
|
|
239
239
|
return _vitstr(
|
|
@@ -259,12 +259,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
259
259
|
>>> out = model(input_tensor)
|
|
260
260
|
|
|
261
261
|
Args:
|
|
262
|
-
----
|
|
263
262
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
264
263
|
kwargs: keyword arguments of the ViTSTR architecture
|
|
265
264
|
|
|
266
265
|
Returns:
|
|
267
|
-
-------
|
|
268
266
|
text recognition architecture
|
|
269
267
|
"""
|
|
270
268
|
return _vitstr(
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import Model, layers
|
|
@@ -12,25 +12,25 @@ from tensorflow.keras import Model, layers
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
|
|
14
14
|
from ...classification import vit_b, vit_s
|
|
15
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
15
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
16
16
|
from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
17
17
|
|
|
18
18
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
19
19
|
|
|
20
|
-
default_cfgs:
|
|
20
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
21
21
|
"vitstr_small": {
|
|
22
22
|
"mean": (0.694, 0.695, 0.693),
|
|
23
23
|
"std": (0.299, 0.296, 0.301),
|
|
24
24
|
"input_shape": (32, 128, 3),
|
|
25
25
|
"vocab": VOCABS["french"],
|
|
26
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
26
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0",
|
|
27
27
|
},
|
|
28
28
|
"vitstr_base": {
|
|
29
29
|
"mean": (0.694, 0.695, 0.693),
|
|
30
30
|
"std": (0.299, 0.296, 0.301),
|
|
31
31
|
"input_shape": (32, 128, 3),
|
|
32
32
|
"vocab": VOCABS["french"],
|
|
33
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
33
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0",
|
|
34
34
|
},
|
|
35
35
|
}
|
|
36
36
|
|
|
@@ -40,7 +40,6 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
40
40
|
Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
41
41
|
|
|
42
42
|
Args:
|
|
43
|
-
----
|
|
44
43
|
feature_extractor: the backbone serving as feature extractor
|
|
45
44
|
vocab: vocabulary used for encoding
|
|
46
45
|
embedding_units: number of embedding units
|
|
@@ -51,7 +50,7 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
51
50
|
cfg: dictionary containing information about the model
|
|
52
51
|
"""
|
|
53
52
|
|
|
54
|
-
_children_names:
|
|
53
|
+
_children_names: list[str] = ["feat_extractor", "postprocessor"]
|
|
55
54
|
|
|
56
55
|
def __init__(
|
|
57
56
|
self,
|
|
@@ -60,9 +59,9 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
60
59
|
embedding_units: int,
|
|
61
60
|
max_length: int = 32,
|
|
62
61
|
dropout_prob: float = 0.0,
|
|
63
|
-
input_shape:
|
|
62
|
+
input_shape: tuple[int, int, int] = (32, 128, 3), # different from paper
|
|
64
63
|
exportable: bool = False,
|
|
65
|
-
cfg:
|
|
64
|
+
cfg: dict[str, Any] | None = None,
|
|
66
65
|
) -> None:
|
|
67
66
|
super().__init__()
|
|
68
67
|
self.vocab = vocab
|
|
@@ -79,19 +78,17 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
79
78
|
def compute_loss(
|
|
80
79
|
model_output: tf.Tensor,
|
|
81
80
|
gt: tf.Tensor,
|
|
82
|
-
seq_len:
|
|
81
|
+
seq_len: list[int],
|
|
83
82
|
) -> tf.Tensor:
|
|
84
83
|
"""Compute categorical cross-entropy loss for the model.
|
|
85
84
|
Sequences are masked after the EOS character.
|
|
86
85
|
|
|
87
86
|
Args:
|
|
88
|
-
----
|
|
89
87
|
model_output: predicted logits of the model
|
|
90
88
|
gt: the encoded tensor with gt labels
|
|
91
89
|
seq_len: lengths of each gt word inside the batch
|
|
92
90
|
|
|
93
91
|
Returns:
|
|
94
|
-
-------
|
|
95
92
|
The loss of the model on the batch
|
|
96
93
|
"""
|
|
97
94
|
# Input length : number of steps
|
|
@@ -114,11 +111,11 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
114
111
|
def call(
|
|
115
112
|
self,
|
|
116
113
|
x: tf.Tensor,
|
|
117
|
-
target:
|
|
114
|
+
target: list[str] | None = None,
|
|
118
115
|
return_model_output: bool = False,
|
|
119
116
|
return_preds: bool = False,
|
|
120
117
|
**kwargs: Any,
|
|
121
|
-
) ->
|
|
118
|
+
) -> dict[str, Any]:
|
|
122
119
|
features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
|
|
123
120
|
|
|
124
121
|
if target is not None:
|
|
@@ -136,7 +133,7 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
136
133
|
) # (batch_size, max_length, vocab + 1)
|
|
137
134
|
decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
|
|
138
135
|
|
|
139
|
-
out:
|
|
136
|
+
out: dict[str, tf.Tensor] = {}
|
|
140
137
|
if self.exportable:
|
|
141
138
|
out["logits"] = decoded_features
|
|
142
139
|
return out
|
|
@@ -158,14 +155,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
|
|
|
158
155
|
"""Post processor for ViTSTR architecture
|
|
159
156
|
|
|
160
157
|
Args:
|
|
161
|
-
----
|
|
162
158
|
vocab: string containing the ordered sequence of supported characters
|
|
163
159
|
"""
|
|
164
160
|
|
|
165
161
|
def __call__(
|
|
166
162
|
self,
|
|
167
163
|
logits: tf.Tensor,
|
|
168
|
-
) ->
|
|
164
|
+
) -> list[tuple[str, float]]:
|
|
169
165
|
# compute pred with argmax for attention models
|
|
170
166
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
171
167
|
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
|
|
@@ -191,7 +187,7 @@ def _vitstr(
|
|
|
191
187
|
arch: str,
|
|
192
188
|
pretrained: bool,
|
|
193
189
|
backbone_fn,
|
|
194
|
-
input_shape:
|
|
190
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
195
191
|
**kwargs: Any,
|
|
196
192
|
) -> ViTSTR:
|
|
197
193
|
# Patch the config
|
|
@@ -216,9 +212,14 @@ def _vitstr(
|
|
|
216
212
|
|
|
217
213
|
# Build the model
|
|
218
214
|
model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs)
|
|
215
|
+
_build_model(model)
|
|
216
|
+
|
|
219
217
|
# Load pretrained parameters
|
|
220
218
|
if pretrained:
|
|
221
|
-
|
|
219
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
220
|
+
load_pretrained_params(
|
|
221
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
222
|
+
)
|
|
222
223
|
|
|
223
224
|
return model
|
|
224
225
|
|
|
@@ -234,12 +235,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
234
235
|
>>> out = model(input_tensor)
|
|
235
236
|
|
|
236
237
|
Args:
|
|
237
|
-
----
|
|
238
238
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
239
239
|
**kwargs: keyword arguments of the ViTSTR architecture
|
|
240
240
|
|
|
241
241
|
Returns:
|
|
242
|
-
-------
|
|
243
242
|
text recognition architecture
|
|
244
243
|
"""
|
|
245
244
|
return _vitstr(
|
|
@@ -263,12 +262,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
263
262
|
>>> out = model(input_tensor)
|
|
264
263
|
|
|
265
264
|
Args:
|
|
266
|
-
----
|
|
267
265
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
268
266
|
**kwargs: keyword arguments of the ViTSTR architecture
|
|
269
267
|
|
|
270
268
|
Returns:
|
|
271
|
-
-------
|
|
272
269
|
text recognition architecture
|
|
273
270
|
"""
|
|
274
271
|
return _vitstr(
|
doctr/models/recognition/zoo.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.file_utils import is_tf_available
|
|
8
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
9
9
|
from doctr.models.preprocessor import PreProcessor
|
|
10
10
|
|
|
11
11
|
from .. import recognition
|
|
@@ -14,7 +14,7 @@ from .predictor import RecognitionPredictor
|
|
|
14
14
|
__all__ = ["recognition_predictor"]
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
ARCHS:
|
|
17
|
+
ARCHS: list[str] = [
|
|
18
18
|
"crnn_vgg16_bn",
|
|
19
19
|
"crnn_mobilenet_v3_small",
|
|
20
20
|
"crnn_mobilenet_v3_large",
|
|
@@ -35,9 +35,14 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
35
35
|
pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
|
|
36
36
|
)
|
|
37
37
|
else:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
38
|
+
allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
|
|
39
|
+
if is_torch_available():
|
|
40
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
41
|
+
from doctr.models.utils import _CompiledModule
|
|
42
|
+
|
|
43
|
+
allowed_archs.append(_CompiledModule)
|
|
44
|
+
|
|
45
|
+
if not isinstance(arch, tuple(allowed_archs)):
|
|
41
46
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
42
47
|
_model = arch
|
|
43
48
|
|
|
@@ -52,7 +57,13 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
52
57
|
return predictor
|
|
53
58
|
|
|
54
59
|
|
|
55
|
-
def recognition_predictor(
|
|
60
|
+
def recognition_predictor(
|
|
61
|
+
arch: Any = "crnn_vgg16_bn",
|
|
62
|
+
pretrained: bool = False,
|
|
63
|
+
symmetric_pad: bool = False,
|
|
64
|
+
batch_size: int = 128,
|
|
65
|
+
**kwargs: Any,
|
|
66
|
+
) -> RecognitionPredictor:
|
|
56
67
|
"""Text recognition architecture.
|
|
57
68
|
|
|
58
69
|
Example::
|
|
@@ -63,13 +74,13 @@ def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False,
|
|
|
63
74
|
>>> out = model([input_page])
|
|
64
75
|
|
|
65
76
|
Args:
|
|
66
|
-
----
|
|
67
77
|
arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
|
|
68
78
|
pretrained: If True, returns a model pre-trained on our text recognition dataset
|
|
79
|
+
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
|
|
80
|
+
batch_size: number of samples the model processes in parallel
|
|
69
81
|
**kwargs: optional parameters to be passed to the architecture
|
|
70
82
|
|
|
71
83
|
Returns:
|
|
72
|
-
-------
|
|
73
84
|
Recognition predictor
|
|
74
85
|
"""
|
|
75
|
-
return _predictor(arch, pretrained, **kwargs)
|
|
86
|
+
return _predictor(arch=arch, pretrained=pretrained, symmetric_pad=symmetric_pad, batch_size=batch_size, **kwargs)
|
doctr/models/utils/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
doctr/models/utils/pytorch.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from torch import nn
|
|
@@ -18,8 +18,12 @@ __all__ = [
|
|
|
18
18
|
"export_model_to_onnx",
|
|
19
19
|
"_copy_tensor",
|
|
20
20
|
"_bf16_to_float32",
|
|
21
|
+
"_CompiledModule",
|
|
21
22
|
]
|
|
22
23
|
|
|
24
|
+
# torch compiled model type
|
|
25
|
+
_CompiledModule = torch._dynamo.eval_frame.OptimizedModule
|
|
26
|
+
|
|
23
27
|
|
|
24
28
|
def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
|
|
25
29
|
return x.clone().detach()
|
|
@@ -32,9 +36,9 @@ def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
|
|
|
32
36
|
|
|
33
37
|
def load_pretrained_params(
|
|
34
38
|
model: nn.Module,
|
|
35
|
-
url:
|
|
36
|
-
hash_prefix:
|
|
37
|
-
ignore_keys:
|
|
39
|
+
url: str | None = None,
|
|
40
|
+
hash_prefix: str | None = None,
|
|
41
|
+
ignore_keys: list[str] | None = None,
|
|
38
42
|
**kwargs: Any,
|
|
39
43
|
) -> None:
|
|
40
44
|
"""Load a set of parameters onto a model
|
|
@@ -43,7 +47,6 @@ def load_pretrained_params(
|
|
|
43
47
|
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
|
|
44
48
|
|
|
45
49
|
Args:
|
|
46
|
-
----
|
|
47
50
|
model: the PyTorch model to be loaded
|
|
48
51
|
url: URL of the zipped set of parameters
|
|
49
52
|
hash_prefix: first characters of SHA256 expected hash
|
|
@@ -76,7 +79,7 @@ def conv_sequence_pt(
|
|
|
76
79
|
relu: bool = False,
|
|
77
80
|
bn: bool = False,
|
|
78
81
|
**kwargs: Any,
|
|
79
|
-
) ->
|
|
82
|
+
) -> list[nn.Module]:
|
|
80
83
|
"""Builds a convolutional-based layer sequence
|
|
81
84
|
|
|
82
85
|
>>> from torch.nn import Sequential
|
|
@@ -84,7 +87,6 @@ def conv_sequence_pt(
|
|
|
84
87
|
>>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3))
|
|
85
88
|
|
|
86
89
|
Args:
|
|
87
|
-
----
|
|
88
90
|
in_channels: number of input channels
|
|
89
91
|
out_channels: number of output channels
|
|
90
92
|
relu: whether ReLU should be used
|
|
@@ -92,13 +94,12 @@ def conv_sequence_pt(
|
|
|
92
94
|
**kwargs: additional arguments to be passed to the convolutional layer
|
|
93
95
|
|
|
94
96
|
Returns:
|
|
95
|
-
-------
|
|
96
97
|
list of layers
|
|
97
98
|
"""
|
|
98
99
|
# No bias before Batch norm
|
|
99
100
|
kwargs["bias"] = kwargs.get("bias", not bn)
|
|
100
101
|
# Add activation directly to the conv if there is no BN
|
|
101
|
-
conv_seq:
|
|
102
|
+
conv_seq: list[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)]
|
|
102
103
|
|
|
103
104
|
if bn:
|
|
104
105
|
conv_seq.append(nn.BatchNorm2d(out_channels))
|
|
@@ -110,8 +111,8 @@ def conv_sequence_pt(
|
|
|
110
111
|
|
|
111
112
|
|
|
112
113
|
def set_device_and_dtype(
|
|
113
|
-
model: Any, batches:
|
|
114
|
-
) ->
|
|
114
|
+
model: Any, batches: list[torch.Tensor], device: str | torch.device, dtype: torch.dtype
|
|
115
|
+
) -> tuple[Any, list[torch.Tensor]]:
|
|
115
116
|
"""Set the device and dtype of a model and its batches
|
|
116
117
|
|
|
117
118
|
>>> import torch
|
|
@@ -122,14 +123,12 @@ def set_device_and_dtype(
|
|
|
122
123
|
>>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16)
|
|
123
124
|
|
|
124
125
|
Args:
|
|
125
|
-
----
|
|
126
126
|
model: the model to be set
|
|
127
127
|
batches: the batches to be set
|
|
128
128
|
device: the device to be used
|
|
129
129
|
dtype: the dtype to be used
|
|
130
130
|
|
|
131
131
|
Returns:
|
|
132
|
-
-------
|
|
133
132
|
the model and batches set
|
|
134
133
|
"""
|
|
135
134
|
return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches]
|
|
@@ -145,14 +144,12 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
|
|
|
145
144
|
>>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
|
|
146
145
|
|
|
147
146
|
Args:
|
|
148
|
-
----
|
|
149
147
|
model: the PyTorch model to be exported
|
|
150
148
|
model_name: the name for the exported model
|
|
151
149
|
dummy_input: the dummy input to the model
|
|
152
150
|
kwargs: additional arguments to be passed to torch.onnx.export
|
|
153
151
|
|
|
154
152
|
Returns:
|
|
155
|
-
-------
|
|
156
153
|
the path to the exported model
|
|
157
154
|
"""
|
|
158
155
|
torch.onnx.export(
|