python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +8 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +7 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +4 -5
- doctr/datasets/ic13.py +4 -5
- doctr/datasets/iiit5k.py +6 -5
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +6 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +6 -5
- doctr/datasets/svt.py +4 -5
- doctr/datasets/synthtext.py +4 -5
- doctr/datasets/utils.py +34 -29
- doctr/datasets/vocabs.py +17 -7
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +2 -7
- doctr/io/elements.py +59 -79
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +30 -48
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +8 -11
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +5 -17
- doctr/models/classification/mobilenet/tensorflow.py +8 -21
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +6 -8
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +20 -31
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +8 -15
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +9 -12
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +6 -12
- doctr/models/classification/zoo.py +19 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +14 -26
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +14 -23
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +3 -7
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +18 -19
- doctr/models/kie_predictor/tensorflow.py +13 -14
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +12 -13
- doctr/models/predictor/tensorflow.py +8 -9
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +11 -23
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +12 -22
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +16 -22
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +12 -21
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +12 -20
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +14 -17
- doctr/models/utils/tensorflow.py +17 -16
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +16 -47
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import Model, layers
|
|
@@ -17,7 +17,7 @@ from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
|
17
17
|
|
|
18
18
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
19
19
|
|
|
20
|
-
default_cfgs:
|
|
20
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
21
21
|
"vitstr_small": {
|
|
22
22
|
"mean": (0.694, 0.695, 0.693),
|
|
23
23
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -40,7 +40,6 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
40
40
|
Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
41
41
|
|
|
42
42
|
Args:
|
|
43
|
-
----
|
|
44
43
|
feature_extractor: the backbone serving as feature extractor
|
|
45
44
|
vocab: vocabulary used for encoding
|
|
46
45
|
embedding_units: number of embedding units
|
|
@@ -51,7 +50,7 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
51
50
|
cfg: dictionary containing information about the model
|
|
52
51
|
"""
|
|
53
52
|
|
|
54
|
-
_children_names:
|
|
53
|
+
_children_names: list[str] = ["feat_extractor", "postprocessor"]
|
|
55
54
|
|
|
56
55
|
def __init__(
|
|
57
56
|
self,
|
|
@@ -60,9 +59,9 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
60
59
|
embedding_units: int,
|
|
61
60
|
max_length: int = 32,
|
|
62
61
|
dropout_prob: float = 0.0,
|
|
63
|
-
input_shape:
|
|
62
|
+
input_shape: tuple[int, int, int] = (32, 128, 3), # different from paper
|
|
64
63
|
exportable: bool = False,
|
|
65
|
-
cfg:
|
|
64
|
+
cfg: dict[str, Any] | None = None,
|
|
66
65
|
) -> None:
|
|
67
66
|
super().__init__()
|
|
68
67
|
self.vocab = vocab
|
|
@@ -79,19 +78,17 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
79
78
|
def compute_loss(
|
|
80
79
|
model_output: tf.Tensor,
|
|
81
80
|
gt: tf.Tensor,
|
|
82
|
-
seq_len:
|
|
81
|
+
seq_len: list[int],
|
|
83
82
|
) -> tf.Tensor:
|
|
84
83
|
"""Compute categorical cross-entropy loss for the model.
|
|
85
84
|
Sequences are masked after the EOS character.
|
|
86
85
|
|
|
87
86
|
Args:
|
|
88
|
-
----
|
|
89
87
|
model_output: predicted logits of the model
|
|
90
88
|
gt: the encoded tensor with gt labels
|
|
91
89
|
seq_len: lengths of each gt word inside the batch
|
|
92
90
|
|
|
93
91
|
Returns:
|
|
94
|
-
-------
|
|
95
92
|
The loss of the model on the batch
|
|
96
93
|
"""
|
|
97
94
|
# Input length : number of steps
|
|
@@ -114,11 +111,11 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
114
111
|
def call(
|
|
115
112
|
self,
|
|
116
113
|
x: tf.Tensor,
|
|
117
|
-
target:
|
|
114
|
+
target: list[str] | None = None,
|
|
118
115
|
return_model_output: bool = False,
|
|
119
116
|
return_preds: bool = False,
|
|
120
117
|
**kwargs: Any,
|
|
121
|
-
) ->
|
|
118
|
+
) -> dict[str, Any]:
|
|
122
119
|
features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
|
|
123
120
|
|
|
124
121
|
if target is not None:
|
|
@@ -136,7 +133,7 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
136
133
|
) # (batch_size, max_length, vocab + 1)
|
|
137
134
|
decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
|
|
138
135
|
|
|
139
|
-
out:
|
|
136
|
+
out: dict[str, tf.Tensor] = {}
|
|
140
137
|
if self.exportable:
|
|
141
138
|
out["logits"] = decoded_features
|
|
142
139
|
return out
|
|
@@ -158,14 +155,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
|
|
|
158
155
|
"""Post processor for ViTSTR architecture
|
|
159
156
|
|
|
160
157
|
Args:
|
|
161
|
-
----
|
|
162
158
|
vocab: string containing the ordered sequence of supported characters
|
|
163
159
|
"""
|
|
164
160
|
|
|
165
161
|
def __call__(
|
|
166
162
|
self,
|
|
167
163
|
logits: tf.Tensor,
|
|
168
|
-
) ->
|
|
164
|
+
) -> list[tuple[str, float]]:
|
|
169
165
|
# compute pred with argmax for attention models
|
|
170
166
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
171
167
|
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
|
|
@@ -191,7 +187,7 @@ def _vitstr(
|
|
|
191
187
|
arch: str,
|
|
192
188
|
pretrained: bool,
|
|
193
189
|
backbone_fn,
|
|
194
|
-
input_shape:
|
|
190
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
195
191
|
**kwargs: Any,
|
|
196
192
|
) -> ViTSTR:
|
|
197
193
|
# Patch the config
|
|
@@ -239,12 +235,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
239
235
|
>>> out = model(input_tensor)
|
|
240
236
|
|
|
241
237
|
Args:
|
|
242
|
-
----
|
|
243
238
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
244
239
|
**kwargs: keyword arguments of the ViTSTR architecture
|
|
245
240
|
|
|
246
241
|
Returns:
|
|
247
|
-
-------
|
|
248
242
|
text recognition architecture
|
|
249
243
|
"""
|
|
250
244
|
return _vitstr(
|
|
@@ -268,12 +262,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
268
262
|
>>> out = model(input_tensor)
|
|
269
263
|
|
|
270
264
|
Args:
|
|
271
|
-
----
|
|
272
265
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
273
266
|
**kwargs: keyword arguments of the ViTSTR architecture
|
|
274
267
|
|
|
275
268
|
Returns:
|
|
276
|
-
-------
|
|
277
269
|
text recognition architecture
|
|
278
270
|
"""
|
|
279
271
|
return _vitstr(
|
doctr/models/recognition/zoo.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.file_utils import is_tf_available
|
|
8
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
9
9
|
from doctr.models.preprocessor import PreProcessor
|
|
10
10
|
|
|
11
11
|
from .. import recognition
|
|
@@ -14,7 +14,7 @@ from .predictor import RecognitionPredictor
|
|
|
14
14
|
__all__ = ["recognition_predictor"]
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
ARCHS:
|
|
17
|
+
ARCHS: list[str] = [
|
|
18
18
|
"crnn_vgg16_bn",
|
|
19
19
|
"crnn_mobilenet_v3_small",
|
|
20
20
|
"crnn_mobilenet_v3_large",
|
|
@@ -35,9 +35,14 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
35
35
|
pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
|
|
36
36
|
)
|
|
37
37
|
else:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
38
|
+
allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
|
|
39
|
+
if is_torch_available():
|
|
40
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
41
|
+
from doctr.models.utils import _CompiledModule
|
|
42
|
+
|
|
43
|
+
allowed_archs.append(_CompiledModule)
|
|
44
|
+
|
|
45
|
+
if not isinstance(arch, tuple(allowed_archs)):
|
|
41
46
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
42
47
|
_model = arch
|
|
43
48
|
|
|
@@ -52,7 +57,13 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
52
57
|
return predictor
|
|
53
58
|
|
|
54
59
|
|
|
55
|
-
def recognition_predictor(
|
|
60
|
+
def recognition_predictor(
|
|
61
|
+
arch: Any = "crnn_vgg16_bn",
|
|
62
|
+
pretrained: bool = False,
|
|
63
|
+
symmetric_pad: bool = False,
|
|
64
|
+
batch_size: int = 128,
|
|
65
|
+
**kwargs: Any,
|
|
66
|
+
) -> RecognitionPredictor:
|
|
56
67
|
"""Text recognition architecture.
|
|
57
68
|
|
|
58
69
|
Example::
|
|
@@ -63,13 +74,13 @@ def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False,
|
|
|
63
74
|
>>> out = model([input_page])
|
|
64
75
|
|
|
65
76
|
Args:
|
|
66
|
-
----
|
|
67
77
|
arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
|
|
68
78
|
pretrained: If True, returns a model pre-trained on our text recognition dataset
|
|
79
|
+
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
|
|
80
|
+
batch_size: number of samples the model processes in parallel
|
|
69
81
|
**kwargs: optional parameters to be passed to the architecture
|
|
70
82
|
|
|
71
83
|
Returns:
|
|
72
|
-
-------
|
|
73
84
|
Recognition predictor
|
|
74
85
|
"""
|
|
75
|
-
return _predictor(arch, pretrained, **kwargs)
|
|
86
|
+
return _predictor(arch=arch, pretrained=pretrained, symmetric_pad=symmetric_pad, batch_size=batch_size, **kwargs)
|
doctr/models/utils/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
elif
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
doctr/models/utils/pytorch.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from torch import nn
|
|
@@ -18,8 +18,12 @@ __all__ = [
|
|
|
18
18
|
"export_model_to_onnx",
|
|
19
19
|
"_copy_tensor",
|
|
20
20
|
"_bf16_to_float32",
|
|
21
|
+
"_CompiledModule",
|
|
21
22
|
]
|
|
22
23
|
|
|
24
|
+
# torch compiled model type
|
|
25
|
+
_CompiledModule = torch._dynamo.eval_frame.OptimizedModule
|
|
26
|
+
|
|
23
27
|
|
|
24
28
|
def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
|
|
25
29
|
return x.clone().detach()
|
|
@@ -32,9 +36,9 @@ def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
|
|
|
32
36
|
|
|
33
37
|
def load_pretrained_params(
|
|
34
38
|
model: nn.Module,
|
|
35
|
-
url:
|
|
36
|
-
hash_prefix:
|
|
37
|
-
ignore_keys:
|
|
39
|
+
url: str | None = None,
|
|
40
|
+
hash_prefix: str | None = None,
|
|
41
|
+
ignore_keys: list[str] | None = None,
|
|
38
42
|
**kwargs: Any,
|
|
39
43
|
) -> None:
|
|
40
44
|
"""Load a set of parameters onto a model
|
|
@@ -43,7 +47,6 @@ def load_pretrained_params(
|
|
|
43
47
|
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
|
|
44
48
|
|
|
45
49
|
Args:
|
|
46
|
-
----
|
|
47
50
|
model: the PyTorch model to be loaded
|
|
48
51
|
url: URL of the zipped set of parameters
|
|
49
52
|
hash_prefix: first characters of SHA256 expected hash
|
|
@@ -76,7 +79,7 @@ def conv_sequence_pt(
|
|
|
76
79
|
relu: bool = False,
|
|
77
80
|
bn: bool = False,
|
|
78
81
|
**kwargs: Any,
|
|
79
|
-
) ->
|
|
82
|
+
) -> list[nn.Module]:
|
|
80
83
|
"""Builds a convolutional-based layer sequence
|
|
81
84
|
|
|
82
85
|
>>> from torch.nn import Sequential
|
|
@@ -84,7 +87,6 @@ def conv_sequence_pt(
|
|
|
84
87
|
>>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3))
|
|
85
88
|
|
|
86
89
|
Args:
|
|
87
|
-
----
|
|
88
90
|
in_channels: number of input channels
|
|
89
91
|
out_channels: number of output channels
|
|
90
92
|
relu: whether ReLU should be used
|
|
@@ -92,13 +94,12 @@ def conv_sequence_pt(
|
|
|
92
94
|
**kwargs: additional arguments to be passed to the convolutional layer
|
|
93
95
|
|
|
94
96
|
Returns:
|
|
95
|
-
-------
|
|
96
97
|
list of layers
|
|
97
98
|
"""
|
|
98
99
|
# No bias before Batch norm
|
|
99
100
|
kwargs["bias"] = kwargs.get("bias", not bn)
|
|
100
101
|
# Add activation directly to the conv if there is no BN
|
|
101
|
-
conv_seq:
|
|
102
|
+
conv_seq: list[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)]
|
|
102
103
|
|
|
103
104
|
if bn:
|
|
104
105
|
conv_seq.append(nn.BatchNorm2d(out_channels))
|
|
@@ -110,8 +111,8 @@ def conv_sequence_pt(
|
|
|
110
111
|
|
|
111
112
|
|
|
112
113
|
def set_device_and_dtype(
|
|
113
|
-
model: Any, batches:
|
|
114
|
-
) ->
|
|
114
|
+
model: Any, batches: list[torch.Tensor], device: str | torch.device, dtype: torch.dtype
|
|
115
|
+
) -> tuple[Any, list[torch.Tensor]]:
|
|
115
116
|
"""Set the device and dtype of a model and its batches
|
|
116
117
|
|
|
117
118
|
>>> import torch
|
|
@@ -122,14 +123,12 @@ def set_device_and_dtype(
|
|
|
122
123
|
>>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16)
|
|
123
124
|
|
|
124
125
|
Args:
|
|
125
|
-
----
|
|
126
126
|
model: the model to be set
|
|
127
127
|
batches: the batches to be set
|
|
128
128
|
device: the device to be used
|
|
129
129
|
dtype: the dtype to be used
|
|
130
130
|
|
|
131
131
|
Returns:
|
|
132
|
-
-------
|
|
133
132
|
the model and batches set
|
|
134
133
|
"""
|
|
135
134
|
return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches]
|
|
@@ -145,19 +144,17 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
|
|
|
145
144
|
>>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
|
|
146
145
|
|
|
147
146
|
Args:
|
|
148
|
-
----
|
|
149
147
|
model: the PyTorch model to be exported
|
|
150
148
|
model_name: the name for the exported model
|
|
151
149
|
dummy_input: the dummy input to the model
|
|
152
150
|
kwargs: additional arguments to be passed to torch.onnx.export
|
|
153
151
|
|
|
154
152
|
Returns:
|
|
155
|
-
-------
|
|
156
153
|
the path to the exported model
|
|
157
154
|
"""
|
|
158
155
|
torch.onnx.export(
|
|
159
156
|
model,
|
|
160
|
-
dummy_input,
|
|
157
|
+
dummy_input,
|
|
161
158
|
f"{model_name}.onnx",
|
|
162
159
|
input_names=["input"],
|
|
163
160
|
output_names=["logits"],
|
doctr/models/utils/tensorflow.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import tensorflow as tf
|
|
10
11
|
import tf2onnx
|
|
@@ -39,7 +40,6 @@ def _build_model(model: Model):
|
|
|
39
40
|
"""Build a model by calling it once with dummy input
|
|
40
41
|
|
|
41
42
|
Args:
|
|
42
|
-
----
|
|
43
43
|
model: the model to be built
|
|
44
44
|
"""
|
|
45
45
|
model(tf.zeros((1, *model.cfg["input_shape"])), training=False)
|
|
@@ -47,8 +47,8 @@ def _build_model(model: Model):
|
|
|
47
47
|
|
|
48
48
|
def load_pretrained_params(
|
|
49
49
|
model: Model,
|
|
50
|
-
url:
|
|
51
|
-
hash_prefix:
|
|
50
|
+
url: str | None = None,
|
|
51
|
+
hash_prefix: str | None = None,
|
|
52
52
|
skip_mismatch: bool = False,
|
|
53
53
|
**kwargs: Any,
|
|
54
54
|
) -> None:
|
|
@@ -58,7 +58,6 @@ def load_pretrained_params(
|
|
|
58
58
|
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5")
|
|
59
59
|
|
|
60
60
|
Args:
|
|
61
|
-
----
|
|
62
61
|
model: the keras model to be loaded
|
|
63
62
|
url: URL of the zipped set of parameters
|
|
64
63
|
hash_prefix: first characters of SHA256 expected hash
|
|
@@ -75,12 +74,12 @@ def load_pretrained_params(
|
|
|
75
74
|
|
|
76
75
|
def conv_sequence(
|
|
77
76
|
out_channels: int,
|
|
78
|
-
activation:
|
|
77
|
+
activation: str | Callable | None = None,
|
|
79
78
|
bn: bool = False,
|
|
80
79
|
padding: str = "same",
|
|
81
80
|
kernel_initializer: str = "he_normal",
|
|
82
81
|
**kwargs: Any,
|
|
83
|
-
) ->
|
|
82
|
+
) -> list[layers.Layer]:
|
|
84
83
|
"""Builds a convolutional-based layer sequence
|
|
85
84
|
|
|
86
85
|
>>> from tensorflow.keras import Sequential
|
|
@@ -88,7 +87,6 @@ def conv_sequence(
|
|
|
88
87
|
>>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3]))
|
|
89
88
|
|
|
90
89
|
Args:
|
|
91
|
-
----
|
|
92
90
|
out_channels: number of output channels
|
|
93
91
|
activation: activation to be used (default: no activation)
|
|
94
92
|
bn: should a batch normalization layer be added
|
|
@@ -97,7 +95,6 @@ def conv_sequence(
|
|
|
97
95
|
**kwargs: additional arguments to be passed to the convolutional layer
|
|
98
96
|
|
|
99
97
|
Returns:
|
|
100
|
-
-------
|
|
101
98
|
list of layers
|
|
102
99
|
"""
|
|
103
100
|
# No bias before Batch norm
|
|
@@ -125,12 +122,11 @@ class IntermediateLayerGetter(Model):
|
|
|
125
122
|
>>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers)
|
|
126
123
|
|
|
127
124
|
Args:
|
|
128
|
-
----
|
|
129
125
|
model: the model to extract feature maps from
|
|
130
126
|
layer_names: the list of layers to retrieve the feature map from
|
|
131
127
|
"""
|
|
132
128
|
|
|
133
|
-
def __init__(self, model: Model, layer_names:
|
|
129
|
+
def __init__(self, model: Model, layer_names: list[str]) -> None:
|
|
134
130
|
intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names]
|
|
135
131
|
super().__init__(model.input, outputs=intermediate_fmaps)
|
|
136
132
|
|
|
@@ -139,8 +135,8 @@ class IntermediateLayerGetter(Model):
|
|
|
139
135
|
|
|
140
136
|
|
|
141
137
|
def export_model_to_onnx(
|
|
142
|
-
model: Model, model_name: str, dummy_input:
|
|
143
|
-
) ->
|
|
138
|
+
model: Model, model_name: str, dummy_input: list[tf.TensorSpec], **kwargs: Any
|
|
139
|
+
) -> tuple[str, list[str]]:
|
|
144
140
|
"""Export model to ONNX format.
|
|
145
141
|
|
|
146
142
|
>>> import tensorflow as tf
|
|
@@ -151,16 +147,18 @@ def export_model_to_onnx(
|
|
|
151
147
|
>>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")])
|
|
152
148
|
|
|
153
149
|
Args:
|
|
154
|
-
----
|
|
155
150
|
model: the keras model to be exported
|
|
156
151
|
model_name: the name for the exported model
|
|
157
152
|
dummy_input: the dummy input to the model
|
|
158
153
|
kwargs: additional arguments to be passed to tf2onnx
|
|
159
154
|
|
|
160
155
|
Returns:
|
|
161
|
-
-------
|
|
162
156
|
the path to the exported model and a list with the output layer names
|
|
163
157
|
"""
|
|
158
|
+
# get the users eager mode
|
|
159
|
+
eager_mode = tf.executing_eagerly()
|
|
160
|
+
# set eager mode to true to avoid issues with tf2onnx
|
|
161
|
+
tf.config.run_functions_eagerly(True)
|
|
164
162
|
large_model = kwargs.get("large_model", False)
|
|
165
163
|
model_proto, _ = tf2onnx.convert.from_keras(
|
|
166
164
|
model,
|
|
@@ -171,6 +169,9 @@ def export_model_to_onnx(
|
|
|
171
169
|
# Get the output layer names
|
|
172
170
|
output = [n.name for n in model_proto.graph.output]
|
|
173
171
|
|
|
172
|
+
# reset the eager mode to the users mode
|
|
173
|
+
tf.config.run_functions_eagerly(eager_mode)
|
|
174
|
+
|
|
174
175
|
# models which are too large (weights > 2GB while converting to ONNX) needs to be handled
|
|
175
176
|
# about an external tensor storage where the graph and weights are seperatly stored in a archive
|
|
176
177
|
if large_model:
|
doctr/models/zoo.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -83,7 +83,6 @@ def ocr_predictor(
|
|
|
83
83
|
>>> out = model([input_page])
|
|
84
84
|
|
|
85
85
|
Args:
|
|
86
|
-
----
|
|
87
86
|
det_arch: name of the detection architecture or the model itself to use
|
|
88
87
|
(e.g. 'db_resnet50', 'db_mobilenet_v3_large')
|
|
89
88
|
reco_arch: name of the recognition architecture or the model itself to use
|
|
@@ -108,7 +107,6 @@ def ocr_predictor(
|
|
|
108
107
|
kwargs: keyword args of `OCRPredictor`
|
|
109
108
|
|
|
110
109
|
Returns:
|
|
111
|
-
-------
|
|
112
110
|
OCR predictor
|
|
113
111
|
"""
|
|
114
112
|
return _predictor(
|
|
@@ -197,7 +195,6 @@ def kie_predictor(
|
|
|
197
195
|
>>> out = model([input_page])
|
|
198
196
|
|
|
199
197
|
Args:
|
|
200
|
-
----
|
|
201
198
|
det_arch: name of the detection architecture or the model itself to use
|
|
202
199
|
(e.g. 'db_resnet50', 'db_mobilenet_v3_large')
|
|
203
200
|
reco_arch: name of the recognition architecture or the model itself to use
|
|
@@ -222,7 +219,6 @@ def kie_predictor(
|
|
|
222
219
|
kwargs: keyword args of `OCRPredictor`
|
|
223
220
|
|
|
224
221
|
Returns:
|
|
225
|
-
-------
|
|
226
222
|
KIE predictor
|
|
227
223
|
"""
|
|
228
224
|
return _kie_predictor(
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Tuple, Union
|
|
7
6
|
|
|
8
7
|
import cv2
|
|
9
8
|
import numpy as np
|
|
@@ -15,17 +14,15 @@ __all__ = ["crop_boxes", "create_shadow_mask"]
|
|
|
15
14
|
|
|
16
15
|
def crop_boxes(
|
|
17
16
|
boxes: np.ndarray,
|
|
18
|
-
crop_box:
|
|
17
|
+
crop_box: tuple[int, int, int, int] | tuple[float, float, float, float],
|
|
19
18
|
) -> np.ndarray:
|
|
20
19
|
"""Crop localization boxes
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
----
|
|
24
22
|
boxes: ndarray of shape (N, 4) in relative or abs coordinates
|
|
25
23
|
crop_box: box (xmin, ymin, xmax, ymax) to crop the image, in the same coord format that the boxes
|
|
26
24
|
|
|
27
25
|
Returns:
|
|
28
|
-
-------
|
|
29
26
|
the cropped boxes
|
|
30
27
|
"""
|
|
31
28
|
is_box_rel = boxes.max() <= 1
|
|
@@ -49,17 +46,15 @@ def crop_boxes(
|
|
|
49
46
|
return boxes[is_valid]
|
|
50
47
|
|
|
51
48
|
|
|
52
|
-
def expand_line(line: np.ndarray, target_shape:
|
|
49
|
+
def expand_line(line: np.ndarray, target_shape: tuple[int, int]) -> tuple[float, float]:
|
|
53
50
|
"""Expands a 2-point line, so that the first is on the edge. In other terms, we extend the line in
|
|
54
51
|
the same direction until we meet one of the edges.
|
|
55
52
|
|
|
56
53
|
Args:
|
|
57
|
-
----
|
|
58
54
|
line: array of shape (2, 2) of the point supposed to be on one edge, and the shadow tip.
|
|
59
55
|
target_shape: the desired mask shape
|
|
60
56
|
|
|
61
57
|
Returns:
|
|
62
|
-
-------
|
|
63
58
|
2D coordinates of the first point once we extended the line (on one of the edges)
|
|
64
59
|
"""
|
|
65
60
|
if any(coord == 0 or coord == size for coord, size in zip(line[0], target_shape[::-1])):
|
|
@@ -112,7 +107,7 @@ def expand_line(line: np.ndarray, target_shape: Tuple[int, int]) -> Tuple[float,
|
|
|
112
107
|
|
|
113
108
|
|
|
114
109
|
def create_shadow_mask(
|
|
115
|
-
target_shape:
|
|
110
|
+
target_shape: tuple[int, int],
|
|
116
111
|
min_base_width=0.3,
|
|
117
112
|
max_tip_width=0.5,
|
|
118
113
|
max_tip_height=0.3,
|
|
@@ -120,14 +115,12 @@ def create_shadow_mask(
|
|
|
120
115
|
"""Creates a random shadow mask
|
|
121
116
|
|
|
122
117
|
Args:
|
|
123
|
-
----
|
|
124
118
|
target_shape: the target shape (H, W)
|
|
125
119
|
min_base_width: the relative minimum shadow base width
|
|
126
120
|
max_tip_width: the relative maximum shadow tip width
|
|
127
121
|
max_tip_height: the relative maximum shadow tip height
|
|
128
122
|
|
|
129
123
|
Returns:
|
|
130
|
-
-------
|
|
131
124
|
a numpy ndarray of shape (H, W, 1) with values in the range [0, 1]
|
|
132
125
|
"""
|
|
133
126
|
# Default base is top
|