python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from tensorflow.keras import Model, layers
|
|
@@ -19,7 +19,7 @@ from .base import _MASTER, _MASTERPostProcessor
|
|
|
19
19
|
__all__ = ["MASTER", "master"]
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
default_cfgs:
|
|
22
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
23
|
"master": {
|
|
24
24
|
"mean": (0.694, 0.695, 0.693),
|
|
25
25
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -35,7 +35,6 @@ class MASTER(_MASTER, Model):
|
|
|
35
35
|
Implementation based on the official TF implementation: <https://github.com/jiangxiluning/MASTER-TF>`_.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
|
-
----
|
|
39
38
|
feature_extractor: the backbone serving as feature extractor
|
|
40
39
|
vocab: vocabulary, (without EOS, SOS, PAD)
|
|
41
40
|
d_model: d parameter for the transformer decoder
|
|
@@ -59,9 +58,9 @@ class MASTER(_MASTER, Model):
|
|
|
59
58
|
num_layers: int = 3,
|
|
60
59
|
max_length: int = 50,
|
|
61
60
|
dropout: float = 0.2,
|
|
62
|
-
input_shape:
|
|
61
|
+
input_shape: tuple[int, int, int] = (32, 128, 3), # different from the paper
|
|
63
62
|
exportable: bool = False,
|
|
64
|
-
cfg:
|
|
63
|
+
cfg: dict[str, Any] | None = None,
|
|
65
64
|
) -> None:
|
|
66
65
|
super().__init__()
|
|
67
66
|
|
|
@@ -88,8 +87,17 @@ class MASTER(_MASTER, Model):
|
|
|
88
87
|
self.linear = layers.Dense(self.vocab_size + 3, kernel_initializer=tf.initializers.he_uniform())
|
|
89
88
|
self.postprocessor = MASTERPostProcessor(vocab=self.vocab)
|
|
90
89
|
|
|
90
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
91
|
+
"""Load pretrained parameters onto the model
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
95
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
96
|
+
"""
|
|
97
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
98
|
+
|
|
91
99
|
@tf.function
|
|
92
|
-
def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) ->
|
|
100
|
+
def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
|
|
93
101
|
# [1, 1, 1, ..., 0, 0, 0] -> 0 is masked
|
|
94
102
|
# (N, 1, 1, max_length)
|
|
95
103
|
target_pad_mask = tf.cast(tf.math.not_equal(target, self.vocab_size + 2), dtype=tf.uint8)
|
|
@@ -109,19 +117,17 @@ class MASTER(_MASTER, Model):
|
|
|
109
117
|
def compute_loss(
|
|
110
118
|
model_output: tf.Tensor,
|
|
111
119
|
gt: tf.Tensor,
|
|
112
|
-
seq_len:
|
|
120
|
+
seq_len: list[int],
|
|
113
121
|
) -> tf.Tensor:
|
|
114
122
|
"""Compute categorical cross-entropy loss for the model.
|
|
115
123
|
Sequences are masked after the EOS character.
|
|
116
124
|
|
|
117
125
|
Args:
|
|
118
|
-
----
|
|
119
126
|
gt: the encoded tensor with gt labels
|
|
120
127
|
model_output: predicted logits of the model
|
|
121
128
|
seq_len: lengths of each gt word inside the batch
|
|
122
129
|
|
|
123
130
|
Returns:
|
|
124
|
-
-------
|
|
125
131
|
The loss of the model on the batch
|
|
126
132
|
"""
|
|
127
133
|
# Input length : number of timesteps
|
|
@@ -144,15 +150,14 @@ class MASTER(_MASTER, Model):
|
|
|
144
150
|
def call(
|
|
145
151
|
self,
|
|
146
152
|
x: tf.Tensor,
|
|
147
|
-
target:
|
|
153
|
+
target: list[str] | None = None,
|
|
148
154
|
return_model_output: bool = False,
|
|
149
155
|
return_preds: bool = False,
|
|
150
156
|
**kwargs: Any,
|
|
151
|
-
) ->
|
|
157
|
+
) -> dict[str, Any]:
|
|
152
158
|
"""Call function for training
|
|
153
159
|
|
|
154
160
|
Args:
|
|
155
|
-
----
|
|
156
161
|
x: images
|
|
157
162
|
target: list of str labels
|
|
158
163
|
return_model_output: if True, return logits
|
|
@@ -160,7 +165,6 @@ class MASTER(_MASTER, Model):
|
|
|
160
165
|
**kwargs: keyword arguments passed to the decoder
|
|
161
166
|
|
|
162
167
|
Returns:
|
|
163
|
-
-------
|
|
164
168
|
A dictionnary containing eventually loss, logits and predictions.
|
|
165
169
|
"""
|
|
166
170
|
# Encode
|
|
@@ -171,7 +175,7 @@ class MASTER(_MASTER, Model):
|
|
|
171
175
|
# add positional encoding to features
|
|
172
176
|
encoded = self.positional_encoding(feature, **kwargs)
|
|
173
177
|
|
|
174
|
-
out:
|
|
178
|
+
out: dict[str, tf.Tensor] = {}
|
|
175
179
|
|
|
176
180
|
if kwargs.get("training", False) and target is None:
|
|
177
181
|
raise ValueError("Need to provide labels during training")
|
|
@@ -209,13 +213,11 @@ class MASTER(_MASTER, Model):
|
|
|
209
213
|
"""Decode function for prediction
|
|
210
214
|
|
|
211
215
|
Args:
|
|
212
|
-
----
|
|
213
216
|
encoded: encoded features
|
|
214
217
|
**kwargs: keyword arguments passed to the decoder
|
|
215
218
|
|
|
216
219
|
Returns:
|
|
217
|
-
|
|
218
|
-
A Tuple of tf.Tensor: predictions, logits
|
|
220
|
+
A tuple of tf.Tensor: predictions, logits
|
|
219
221
|
"""
|
|
220
222
|
b = encoded.shape[0]
|
|
221
223
|
|
|
@@ -247,14 +249,13 @@ class MASTERPostProcessor(_MASTERPostProcessor):
|
|
|
247
249
|
"""Post processor for MASTER architectures
|
|
248
250
|
|
|
249
251
|
Args:
|
|
250
|
-
----
|
|
251
252
|
vocab: string containing the ordered sequence of supported characters
|
|
252
253
|
"""
|
|
253
254
|
|
|
254
255
|
def __call__(
|
|
255
256
|
self,
|
|
256
257
|
logits: tf.Tensor,
|
|
257
|
-
) ->
|
|
258
|
+
) -> list[tuple[str, float]]:
|
|
258
259
|
# compute pred with argmax for attention models
|
|
259
260
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
260
261
|
# N x L
|
|
@@ -295,9 +296,7 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool
|
|
|
295
296
|
# Load pretrained parameters
|
|
296
297
|
if pretrained:
|
|
297
298
|
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
298
|
-
|
|
299
|
-
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
300
|
-
)
|
|
299
|
+
model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
301
300
|
|
|
302
301
|
return model
|
|
303
302
|
|
|
@@ -312,12 +311,10 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
|
|
|
312
311
|
>>> out = model(input_tensor)
|
|
313
312
|
|
|
314
313
|
Args:
|
|
315
|
-
----
|
|
316
314
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
317
315
|
**kwargs: keywoard arguments passed to the MASTER architecture
|
|
318
316
|
|
|
319
317
|
Returns:
|
|
320
|
-
-------
|
|
321
318
|
text recognition architecture
|
|
322
319
|
"""
|
|
323
320
|
return _master("master", pretrained, magc_resnet31, **kwargs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import List, Tuple
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
|
|
@@ -17,17 +16,15 @@ class _PARSeq:
|
|
|
17
16
|
|
|
18
17
|
def build_target(
|
|
19
18
|
self,
|
|
20
|
-
gts:
|
|
21
|
-
) ->
|
|
19
|
+
gts: list[str],
|
|
20
|
+
) -> tuple[np.ndarray, list[int]]:
|
|
22
21
|
"""Encode a list of gts sequences into a np array and gives the corresponding*
|
|
23
22
|
sequence lengths.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
gts: list of ground-truth labels
|
|
28
26
|
|
|
29
27
|
Returns:
|
|
30
|
-
-------
|
|
31
28
|
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
|
|
32
29
|
"""
|
|
33
30
|
encoded = encode_sequences(
|
|
@@ -46,7 +43,6 @@ class _PARSeqPostProcessor(RecognitionPostProcessor):
|
|
|
46
43
|
"""Abstract class to postprocess the raw output of the model
|
|
47
44
|
|
|
48
45
|
Args:
|
|
49
|
-
----
|
|
50
46
|
vocab: string containing the ordered sequence of supported characters
|
|
51
47
|
"""
|
|
52
48
|
|
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
+
from collections.abc import Callable
|
|
7
8
|
from copy import deepcopy
|
|
8
9
|
from itertools import permutations
|
|
9
|
-
from typing import Any
|
|
10
|
+
from typing import Any
|
|
10
11
|
|
|
11
12
|
import numpy as np
|
|
12
13
|
import torch
|
|
@@ -23,7 +24,7 @@ from .base import _PARSeq, _PARSeqPostProcessor
|
|
|
23
24
|
|
|
24
25
|
__all__ = ["PARSeq", "parseq"]
|
|
25
26
|
|
|
26
|
-
default_cfgs:
|
|
27
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
27
28
|
"parseq": {
|
|
28
29
|
"mean": (0.694, 0.695, 0.693),
|
|
29
30
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -38,7 +39,6 @@ class CharEmbedding(nn.Module):
|
|
|
38
39
|
"""Implements the character embedding module
|
|
39
40
|
|
|
40
41
|
Args:
|
|
41
|
-
----
|
|
42
42
|
vocab_size: size of the vocabulary
|
|
43
43
|
d_model: dimension of the model
|
|
44
44
|
"""
|
|
@@ -56,7 +56,6 @@ class PARSeqDecoder(nn.Module):
|
|
|
56
56
|
"""Implements decoder module of the PARSeq model
|
|
57
57
|
|
|
58
58
|
Args:
|
|
59
|
-
----
|
|
60
59
|
d_model: dimension of the model
|
|
61
60
|
num_heads: number of attention heads
|
|
62
61
|
ffd: dimension of the feed forward layer
|
|
@@ -77,8 +76,6 @@ class PARSeqDecoder(nn.Module):
|
|
|
77
76
|
self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
|
|
78
77
|
self.position_feed_forward = PositionwiseFeedForward(d_model, ffd * ffd_ratio, dropout, nn.GELU())
|
|
79
78
|
|
|
80
|
-
self.attention_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
81
|
-
self.cross_attention_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
82
79
|
self.query_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
83
80
|
self.content_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
84
81
|
self.feed_forward_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
@@ -92,7 +89,7 @@ class PARSeqDecoder(nn.Module):
|
|
|
92
89
|
target,
|
|
93
90
|
content,
|
|
94
91
|
memory,
|
|
95
|
-
target_mask:
|
|
92
|
+
target_mask: torch.Tensor | None = None,
|
|
96
93
|
):
|
|
97
94
|
query_norm = self.query_norm(target)
|
|
98
95
|
content_norm = self.content_norm(content)
|
|
@@ -112,7 +109,6 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
112
109
|
Slightly modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
|
|
113
110
|
|
|
114
111
|
Args:
|
|
115
|
-
----
|
|
116
112
|
feature_extractor: the backbone serving as feature extractor
|
|
117
113
|
vocab: vocabulary used for encoding
|
|
118
114
|
embedding_units: number of embedding units
|
|
@@ -136,9 +132,9 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
136
132
|
dec_num_heads: int = 12,
|
|
137
133
|
dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
|
|
138
134
|
dec_ffd_ratio: int = 4,
|
|
139
|
-
input_shape:
|
|
135
|
+
input_shape: tuple[int, int, int] = (3, 32, 128),
|
|
140
136
|
exportable: bool = False,
|
|
141
|
-
cfg:
|
|
137
|
+
cfg: dict[str, Any] | None = None,
|
|
142
138
|
) -> None:
|
|
143
139
|
super().__init__()
|
|
144
140
|
self.vocab = vocab
|
|
@@ -175,6 +171,26 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
175
171
|
nn.init.constant_(m.weight, 1)
|
|
176
172
|
nn.init.constant_(m.bias, 0)
|
|
177
173
|
|
|
174
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
175
|
+
"""Load pretrained parameters onto the model
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
179
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
180
|
+
"""
|
|
181
|
+
# NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
|
|
182
|
+
# ref.: https://github.com/mindee/doctr/issues/1911
|
|
183
|
+
if kwargs.get("ignore_keys") is None:
|
|
184
|
+
kwargs["ignore_keys"] = []
|
|
185
|
+
|
|
186
|
+
kwargs["ignore_keys"].extend([
|
|
187
|
+
"decoder.attention_norm.weight",
|
|
188
|
+
"decoder.attention_norm.bias",
|
|
189
|
+
"decoder.cross_attention_norm.weight",
|
|
190
|
+
"decoder.cross_attention_norm.bias",
|
|
191
|
+
])
|
|
192
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
193
|
+
|
|
178
194
|
def generate_permutations(self, seqlen: torch.Tensor) -> torch.Tensor:
|
|
179
195
|
# Generates permutations of the target sequence.
|
|
180
196
|
# Borrowed from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
|
|
@@ -217,7 +233,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
217
233
|
combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
|
|
218
234
|
return combined
|
|
219
235
|
|
|
220
|
-
def generate_permutations_attention_masks(self, permutation: torch.Tensor) ->
|
|
236
|
+
def generate_permutations_attention_masks(self, permutation: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
221
237
|
# Generate source and target mask for the decoder attention.
|
|
222
238
|
sz = permutation.shape[0]
|
|
223
239
|
mask = torch.ones((sz, sz), device=permutation.device)
|
|
@@ -236,8 +252,8 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
236
252
|
self,
|
|
237
253
|
target: torch.Tensor,
|
|
238
254
|
memory: torch.Tensor,
|
|
239
|
-
target_mask:
|
|
240
|
-
target_query:
|
|
255
|
+
target_mask: torch.Tensor | None = None,
|
|
256
|
+
target_query: torch.Tensor | None = None,
|
|
241
257
|
) -> torch.Tensor:
|
|
242
258
|
"""Add positional information to the target sequence and pass it through the decoder."""
|
|
243
259
|
batch_size, sequence_length = target.shape
|
|
@@ -250,7 +266,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
250
266
|
target_query = self.dropout(target_query)
|
|
251
267
|
return self.decoder(target_query, content, memory, target_mask)
|
|
252
268
|
|
|
253
|
-
def decode_autoregressive(self, features: torch.Tensor, max_len:
|
|
269
|
+
def decode_autoregressive(self, features: torch.Tensor, max_len: int | None = None) -> torch.Tensor:
|
|
254
270
|
"""Generate predictions for the given features."""
|
|
255
271
|
max_length = max_len if max_len is not None else self.max_length
|
|
256
272
|
max_length = min(max_length, self.max_length) + 1
|
|
@@ -283,7 +299,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
283
299
|
|
|
284
300
|
# Stop decoding if all sequences have reached the EOS token
|
|
285
301
|
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
|
|
286
|
-
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
|
|
302
|
+
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): # type: ignore[attr-defined]
|
|
287
303
|
break
|
|
288
304
|
|
|
289
305
|
logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
|
|
@@ -298,7 +314,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
298
314
|
|
|
299
315
|
# Create padding mask for refined target input maskes all behind EOS token as False
|
|
300
316
|
# (N, 1, 1, max_length)
|
|
301
|
-
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
317
|
+
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
|
|
302
318
|
mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
|
|
303
319
|
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
|
|
304
320
|
|
|
@@ -307,10 +323,10 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
307
323
|
def forward(
|
|
308
324
|
self,
|
|
309
325
|
x: torch.Tensor,
|
|
310
|
-
target:
|
|
326
|
+
target: list[str] | None = None,
|
|
311
327
|
return_model_output: bool = False,
|
|
312
328
|
return_preds: bool = False,
|
|
313
|
-
) ->
|
|
329
|
+
) -> dict[str, Any]:
|
|
314
330
|
features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
|
|
315
331
|
# remove cls token
|
|
316
332
|
features = features[:, 1:, :]
|
|
@@ -337,7 +353,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
337
353
|
).unsqueeze(1).unsqueeze(1) # (N, 1, 1, seq_len)
|
|
338
354
|
|
|
339
355
|
loss = torch.tensor(0.0, device=features.device)
|
|
340
|
-
loss_numel:
|
|
356
|
+
loss_numel: int | float = 0
|
|
341
357
|
n = (gt_out != self.vocab_size + 2).sum().item()
|
|
342
358
|
for i, perm in enumerate(tgt_perms):
|
|
343
359
|
_, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
|
|
@@ -351,7 +367,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
351
367
|
# remove the [EOS] tokens for the succeeding perms
|
|
352
368
|
if i == 1:
|
|
353
369
|
gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out)
|
|
354
|
-
n = (gt_out != self.vocab_size + 2).sum().item()
|
|
370
|
+
n = (gt_out != self.vocab_size + 2).sum().item() # type: ignore[attr-defined]
|
|
355
371
|
|
|
356
372
|
loss /= loss_numel
|
|
357
373
|
|
|
@@ -365,7 +381,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
365
381
|
|
|
366
382
|
logits = _bf16_to_float32(logits)
|
|
367
383
|
|
|
368
|
-
out:
|
|
384
|
+
out: dict[str, Any] = {}
|
|
369
385
|
if self.exportable:
|
|
370
386
|
out["logits"] = logits
|
|
371
387
|
return out
|
|
@@ -374,8 +390,13 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
374
390
|
out["out_map"] = logits
|
|
375
391
|
|
|
376
392
|
if target is None or return_preds:
|
|
393
|
+
# Disable for torch.compile compatibility
|
|
394
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
395
|
+
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
396
|
+
return self.postprocessor(logits)
|
|
397
|
+
|
|
377
398
|
# Post-process boxes
|
|
378
|
-
out["preds"] =
|
|
399
|
+
out["preds"] = _postprocess(logits)
|
|
379
400
|
|
|
380
401
|
if target is not None:
|
|
381
402
|
out["loss"] = loss
|
|
@@ -387,14 +408,13 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
|
|
|
387
408
|
"""Post processor for PARSeq architecture
|
|
388
409
|
|
|
389
410
|
Args:
|
|
390
|
-
----
|
|
391
411
|
vocab: string containing the ordered sequence of supported characters
|
|
392
412
|
"""
|
|
393
413
|
|
|
394
414
|
def __call__(
|
|
395
415
|
self,
|
|
396
416
|
logits: torch.Tensor,
|
|
397
|
-
) ->
|
|
417
|
+
) -> list[tuple[str, float]]:
|
|
398
418
|
# compute pred with argmax for attention models
|
|
399
419
|
out_idxs = logits.argmax(-1)
|
|
400
420
|
preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
|
|
@@ -417,7 +437,7 @@ def _parseq(
|
|
|
417
437
|
pretrained: bool,
|
|
418
438
|
backbone_fn: Callable[[bool], nn.Module],
|
|
419
439
|
layer: str,
|
|
420
|
-
ignore_keys:
|
|
440
|
+
ignore_keys: list[str] | None = None,
|
|
421
441
|
**kwargs: Any,
|
|
422
442
|
) -> PARSeq:
|
|
423
443
|
# Patch the config
|
|
@@ -446,7 +466,7 @@ def _parseq(
|
|
|
446
466
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
447
467
|
# remove the last layer weights
|
|
448
468
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
449
|
-
|
|
469
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
450
470
|
|
|
451
471
|
return model
|
|
452
472
|
|
|
@@ -462,12 +482,10 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
|
|
|
462
482
|
>>> out = model(input_tensor)
|
|
463
483
|
|
|
464
484
|
Args:
|
|
465
|
-
----
|
|
466
485
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
467
486
|
**kwargs: keyword arguments of the PARSeq architecture
|
|
468
487
|
|
|
469
488
|
Returns:
|
|
470
|
-
-------
|
|
471
489
|
text recognition architecture
|
|
472
490
|
"""
|
|
473
491
|
return _parseq(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
import math
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from itertools import permutations
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
@@ -21,7 +21,7 @@ from .base import _PARSeq, _PARSeqPostProcessor
|
|
|
21
21
|
|
|
22
22
|
__all__ = ["PARSeq", "parseq"]
|
|
23
23
|
|
|
24
|
-
default_cfgs:
|
|
24
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
25
25
|
"parseq": {
|
|
26
26
|
"mean": (0.694, 0.695, 0.693),
|
|
27
27
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -36,7 +36,7 @@ class CharEmbedding(layers.Layer):
|
|
|
36
36
|
"""Implements the character embedding module
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
|
-
|
|
39
|
+
-
|
|
40
40
|
vocab_size: size of the vocabulary
|
|
41
41
|
d_model: dimension of the model
|
|
42
42
|
"""
|
|
@@ -54,7 +54,6 @@ class PARSeqDecoder(layers.Layer):
|
|
|
54
54
|
"""Implements decoder module of the PARSeq model
|
|
55
55
|
|
|
56
56
|
Args:
|
|
57
|
-
----
|
|
58
57
|
d_model: dimension of the model
|
|
59
58
|
num_heads: number of attention heads
|
|
60
59
|
ffd: dimension of the feed forward layer
|
|
@@ -77,8 +76,6 @@ class PARSeqDecoder(layers.Layer):
|
|
|
77
76
|
d_model, ffd * ffd_ratio, dropout, layers.Activation(tf.nn.gelu)
|
|
78
77
|
)
|
|
79
78
|
|
|
80
|
-
self.attention_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
81
|
-
self.cross_attention_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
82
79
|
self.query_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
83
80
|
self.content_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
84
81
|
self.feed_forward_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
@@ -115,7 +112,6 @@ class PARSeq(_PARSeq, Model):
|
|
|
115
112
|
Modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
|
|
116
113
|
|
|
117
114
|
Args:
|
|
118
|
-
----
|
|
119
115
|
feature_extractor: the backbone serving as feature extractor
|
|
120
116
|
vocab: vocabulary used for encoding
|
|
121
117
|
embedding_units: number of embedding units
|
|
@@ -129,7 +125,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
129
125
|
cfg: dictionary containing information about the model
|
|
130
126
|
"""
|
|
131
127
|
|
|
132
|
-
_children_names:
|
|
128
|
+
_children_names: list[str] = ["feat_extractor", "postprocessor"]
|
|
133
129
|
|
|
134
130
|
def __init__(
|
|
135
131
|
self,
|
|
@@ -141,9 +137,9 @@ class PARSeq(_PARSeq, Model):
|
|
|
141
137
|
dec_num_heads: int = 12,
|
|
142
138
|
dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
|
|
143
139
|
dec_ffd_ratio: int = 4,
|
|
144
|
-
input_shape:
|
|
140
|
+
input_shape: tuple[int, int, int] = (32, 128, 3),
|
|
145
141
|
exportable: bool = False,
|
|
146
|
-
cfg:
|
|
142
|
+
cfg: dict[str, Any] | None = None,
|
|
147
143
|
) -> None:
|
|
148
144
|
super().__init__()
|
|
149
145
|
self.vocab = vocab
|
|
@@ -167,6 +163,18 @@ class PARSeq(_PARSeq, Model):
|
|
|
167
163
|
|
|
168
164
|
self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
|
|
169
165
|
|
|
166
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
167
|
+
"""Load pretrained parameters onto the model
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
171
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
172
|
+
"""
|
|
173
|
+
# NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
|
|
174
|
+
# ref.: https://github.com/mindee/doctr/issues/1911
|
|
175
|
+
kwargs["skip_mismatch"] = True
|
|
176
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
177
|
+
|
|
170
178
|
def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
|
|
171
179
|
# Generates permutations of the target sequence.
|
|
172
180
|
# Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
|
|
@@ -213,7 +221,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
213
221
|
)
|
|
214
222
|
return combined
|
|
215
223
|
|
|
216
|
-
def generate_permutations_attention_masks(self, permutation: tf.Tensor) ->
|
|
224
|
+
def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
|
|
217
225
|
# Generate source and target mask for the decoder attention.
|
|
218
226
|
sz = permutation.shape[0]
|
|
219
227
|
mask = tf.ones((sz, sz), dtype=tf.float32)
|
|
@@ -236,8 +244,8 @@ class PARSeq(_PARSeq, Model):
|
|
|
236
244
|
self,
|
|
237
245
|
target: tf.Tensor,
|
|
238
246
|
memory: tf.Tensor,
|
|
239
|
-
target_mask:
|
|
240
|
-
target_query:
|
|
247
|
+
target_mask: tf.Tensor | None = None,
|
|
248
|
+
target_query: tf.Tensor | None = None,
|
|
241
249
|
**kwargs: Any,
|
|
242
250
|
) -> tf.Tensor:
|
|
243
251
|
batch_size, sequence_length = target.shape
|
|
@@ -250,8 +258,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
250
258
|
target_query = self.dropout(target_query, **kwargs)
|
|
251
259
|
return self.decoder(target_query, content, memory, target_mask, **kwargs)
|
|
252
260
|
|
|
253
|
-
|
|
254
|
-
def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor:
|
|
261
|
+
def decode_autoregressive(self, features: tf.Tensor, max_len: int | None = None, **kwargs) -> tf.Tensor:
|
|
255
262
|
"""Generate predictions for the given features."""
|
|
256
263
|
max_length = max_len if max_len is not None else self.max_length
|
|
257
264
|
max_length = min(max_length, self.max_length) + 1
|
|
@@ -318,11 +325,11 @@ class PARSeq(_PARSeq, Model):
|
|
|
318
325
|
def call(
|
|
319
326
|
self,
|
|
320
327
|
x: tf.Tensor,
|
|
321
|
-
target:
|
|
328
|
+
target: list[str] | None = None,
|
|
322
329
|
return_model_output: bool = False,
|
|
323
330
|
return_preds: bool = False,
|
|
324
331
|
**kwargs: Any,
|
|
325
|
-
) ->
|
|
332
|
+
) -> dict[str, Any]:
|
|
326
333
|
features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
|
|
327
334
|
# remove cls token
|
|
328
335
|
features = features[:, 1:, :]
|
|
@@ -393,7 +400,7 @@ class PARSeq(_PARSeq, Model):
|
|
|
393
400
|
|
|
394
401
|
logits = _bf16_to_float32(logits)
|
|
395
402
|
|
|
396
|
-
out:
|
|
403
|
+
out: dict[str, tf.Tensor] = {}
|
|
397
404
|
if self.exportable:
|
|
398
405
|
out["logits"] = logits
|
|
399
406
|
return out
|
|
@@ -415,14 +422,13 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
|
|
|
415
422
|
"""Post processor for PARSeq architecture
|
|
416
423
|
|
|
417
424
|
Args:
|
|
418
|
-
----
|
|
419
425
|
vocab: string containing the ordered sequence of supported characters
|
|
420
426
|
"""
|
|
421
427
|
|
|
422
428
|
def __call__(
|
|
423
429
|
self,
|
|
424
430
|
logits: tf.Tensor,
|
|
425
|
-
) ->
|
|
431
|
+
) -> list[tuple[str, float]]:
|
|
426
432
|
# compute pred with argmax for attention models
|
|
427
433
|
out_idxs = tf.math.argmax(logits, axis=2)
|
|
428
434
|
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
|
|
@@ -448,7 +454,7 @@ def _parseq(
|
|
|
448
454
|
arch: str,
|
|
449
455
|
pretrained: bool,
|
|
450
456
|
backbone_fn,
|
|
451
|
-
input_shape:
|
|
457
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
452
458
|
**kwargs: Any,
|
|
453
459
|
) -> PARSeq:
|
|
454
460
|
# Patch the config
|
|
@@ -478,9 +484,7 @@ def _parseq(
|
|
|
478
484
|
# Load pretrained parameters
|
|
479
485
|
if pretrained:
|
|
480
486
|
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
481
|
-
|
|
482
|
-
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
483
|
-
)
|
|
487
|
+
model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
484
488
|
|
|
485
489
|
return model
|
|
486
490
|
|
|
@@ -496,12 +500,10 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
|
|
|
496
500
|
>>> out = model(input_tensor)
|
|
497
501
|
|
|
498
502
|
Args:
|
|
499
|
-
----
|
|
500
503
|
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
501
504
|
**kwargs: keyword arguments of the PARSeq architecture
|
|
502
505
|
|
|
503
506
|
Returns:
|
|
504
|
-
-------
|
|
505
507
|
text recognition architecture
|
|
506
508
|
"""
|
|
507
509
|
return _parseq(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|