python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +2 -0
- doctr/datasets/cord.py +6 -4
- doctr/datasets/datasets/base.py +3 -2
- doctr/datasets/datasets/pytorch.py +4 -2
- doctr/datasets/datasets/tensorflow.py +4 -2
- doctr/datasets/detection.py +6 -3
- doctr/datasets/doc_artefacts.py +2 -1
- doctr/datasets/funsd.py +7 -8
- doctr/datasets/generator/base.py +3 -2
- doctr/datasets/generator/pytorch.py +3 -1
- doctr/datasets/generator/tensorflow.py +3 -1
- doctr/datasets/ic03.py +3 -2
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +6 -4
- doctr/datasets/iiithws.py +2 -1
- doctr/datasets/imgur5k.py +3 -2
- doctr/datasets/loader.py +4 -2
- doctr/datasets/mjsynth.py +2 -1
- doctr/datasets/ocr.py +2 -1
- doctr/datasets/orientation.py +40 -0
- doctr/datasets/recognition.py +3 -2
- doctr/datasets/sroie.py +2 -1
- doctr/datasets/svhn.py +2 -1
- doctr/datasets/svt.py +3 -2
- doctr/datasets/synthtext.py +2 -1
- doctr/datasets/utils.py +27 -11
- doctr/datasets/vocabs.py +26 -1
- doctr/datasets/wildreceipt.py +111 -0
- doctr/file_utils.py +3 -1
- doctr/io/elements.py +52 -35
- doctr/io/html.py +5 -3
- doctr/io/image/base.py +5 -4
- doctr/io/image/pytorch.py +12 -7
- doctr/io/image/tensorflow.py +11 -6
- doctr/io/pdf.py +5 -4
- doctr/io/reader.py +13 -5
- doctr/models/_utils.py +30 -53
- doctr/models/artefacts/barcode.py +4 -3
- doctr/models/artefacts/face.py +4 -2
- doctr/models/builder.py +58 -43
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +5 -2
- doctr/models/classification/magc_resnet/tensorflow.py +5 -2
- doctr/models/classification/mobilenet/pytorch.py +16 -4
- doctr/models/classification/mobilenet/tensorflow.py +29 -20
- doctr/models/classification/predictor/pytorch.py +3 -2
- doctr/models/classification/predictor/tensorflow.py +2 -1
- doctr/models/classification/resnet/pytorch.py +23 -13
- doctr/models/classification/resnet/tensorflow.py +33 -26
- doctr/models/classification/textnet/__init__.py +6 -0
- doctr/models/classification/textnet/pytorch.py +275 -0
- doctr/models/classification/textnet/tensorflow.py +267 -0
- doctr/models/classification/vgg/pytorch.py +4 -2
- doctr/models/classification/vgg/tensorflow.py +5 -2
- doctr/models/classification/vit/pytorch.py +9 -3
- doctr/models/classification/vit/tensorflow.py +9 -3
- doctr/models/classification/zoo.py +7 -2
- doctr/models/core.py +1 -1
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/pytorch.py +7 -1
- doctr/models/detection/_utils/tensorflow.py +7 -3
- doctr/models/detection/core.py +9 -3
- doctr/models/detection/differentiable_binarization/base.py +37 -25
- doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
- doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +256 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +12 -5
- doctr/models/detection/linknet/pytorch.py +28 -15
- doctr/models/detection/linknet/tensorflow.py +68 -88
- doctr/models/detection/predictor/pytorch.py +16 -6
- doctr/models/detection/predictor/tensorflow.py +13 -5
- doctr/models/detection/zoo.py +19 -16
- doctr/models/factory/hub.py +20 -10
- doctr/models/kie_predictor/base.py +2 -1
- doctr/models/kie_predictor/pytorch.py +28 -36
- doctr/models/kie_predictor/tensorflow.py +27 -27
- doctr/models/modules/__init__.py +1 -0
- doctr/models/modules/layers/__init__.py +6 -0
- doctr/models/modules/layers/pytorch.py +166 -0
- doctr/models/modules/layers/tensorflow.py +175 -0
- doctr/models/modules/transformer/pytorch.py +24 -22
- doctr/models/modules/transformer/tensorflow.py +6 -4
- doctr/models/modules/vision_transformer/pytorch.py +2 -4
- doctr/models/modules/vision_transformer/tensorflow.py +2 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
- doctr/models/predictor/base.py +14 -3
- doctr/models/predictor/pytorch.py +26 -29
- doctr/models/predictor/tensorflow.py +25 -22
- doctr/models/preprocessor/pytorch.py +14 -9
- doctr/models/preprocessor/tensorflow.py +10 -5
- doctr/models/recognition/core.py +4 -1
- doctr/models/recognition/crnn/pytorch.py +23 -16
- doctr/models/recognition/crnn/tensorflow.py +25 -17
- doctr/models/recognition/master/base.py +4 -1
- doctr/models/recognition/master/pytorch.py +20 -9
- doctr/models/recognition/master/tensorflow.py +20 -8
- doctr/models/recognition/parseq/base.py +4 -1
- doctr/models/recognition/parseq/pytorch.py +28 -22
- doctr/models/recognition/parseq/tensorflow.py +22 -11
- doctr/models/recognition/predictor/_utils.py +3 -2
- doctr/models/recognition/predictor/pytorch.py +3 -2
- doctr/models/recognition/predictor/tensorflow.py +2 -1
- doctr/models/recognition/sar/pytorch.py +14 -7
- doctr/models/recognition/sar/tensorflow.py +23 -14
- doctr/models/recognition/utils.py +5 -1
- doctr/models/recognition/vitstr/base.py +4 -1
- doctr/models/recognition/vitstr/pytorch.py +22 -13
- doctr/models/recognition/vitstr/tensorflow.py +21 -10
- doctr/models/recognition/zoo.py +4 -2
- doctr/models/utils/pytorch.py +24 -6
- doctr/models/utils/tensorflow.py +22 -3
- doctr/models/zoo.py +21 -3
- doctr/transforms/functional/base.py +8 -3
- doctr/transforms/functional/pytorch.py +23 -6
- doctr/transforms/functional/tensorflow.py +25 -5
- doctr/transforms/modules/base.py +12 -5
- doctr/transforms/modules/pytorch.py +10 -12
- doctr/transforms/modules/tensorflow.py +17 -9
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +4 -2
- doctr/utils/fonts.py +3 -2
- doctr/utils/geometry.py +95 -26
- doctr/utils/metrics.py +36 -22
- doctr/utils/multithreading.py +5 -3
- doctr/utils/repr.py +3 -1
- doctr/utils/visualization.py +31 -8
- doctr/version.py +1 -1
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
- python_doctr-0.8.1.dist-info/RECORD +173 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
- python_doctr-0.7.0.dist-info/RECORD +0 -161
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -14,7 +14,7 @@ from torchvision.models._utils import IntermediateLayerGetter
|
|
|
14
14
|
from doctr.datasets import VOCABS
|
|
15
15
|
|
|
16
16
|
from ...classification import vit_b, vit_s
|
|
17
|
-
from ...utils.pytorch import load_pretrained_params
|
|
17
|
+
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
|
|
18
18
|
from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
19
19
|
|
|
20
20
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
@@ -25,14 +25,14 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
25
25
|
"std": (0.299, 0.296, 0.301),
|
|
26
26
|
"input_shape": (3, 32, 128),
|
|
27
27
|
"vocab": VOCABS["french"],
|
|
28
|
-
"url":
|
|
28
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/vitstr_small-fcd12655.pt&src=0",
|
|
29
29
|
},
|
|
30
30
|
"vitstr_base": {
|
|
31
31
|
"mean": (0.694, 0.695, 0.693),
|
|
32
32
|
"std": (0.299, 0.296, 0.301),
|
|
33
33
|
"input_shape": (3, 32, 128),
|
|
34
34
|
"vocab": VOCABS["french"],
|
|
35
|
-
"url":
|
|
35
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/vitstr_base-50b21df2.pt&src=0",
|
|
36
36
|
},
|
|
37
37
|
}
|
|
38
38
|
|
|
@@ -42,6 +42,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
42
42
|
Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
|
+
----
|
|
45
46
|
feature_extractor: the backbone serving as feature extractor
|
|
46
47
|
vocab: vocabulary used for encoding
|
|
47
48
|
embedding_units: number of embedding units
|
|
@@ -95,7 +96,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
95
96
|
B, N, E = features.size()
|
|
96
97
|
features = features.reshape(B * N, E)
|
|
97
98
|
logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1)
|
|
98
|
-
decoded_features = logits[:, 1:] # remove cls_token
|
|
99
|
+
decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
|
|
99
100
|
|
|
100
101
|
out: Dict[str, Any] = {}
|
|
101
102
|
if self.exportable:
|
|
@@ -124,17 +125,19 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
124
125
|
Sequences are masked after the EOS character.
|
|
125
126
|
|
|
126
127
|
Args:
|
|
128
|
+
----
|
|
127
129
|
model_output: predicted logits of the model
|
|
128
130
|
gt: the encoded tensor with gt labels
|
|
129
131
|
seq_len: lengths of each gt word inside the batch
|
|
130
132
|
|
|
131
133
|
Returns:
|
|
134
|
+
-------
|
|
132
135
|
The loss of the model on the batch
|
|
133
136
|
"""
|
|
134
137
|
# Input length : number of steps
|
|
135
138
|
input_len = model_output.shape[1]
|
|
136
139
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
137
|
-
seq_len = seq_len + 1
|
|
140
|
+
seq_len = seq_len + 1 # type: ignore[assignment]
|
|
138
141
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
139
142
|
# The "masked" first gt char is <sos>.
|
|
140
143
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -150,6 +153,7 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
|
|
|
150
153
|
"""Post processor for ViTSTR architecture
|
|
151
154
|
|
|
152
155
|
Args:
|
|
156
|
+
----
|
|
153
157
|
vocab: string containing the ordered sequence of supported characters
|
|
154
158
|
"""
|
|
155
159
|
|
|
@@ -159,18 +163,19 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
|
|
|
159
163
|
) -> List[Tuple[str, float]]:
|
|
160
164
|
# compute pred with argmax for attention models
|
|
161
165
|
out_idxs = logits.argmax(-1)
|
|
162
|
-
|
|
163
|
-
probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
|
|
164
|
-
# Take the minimum confidence of the sequence
|
|
165
|
-
probs = probs.min(dim=1).values.detach().cpu()
|
|
166
|
+
preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
|
|
166
167
|
|
|
167
168
|
# Manual decoding
|
|
168
169
|
word_values = [
|
|
169
170
|
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
|
|
170
171
|
for encoded_seq in out_idxs.cpu().numpy()
|
|
171
172
|
]
|
|
173
|
+
# compute probabilties for each word up to the EOS token
|
|
174
|
+
probs = [
|
|
175
|
+
preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values)
|
|
176
|
+
]
|
|
172
177
|
|
|
173
|
-
return list(zip(word_values, probs
|
|
178
|
+
return list(zip(word_values, probs))
|
|
174
179
|
|
|
175
180
|
|
|
176
181
|
def _vitstr(
|
|
@@ -223,12 +228,14 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
223
228
|
>>> out = model(input_tensor)
|
|
224
229
|
|
|
225
230
|
Args:
|
|
231
|
+
----
|
|
226
232
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
233
|
+
kwargs: keyword arguments of the ViTSTR architecture
|
|
227
234
|
|
|
228
235
|
Returns:
|
|
236
|
+
-------
|
|
229
237
|
text recognition architecture
|
|
230
238
|
"""
|
|
231
|
-
|
|
232
239
|
return _vitstr(
|
|
233
240
|
"vitstr_small",
|
|
234
241
|
pretrained,
|
|
@@ -252,12 +259,14 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
252
259
|
>>> out = model(input_tensor)
|
|
253
260
|
|
|
254
261
|
Args:
|
|
262
|
+
----
|
|
255
263
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
264
|
+
kwargs: keyword arguments of the ViTSTR architecture
|
|
256
265
|
|
|
257
266
|
Returns:
|
|
267
|
+
-------
|
|
258
268
|
text recognition architecture
|
|
259
269
|
"""
|
|
260
|
-
|
|
261
270
|
return _vitstr(
|
|
262
271
|
"vitstr_base",
|
|
263
272
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -12,7 +12,7 @@ from tensorflow.keras import Model, layers
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
|
|
14
14
|
from ...classification import vit_b, vit_s
|
|
15
|
-
from ...utils.tensorflow import load_pretrained_params
|
|
15
|
+
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
16
16
|
from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
17
17
|
|
|
18
18
|
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
@@ -40,6 +40,7 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
40
40
|
Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
41
41
|
|
|
42
42
|
Args:
|
|
43
|
+
----
|
|
43
44
|
feature_extractor: the backbone serving as feature extractor
|
|
44
45
|
vocab: vocabulary used for encoding
|
|
45
46
|
embedding_units: number of embedding units
|
|
@@ -84,11 +85,13 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
84
85
|
Sequences are masked after the EOS character.
|
|
85
86
|
|
|
86
87
|
Args:
|
|
88
|
+
----
|
|
87
89
|
model_output: predicted logits of the model
|
|
88
90
|
gt: the encoded tensor with gt labels
|
|
89
91
|
seq_len: lengths of each gt word inside the batch
|
|
90
92
|
|
|
91
93
|
Returns:
|
|
94
|
+
-------
|
|
92
95
|
The loss of the model on the batch
|
|
93
96
|
"""
|
|
94
97
|
# Input length : number of steps
|
|
@@ -131,7 +134,7 @@ class ViTSTR(_ViTSTR, Model):
|
|
|
131
134
|
logits = tf.reshape(
|
|
132
135
|
self.head(features, **kwargs), (B, N, len(self.vocab) + 1)
|
|
133
136
|
) # (batch_size, max_length, vocab + 1)
|
|
134
|
-
decoded_features = logits[:, 1:] # remove cls_token
|
|
137
|
+
decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
|
|
135
138
|
|
|
136
139
|
out: Dict[str, tf.Tensor] = {}
|
|
137
140
|
if self.exportable:
|
|
@@ -155,6 +158,7 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
|
|
|
155
158
|
"""Post processor for ViTSTR architecture
|
|
156
159
|
|
|
157
160
|
Args:
|
|
161
|
+
----
|
|
158
162
|
vocab: string containing the ordered sequence of supported characters
|
|
159
163
|
"""
|
|
160
164
|
|
|
@@ -164,10 +168,7 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
|
|
|
164
168
|
) -> List[Tuple[str, float]]:
|
|
165
169
|
# compute pred with argmax for attention models
|
|
166
170
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
167
|
-
|
|
168
|
-
probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)
|
|
169
|
-
# Take the minimum confidence of the sequence
|
|
170
|
-
probs = tf.math.reduce_min(probs, axis=1)
|
|
171
|
+
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
|
|
171
172
|
|
|
172
173
|
# decode raw output of the model with tf_label_to_idx
|
|
173
174
|
out_idxs = tf.cast(out_idxs, dtype="int32")
|
|
@@ -177,7 +178,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
|
|
|
177
178
|
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
|
|
178
179
|
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
|
|
179
180
|
|
|
180
|
-
|
|
181
|
+
# compute probabilties for each word up to the EOS token
|
|
182
|
+
probs = [
|
|
183
|
+
preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
|
|
184
|
+
for i, word in enumerate(word_values)
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
return list(zip(word_values, probs))
|
|
181
188
|
|
|
182
189
|
|
|
183
190
|
def _vitstr(
|
|
@@ -227,12 +234,14 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
227
234
|
>>> out = model(input_tensor)
|
|
228
235
|
|
|
229
236
|
Args:
|
|
237
|
+
----
|
|
230
238
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
239
|
+
**kwargs: keyword arguments of the ViTSTR architecture
|
|
231
240
|
|
|
232
241
|
Returns:
|
|
242
|
+
-------
|
|
233
243
|
text recognition architecture
|
|
234
244
|
"""
|
|
235
|
-
|
|
236
245
|
return _vitstr(
|
|
237
246
|
"vitstr_small",
|
|
238
247
|
pretrained,
|
|
@@ -254,12 +263,14 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
|
254
263
|
>>> out = model(input_tensor)
|
|
255
264
|
|
|
256
265
|
Args:
|
|
266
|
+
----
|
|
257
267
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
268
|
+
**kwargs: keyword arguments of the ViTSTR architecture
|
|
258
269
|
|
|
259
270
|
Returns:
|
|
271
|
+
-------
|
|
260
272
|
text recognition architecture
|
|
261
273
|
"""
|
|
262
|
-
|
|
263
274
|
return _vitstr(
|
|
264
275
|
"vitstr_base",
|
|
265
276
|
pretrained,
|
doctr/models/recognition/zoo.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -63,11 +63,13 @@ def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False,
|
|
|
63
63
|
>>> out = model([input_page])
|
|
64
64
|
|
|
65
65
|
Args:
|
|
66
|
+
----
|
|
66
67
|
arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
|
|
67
68
|
pretrained: If True, returns a model pre-trained on our text recognition dataset
|
|
69
|
+
**kwargs: optional parameters to be passed to the architecture
|
|
68
70
|
|
|
69
71
|
Returns:
|
|
72
|
+
-------
|
|
70
73
|
Recognition predictor
|
|
71
74
|
"""
|
|
72
|
-
|
|
73
75
|
return _predictor(arch, pretrained, **kwargs)
|
doctr/models/utils/pytorch.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -11,18 +11,29 @@ from torch import nn
|
|
|
11
11
|
|
|
12
12
|
from doctr.utils.data import download_from_url
|
|
13
13
|
|
|
14
|
-
__all__ = [
|
|
14
|
+
__all__ = [
|
|
15
|
+
"load_pretrained_params",
|
|
16
|
+
"conv_sequence_pt",
|
|
17
|
+
"set_device_and_dtype",
|
|
18
|
+
"export_model_to_onnx",
|
|
19
|
+
"_copy_tensor",
|
|
20
|
+
"_bf16_to_float32",
|
|
21
|
+
]
|
|
15
22
|
|
|
16
23
|
|
|
17
24
|
def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
|
|
18
25
|
return x.clone().detach()
|
|
19
26
|
|
|
20
27
|
|
|
28
|
+
def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
|
|
29
|
+
# bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype
|
|
30
|
+
return x.float() if x.dtype == torch.bfloat16 else x
|
|
31
|
+
|
|
32
|
+
|
|
21
33
|
def load_pretrained_params(
|
|
22
34
|
model: nn.Module,
|
|
23
35
|
url: Optional[str] = None,
|
|
24
36
|
hash_prefix: Optional[str] = None,
|
|
25
|
-
overwrite: bool = False,
|
|
26
37
|
ignore_keys: Optional[List[str]] = None,
|
|
27
38
|
**kwargs: Any,
|
|
28
39
|
) -> None:
|
|
@@ -32,13 +43,13 @@ def load_pretrained_params(
|
|
|
32
43
|
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
|
|
33
44
|
|
|
34
45
|
Args:
|
|
46
|
+
----
|
|
35
47
|
model: the PyTorch model to be loaded
|
|
36
48
|
url: URL of the zipped set of parameters
|
|
37
49
|
hash_prefix: first characters of SHA256 expected hash
|
|
38
|
-
overwrite: should the zip extraction be enforced if the archive has already been extracted
|
|
39
50
|
ignore_keys: list of weights to be ignored from the state_dict
|
|
51
|
+
**kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
|
|
40
52
|
"""
|
|
41
|
-
|
|
42
53
|
if url is None:
|
|
43
54
|
logging.warning("Invalid model URL, using default initialization.")
|
|
44
55
|
else:
|
|
@@ -73,11 +84,15 @@ def conv_sequence_pt(
|
|
|
73
84
|
>>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3))
|
|
74
85
|
|
|
75
86
|
Args:
|
|
87
|
+
----
|
|
88
|
+
in_channels: number of input channels
|
|
76
89
|
out_channels: number of output channels
|
|
77
90
|
relu: whether ReLU should be used
|
|
78
91
|
bn: should a batch normalization layer be added
|
|
92
|
+
**kwargs: additional arguments to be passed to the convolutional layer
|
|
79
93
|
|
|
80
94
|
Returns:
|
|
95
|
+
-------
|
|
81
96
|
list of layers
|
|
82
97
|
"""
|
|
83
98
|
# No bias before Batch norm
|
|
@@ -107,15 +122,16 @@ def set_device_and_dtype(
|
|
|
107
122
|
>>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16)
|
|
108
123
|
|
|
109
124
|
Args:
|
|
125
|
+
----
|
|
110
126
|
model: the model to be set
|
|
111
127
|
batches: the batches to be set
|
|
112
128
|
device: the device to be used
|
|
113
129
|
dtype: the dtype to be used
|
|
114
130
|
|
|
115
131
|
Returns:
|
|
132
|
+
-------
|
|
116
133
|
the model and batches set
|
|
117
134
|
"""
|
|
118
|
-
|
|
119
135
|
return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches]
|
|
120
136
|
|
|
121
137
|
|
|
@@ -129,12 +145,14 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
|
|
|
129
145
|
>>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
|
|
130
146
|
|
|
131
147
|
Args:
|
|
148
|
+
----
|
|
132
149
|
model: the PyTorch model to be exported
|
|
133
150
|
model_name: the name for the exported model
|
|
134
151
|
dummy_input: the dummy input to the model
|
|
135
152
|
kwargs: additional arguments to be passed to torch.onnx.export
|
|
136
153
|
|
|
137
154
|
Returns:
|
|
155
|
+
-------
|
|
138
156
|
the path to the exported model
|
|
139
157
|
"""
|
|
140
158
|
torch.onnx.export(
|
doctr/models/utils/tensorflow.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -17,13 +17,25 @@ from doctr.utils.data import download_from_url
|
|
|
17
17
|
logging.getLogger("tensorflow").setLevel(logging.DEBUG)
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
__all__ = [
|
|
20
|
+
__all__ = [
|
|
21
|
+
"load_pretrained_params",
|
|
22
|
+
"conv_sequence",
|
|
23
|
+
"IntermediateLayerGetter",
|
|
24
|
+
"export_model_to_onnx",
|
|
25
|
+
"_copy_tensor",
|
|
26
|
+
"_bf16_to_float32",
|
|
27
|
+
]
|
|
21
28
|
|
|
22
29
|
|
|
23
30
|
def _copy_tensor(x: tf.Tensor) -> tf.Tensor:
|
|
24
31
|
return tf.identity(x)
|
|
25
32
|
|
|
26
33
|
|
|
34
|
+
def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor:
|
|
35
|
+
# Convert bfloat16 to float32 for numpy compatibility
|
|
36
|
+
return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x
|
|
37
|
+
|
|
38
|
+
|
|
27
39
|
def load_pretrained_params(
|
|
28
40
|
model: Model,
|
|
29
41
|
url: Optional[str] = None,
|
|
@@ -38,13 +50,14 @@ def load_pretrained_params(
|
|
|
38
50
|
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
|
|
39
51
|
|
|
40
52
|
Args:
|
|
53
|
+
----
|
|
41
54
|
model: the keras model to be loaded
|
|
42
55
|
url: URL of the zipped set of parameters
|
|
43
56
|
hash_prefix: first characters of SHA256 expected hash
|
|
44
57
|
overwrite: should the zip extraction be enforced if the archive has already been extracted
|
|
45
58
|
internal_name: name of the ckpt files
|
|
59
|
+
**kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
|
|
46
60
|
"""
|
|
47
|
-
|
|
48
61
|
if url is None:
|
|
49
62
|
logging.warning("Invalid model URL, using default initialization.")
|
|
50
63
|
else:
|
|
@@ -75,13 +88,16 @@ def conv_sequence(
|
|
|
75
88
|
>>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3]))
|
|
76
89
|
|
|
77
90
|
Args:
|
|
91
|
+
----
|
|
78
92
|
out_channels: number of output channels
|
|
79
93
|
activation: activation to be used (default: no activation)
|
|
80
94
|
bn: should a batch normalization layer be added
|
|
81
95
|
padding: padding scheme
|
|
82
96
|
kernel_initializer: kernel initializer
|
|
97
|
+
**kwargs: additional arguments to be passed to the convolutional layer
|
|
83
98
|
|
|
84
99
|
Returns:
|
|
100
|
+
-------
|
|
85
101
|
list of layers
|
|
86
102
|
"""
|
|
87
103
|
# No bias before Batch norm
|
|
@@ -109,6 +125,7 @@ class IntermediateLayerGetter(Model):
|
|
|
109
125
|
>>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers)
|
|
110
126
|
|
|
111
127
|
Args:
|
|
128
|
+
----
|
|
112
129
|
model: the model to extract feature maps from
|
|
113
130
|
layer_names: the list of layers to retrieve the feature map from
|
|
114
131
|
"""
|
|
@@ -134,12 +151,14 @@ def export_model_to_onnx(
|
|
|
134
151
|
>>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")])
|
|
135
152
|
|
|
136
153
|
Args:
|
|
154
|
+
----
|
|
137
155
|
model: the keras model to be exported
|
|
138
156
|
model_name: the name for the exported model
|
|
139
157
|
dummy_input: the dummy input to the model
|
|
140
158
|
kwargs: additional arguments to be passed to tf2onnx
|
|
141
159
|
|
|
142
160
|
Returns:
|
|
161
|
+
-------
|
|
143
162
|
the path to the exported model and a list with the output layer names
|
|
144
163
|
"""
|
|
145
164
|
large_model = kwargs.get("large_model", False)
|
doctr/models/zoo.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -24,6 +24,7 @@ def _predictor(
|
|
|
24
24
|
det_bs: int = 2,
|
|
25
25
|
reco_bs: int = 128,
|
|
26
26
|
detect_orientation: bool = False,
|
|
27
|
+
straighten_pages: bool = False,
|
|
27
28
|
detect_language: bool = False,
|
|
28
29
|
**kwargs,
|
|
29
30
|
) -> OCRPredictor:
|
|
@@ -53,6 +54,7 @@ def _predictor(
|
|
|
53
54
|
preserve_aspect_ratio=preserve_aspect_ratio,
|
|
54
55
|
symmetric_pad=symmetric_pad,
|
|
55
56
|
detect_orientation=detect_orientation,
|
|
57
|
+
straighten_pages=straighten_pages,
|
|
56
58
|
detect_language=detect_language,
|
|
57
59
|
**kwargs,
|
|
58
60
|
)
|
|
@@ -68,6 +70,7 @@ def ocr_predictor(
|
|
|
68
70
|
symmetric_pad: bool = True,
|
|
69
71
|
export_as_straight_boxes: bool = False,
|
|
70
72
|
detect_orientation: bool = False,
|
|
73
|
+
straighten_pages: bool = False,
|
|
71
74
|
detect_language: bool = False,
|
|
72
75
|
**kwargs: Any,
|
|
73
76
|
) -> OCRPredictor:
|
|
@@ -80,6 +83,7 @@ def ocr_predictor(
|
|
|
80
83
|
>>> out = model([input_page])
|
|
81
84
|
|
|
82
85
|
Args:
|
|
86
|
+
----
|
|
83
87
|
det_arch: name of the detection architecture or the model itself to use
|
|
84
88
|
(e.g. 'db_resnet50', 'db_mobilenet_v3_large')
|
|
85
89
|
reco_arch: name of the recognition architecture or the model itself to use
|
|
@@ -95,14 +99,18 @@ def ocr_predictor(
|
|
|
95
99
|
(potentially rotated) as straight bounding boxes.
|
|
96
100
|
detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
|
|
97
101
|
page. Doing so will slightly deteriorate the overall latency.
|
|
102
|
+
straighten_pages: if True, estimates the page general orientation
|
|
103
|
+
based on the segmentation map median line orientation.
|
|
104
|
+
Then, rotates page before passing it again to the deep learning detection module.
|
|
105
|
+
Doing so will improve performances for documents with page-uniform rotations.
|
|
98
106
|
detect_language: if True, the language prediction will be added to the predictions for each
|
|
99
107
|
page. Doing so will slightly deteriorate the overall latency.
|
|
100
108
|
kwargs: keyword args of `OCRPredictor`
|
|
101
109
|
|
|
102
110
|
Returns:
|
|
111
|
+
-------
|
|
103
112
|
OCR predictor
|
|
104
113
|
"""
|
|
105
|
-
|
|
106
114
|
return _predictor(
|
|
107
115
|
det_arch,
|
|
108
116
|
reco_arch,
|
|
@@ -113,6 +121,7 @@ def ocr_predictor(
|
|
|
113
121
|
symmetric_pad=symmetric_pad,
|
|
114
122
|
export_as_straight_boxes=export_as_straight_boxes,
|
|
115
123
|
detect_orientation=detect_orientation,
|
|
124
|
+
straighten_pages=straighten_pages,
|
|
116
125
|
detect_language=detect_language,
|
|
117
126
|
**kwargs,
|
|
118
127
|
)
|
|
@@ -129,6 +138,7 @@ def _kie_predictor(
|
|
|
129
138
|
det_bs: int = 2,
|
|
130
139
|
reco_bs: int = 128,
|
|
131
140
|
detect_orientation: bool = False,
|
|
141
|
+
straighten_pages: bool = False,
|
|
132
142
|
detect_language: bool = False,
|
|
133
143
|
**kwargs,
|
|
134
144
|
) -> KIEPredictor:
|
|
@@ -158,6 +168,7 @@ def _kie_predictor(
|
|
|
158
168
|
preserve_aspect_ratio=preserve_aspect_ratio,
|
|
159
169
|
symmetric_pad=symmetric_pad,
|
|
160
170
|
detect_orientation=detect_orientation,
|
|
171
|
+
straighten_pages=straighten_pages,
|
|
161
172
|
detect_language=detect_language,
|
|
162
173
|
**kwargs,
|
|
163
174
|
)
|
|
@@ -173,6 +184,7 @@ def kie_predictor(
|
|
|
173
184
|
symmetric_pad: bool = True,
|
|
174
185
|
export_as_straight_boxes: bool = False,
|
|
175
186
|
detect_orientation: bool = False,
|
|
187
|
+
straighten_pages: bool = False,
|
|
176
188
|
detect_language: bool = False,
|
|
177
189
|
**kwargs: Any,
|
|
178
190
|
) -> KIEPredictor:
|
|
@@ -185,6 +197,7 @@ def kie_predictor(
|
|
|
185
197
|
>>> out = model([input_page])
|
|
186
198
|
|
|
187
199
|
Args:
|
|
200
|
+
----
|
|
188
201
|
det_arch: name of the detection architecture or the model itself to use
|
|
189
202
|
(e.g. 'db_resnet50', 'db_mobilenet_v3_large')
|
|
190
203
|
reco_arch: name of the recognition architecture or the model itself to use
|
|
@@ -200,14 +213,18 @@ def kie_predictor(
|
|
|
200
213
|
(potentially rotated) as straight bounding boxes.
|
|
201
214
|
detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
|
|
202
215
|
page. Doing so will slightly deteriorate the overall latency.
|
|
216
|
+
straighten_pages: if True, estimates the page general orientation
|
|
217
|
+
based on the segmentation map median line orientation.
|
|
218
|
+
Then, rotates page before passing it again to the deep learning detection module.
|
|
219
|
+
Doing so will improve performances for documents with page-uniform rotations.
|
|
203
220
|
detect_language: if True, the language prediction will be added to the predictions for each
|
|
204
221
|
page. Doing so will slightly deteriorate the overall latency.
|
|
205
222
|
kwargs: keyword args of `OCRPredictor`
|
|
206
223
|
|
|
207
224
|
Returns:
|
|
225
|
+
-------
|
|
208
226
|
KIE predictor
|
|
209
227
|
"""
|
|
210
|
-
|
|
211
228
|
return _kie_predictor(
|
|
212
229
|
det_arch,
|
|
213
230
|
reco_arch,
|
|
@@ -218,6 +235,7 @@ def kie_predictor(
|
|
|
218
235
|
symmetric_pad=symmetric_pad,
|
|
219
236
|
export_as_straight_boxes=export_as_straight_boxes,
|
|
220
237
|
detect_orientation=detect_orientation,
|
|
238
|
+
straighten_pages=straighten_pages,
|
|
221
239
|
detect_language=detect_language,
|
|
222
240
|
**kwargs,
|
|
223
241
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -20,10 +20,12 @@ def crop_boxes(
|
|
|
20
20
|
"""Crop localization boxes
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
+
----
|
|
23
24
|
boxes: ndarray of shape (N, 4) in relative or abs coordinates
|
|
24
25
|
crop_box: box (xmin, ymin, xmax, ymax) to crop the image, in the same coord format that the boxes
|
|
25
26
|
|
|
26
27
|
Returns:
|
|
28
|
+
-------
|
|
27
29
|
the cropped boxes
|
|
28
30
|
"""
|
|
29
31
|
is_box_rel = boxes.max() <= 1
|
|
@@ -52,10 +54,12 @@ def expand_line(line: np.ndarray, target_shape: Tuple[int, int]) -> Tuple[float,
|
|
|
52
54
|
the same direction until we meet one of the edges.
|
|
53
55
|
|
|
54
56
|
Args:
|
|
57
|
+
----
|
|
55
58
|
line: array of shape (2, 2) of the point supposed to be on one edge, and the shadow tip.
|
|
56
59
|
target_shape: the desired mask shape
|
|
57
60
|
|
|
58
61
|
Returns:
|
|
62
|
+
-------
|
|
59
63
|
2D coordinates of the first point once we extended the line (on one of the edges)
|
|
60
64
|
"""
|
|
61
65
|
if any(coord == 0 or coord == size for coord, size in zip(line[0], target_shape[::-1])):
|
|
@@ -116,15 +120,16 @@ def create_shadow_mask(
|
|
|
116
120
|
"""Creates a random shadow mask
|
|
117
121
|
|
|
118
122
|
Args:
|
|
123
|
+
----
|
|
119
124
|
target_shape: the target shape (H, W)
|
|
120
125
|
min_base_width: the relative minimum shadow base width
|
|
121
126
|
max_tip_width: the relative maximum shadow tip width
|
|
122
127
|
max_tip_height: the relative maximum shadow tip height
|
|
123
128
|
|
|
124
129
|
Returns:
|
|
130
|
+
-------
|
|
125
131
|
a numpy ndarray of shape (H, W, 1) with values in the range [0, 1]
|
|
126
132
|
"""
|
|
127
|
-
|
|
128
133
|
# Default base is top
|
|
129
134
|
_params = np.random.rand(6)
|
|
130
135
|
base_width = min_base_width + (1 - min_base_width) * _params[0]
|
|
@@ -195,4 +200,4 @@ def create_shadow_mask(
|
|
|
195
200
|
mask: np.ndarray = np.zeros((*target_shape, 1), dtype=np.uint8)
|
|
196
201
|
mask = cv2.fillPoly(mask, [final_contour], (255,), lineType=cv2.LINE_AA)[..., 0]
|
|
197
202
|
|
|
198
|
-
return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32)
|
|
203
|
+
return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32) # type: ignore[operator]
|