python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +1 -5
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1100 -54
- doctr/file_utils.py +2 -92
- doctr/io/elements.py +37 -3
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +4 -4
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -4
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +11 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vip/__init__.py +1 -0
- doctr/models/classification/vip/layers/__init__.py +1 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +12 -3
- doctr/models/classification/zoo.py +7 -8
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +7 -16
- doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +6 -17
- doctr/models/detection/fast/pytorch.py +17 -8
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +5 -15
- doctr/models/detection/linknet/pytorch.py +12 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +9 -22
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +52 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +16 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +15 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +26 -8
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +100 -47
- doctr/models/recognition/predictor/pytorch.py +4 -5
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +13 -4
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +1 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +13 -4
- doctr/models/recognition/zoo.py +13 -8
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +29 -19
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -433
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -397
- doctr/models/classification/textnet/tensorflow.py +0 -266
- doctr/models/classification/vgg/tensorflow.py +0 -116
- doctr/models/classification/vit/tensorflow.py +0 -192
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
- doctr/models/detection/fast/tensorflow.py +0 -419
- doctr/models/detection/linknet/tensorflow.py +0 -369
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -308
- doctr/models/recognition/master/tensorflow.py +0 -313
- doctr/models/recognition/parseq/tensorflow.py +0 -508
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -416
- doctr/models/recognition/vitstr/tensorflow.py +0 -278
- doctr/models/utils/tensorflow.py +0 -182
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.11.0.dist-info/RECORD +0 -173
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
|
@@ -1,122 +0,0 @@
|
|
|
1
|
-
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
-
|
|
3
|
-
# This program is licensed under the Apache License 2.0.
|
|
4
|
-
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
-
|
|
6
|
-
import math
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
|
-
import numpy as np
|
|
10
|
-
import tensorflow as tf
|
|
11
|
-
|
|
12
|
-
from doctr.transforms import Normalize, Resize
|
|
13
|
-
from doctr.utils.multithreading import multithread_exec
|
|
14
|
-
from doctr.utils.repr import NestedObject
|
|
15
|
-
|
|
16
|
-
__all__ = ["PreProcessor"]
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class PreProcessor(NestedObject):
|
|
20
|
-
"""Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
output_size: expected size of each page in format (H, W)
|
|
24
|
-
batch_size: the size of page batches
|
|
25
|
-
mean: mean value of the training distribution by channel
|
|
26
|
-
std: standard deviation of the training distribution by channel
|
|
27
|
-
**kwargs: additional arguments for the resizing operation
|
|
28
|
-
"""
|
|
29
|
-
|
|
30
|
-
_children_names: list[str] = ["resize", "normalize"]
|
|
31
|
-
|
|
32
|
-
def __init__(
|
|
33
|
-
self,
|
|
34
|
-
output_size: tuple[int, int],
|
|
35
|
-
batch_size: int,
|
|
36
|
-
mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
37
|
-
std: tuple[float, float, float] = (1.0, 1.0, 1.0),
|
|
38
|
-
**kwargs: Any,
|
|
39
|
-
) -> None:
|
|
40
|
-
self.batch_size = batch_size
|
|
41
|
-
self.resize = Resize(output_size, **kwargs)
|
|
42
|
-
# Perform the division by 255 at the same time
|
|
43
|
-
self.normalize = Normalize(mean, std)
|
|
44
|
-
self._runs_on_cuda = tf.config.list_physical_devices("GPU") != []
|
|
45
|
-
|
|
46
|
-
def batch_inputs(self, samples: list[tf.Tensor]) -> list[tf.Tensor]:
|
|
47
|
-
"""Gather samples into batches for inference purposes
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
samples: list of samples (tf.Tensor)
|
|
51
|
-
|
|
52
|
-
Returns:
|
|
53
|
-
list of batched samples
|
|
54
|
-
"""
|
|
55
|
-
num_batches = int(math.ceil(len(samples) / self.batch_size))
|
|
56
|
-
batches = [
|
|
57
|
-
tf.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], axis=0)
|
|
58
|
-
for idx in range(int(num_batches))
|
|
59
|
-
]
|
|
60
|
-
|
|
61
|
-
return batches
|
|
62
|
-
|
|
63
|
-
def sample_transforms(self, x: np.ndarray | tf.Tensor) -> tf.Tensor:
|
|
64
|
-
if x.ndim != 3:
|
|
65
|
-
raise AssertionError("expected list of 3D Tensors")
|
|
66
|
-
if isinstance(x, np.ndarray):
|
|
67
|
-
if x.dtype not in (np.uint8, np.float32):
|
|
68
|
-
raise TypeError("unsupported data type for numpy.ndarray")
|
|
69
|
-
x = tf.convert_to_tensor(x)
|
|
70
|
-
elif x.dtype not in (tf.uint8, tf.float16, tf.float32):
|
|
71
|
-
raise TypeError("unsupported data type for torch.Tensor")
|
|
72
|
-
# Data type & 255 division
|
|
73
|
-
if x.dtype == tf.uint8:
|
|
74
|
-
x = tf.image.convert_image_dtype(x, dtype=tf.float32)
|
|
75
|
-
# Resizing
|
|
76
|
-
x = self.resize(x)
|
|
77
|
-
|
|
78
|
-
return x
|
|
79
|
-
|
|
80
|
-
def __call__(self, x: tf.Tensor | np.ndarray | list[tf.Tensor | np.ndarray]) -> list[tf.Tensor]:
|
|
81
|
-
"""Prepare document data for model forwarding
|
|
82
|
-
|
|
83
|
-
Args:
|
|
84
|
-
x: list of images (np.array) or tensors (already resized and batched)
|
|
85
|
-
|
|
86
|
-
Returns:
|
|
87
|
-
list of page batches
|
|
88
|
-
"""
|
|
89
|
-
# Input type check
|
|
90
|
-
if isinstance(x, (np.ndarray, tf.Tensor)):
|
|
91
|
-
if x.ndim != 4:
|
|
92
|
-
raise AssertionError("expected 4D Tensor")
|
|
93
|
-
if isinstance(x, np.ndarray):
|
|
94
|
-
if x.dtype not in (np.uint8, np.float32):
|
|
95
|
-
raise TypeError("unsupported data type for numpy.ndarray")
|
|
96
|
-
x = tf.convert_to_tensor(x)
|
|
97
|
-
elif x.dtype not in (tf.uint8, tf.float16, tf.float32):
|
|
98
|
-
raise TypeError("unsupported data type for torch.Tensor")
|
|
99
|
-
|
|
100
|
-
# Data type & 255 division
|
|
101
|
-
if x.dtype == tf.uint8:
|
|
102
|
-
x = tf.image.convert_image_dtype(x, dtype=tf.float32)
|
|
103
|
-
# Resizing
|
|
104
|
-
if (x.shape[1], x.shape[2]) != self.resize.output_size:
|
|
105
|
-
x = tf.image.resize(
|
|
106
|
-
x, self.resize.output_size, method=self.resize.method, antialias=self.resize.antialias
|
|
107
|
-
)
|
|
108
|
-
|
|
109
|
-
batches = [x]
|
|
110
|
-
|
|
111
|
-
elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x):
|
|
112
|
-
# Sample transform (to tensor, resize)
|
|
113
|
-
samples = list(multithread_exec(self.sample_transforms, x, threads=1 if self._runs_on_cuda else None))
|
|
114
|
-
# Batching
|
|
115
|
-
batches = self.batch_inputs(samples)
|
|
116
|
-
else:
|
|
117
|
-
raise TypeError(f"invalid input type: {type(x)}")
|
|
118
|
-
|
|
119
|
-
# Batch transforms (normalize)
|
|
120
|
-
batches = list(multithread_exec(self.normalize, batches, threads=1 if self._runs_on_cuda else None))
|
|
121
|
-
|
|
122
|
-
return batches
|
|
@@ -1,308 +0,0 @@
|
|
|
1
|
-
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
-
|
|
3
|
-
# This program is licensed under the Apache License 2.0.
|
|
4
|
-
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
-
|
|
6
|
-
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
|
-
import tensorflow as tf
|
|
10
|
-
from tensorflow.keras import layers
|
|
11
|
-
from tensorflow.keras.models import Model, Sequential
|
|
12
|
-
|
|
13
|
-
from doctr.datasets import VOCABS
|
|
14
|
-
|
|
15
|
-
from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
|
|
16
|
-
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
17
|
-
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
18
|
-
|
|
19
|
-
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
20
|
-
|
|
21
|
-
default_cfgs: dict[str, dict[str, Any]] = {
|
|
22
|
-
"crnn_vgg16_bn": {
|
|
23
|
-
"mean": (0.694, 0.695, 0.693),
|
|
24
|
-
"std": (0.299, 0.296, 0.301),
|
|
25
|
-
"input_shape": (32, 128, 3),
|
|
26
|
-
"vocab": VOCABS["legacy_french"],
|
|
27
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0",
|
|
28
|
-
},
|
|
29
|
-
"crnn_mobilenet_v3_small": {
|
|
30
|
-
"mean": (0.694, 0.695, 0.693),
|
|
31
|
-
"std": (0.299, 0.296, 0.301),
|
|
32
|
-
"input_shape": (32, 128, 3),
|
|
33
|
-
"vocab": VOCABS["french"],
|
|
34
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0",
|
|
35
|
-
},
|
|
36
|
-
"crnn_mobilenet_v3_large": {
|
|
37
|
-
"mean": (0.694, 0.695, 0.693),
|
|
38
|
-
"std": (0.299, 0.296, 0.301),
|
|
39
|
-
"input_shape": (32, 128, 3),
|
|
40
|
-
"vocab": VOCABS["french"],
|
|
41
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0",
|
|
42
|
-
},
|
|
43
|
-
}
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class CTCPostProcessor(RecognitionPostProcessor):
|
|
47
|
-
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
vocab: string containing the ordered sequence of supported characters
|
|
51
|
-
ignore_case: if True, ignore case of letters
|
|
52
|
-
ignore_accents: if True, ignore accents of letters
|
|
53
|
-
"""
|
|
54
|
-
|
|
55
|
-
def __call__(
|
|
56
|
-
self,
|
|
57
|
-
logits: tf.Tensor,
|
|
58
|
-
beam_width: int = 1,
|
|
59
|
-
top_paths: int = 1,
|
|
60
|
-
) -> list[tuple[str, float]] | list[tuple[list[str] | list[float]]]:
|
|
61
|
-
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
62
|
-
with label_to_idx mapping dictionnary
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
logits: raw output of the model, shape BATCH_SIZE X SEQ_LEN X NUM_CLASSES + 1
|
|
66
|
-
beam_width: An int scalar >= 0 (beam search beam width).
|
|
67
|
-
top_paths: An int scalar >= 0, <= beam_width (controls output size).
|
|
68
|
-
|
|
69
|
-
Returns:
|
|
70
|
-
A list of decoded words of length BATCH_SIZE
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
"""
|
|
74
|
-
# Decode CTC
|
|
75
|
-
_decoded, _log_prob = tf.nn.ctc_beam_search_decoder(
|
|
76
|
-
tf.transpose(logits, perm=[1, 0, 2]),
|
|
77
|
-
tf.fill(tf.shape(logits)[:1], tf.shape(logits)[1]),
|
|
78
|
-
beam_width=beam_width,
|
|
79
|
-
top_paths=top_paths,
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
_decoded = tf.sparse.concat(
|
|
83
|
-
1,
|
|
84
|
-
[tf.sparse.expand_dims(dec, axis=1) for dec in _decoded],
|
|
85
|
-
expand_nonconcat_dims=True,
|
|
86
|
-
) # dim : batchsize x beamwidth x actual_max_len_predictions
|
|
87
|
-
out_idxs = tf.sparse.to_dense(_decoded, default_value=len(self.vocab))
|
|
88
|
-
|
|
89
|
-
# Map it to characters
|
|
90
|
-
_decoded_strings_pred = tf.strings.reduce_join(
|
|
91
|
-
inputs=tf.nn.embedding_lookup(tf.constant(self._embedding, dtype=tf.string), out_idxs),
|
|
92
|
-
axis=-1,
|
|
93
|
-
)
|
|
94
|
-
_decoded_strings_pred = tf.strings.split(_decoded_strings_pred, "<eos>")
|
|
95
|
-
decoded_strings_pred = tf.sparse.to_dense(_decoded_strings_pred.to_sparse(), default_value="not valid")[
|
|
96
|
-
:, :, 0
|
|
97
|
-
] # dim : batch_size x beam_width
|
|
98
|
-
|
|
99
|
-
if top_paths == 1:
|
|
100
|
-
probs = tf.math.exp(tf.squeeze(_log_prob, axis=1)) # dim : batchsize
|
|
101
|
-
decoded_strings_pred = tf.squeeze(decoded_strings_pred, axis=1)
|
|
102
|
-
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
|
|
103
|
-
else:
|
|
104
|
-
probs = tf.math.exp(_log_prob) # dim : batchsize x beamwidth
|
|
105
|
-
word_values = [[word.decode() for word in words] for words in decoded_strings_pred.numpy().tolist()]
|
|
106
|
-
return list(zip(word_values, probs.numpy().tolist()))
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
class CRNN(RecognitionModel, Model):
|
|
110
|
-
"""Implements a CRNN architecture as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
111
|
-
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
112
|
-
|
|
113
|
-
Args:
|
|
114
|
-
feature_extractor: the backbone serving as feature extractor
|
|
115
|
-
vocab: vocabulary used for encoding
|
|
116
|
-
rnn_units: number of units in the LSTM layers
|
|
117
|
-
exportable: onnx exportable returns only logits
|
|
118
|
-
beam_width: beam width for beam search decoding
|
|
119
|
-
top_paths: number of top paths for beam search decoding
|
|
120
|
-
cfg: configuration dictionary
|
|
121
|
-
"""
|
|
122
|
-
|
|
123
|
-
_children_names: list[str] = ["feat_extractor", "decoder", "postprocessor"]
|
|
124
|
-
|
|
125
|
-
def __init__(
|
|
126
|
-
self,
|
|
127
|
-
feature_extractor: Model,
|
|
128
|
-
vocab: str,
|
|
129
|
-
rnn_units: int = 128,
|
|
130
|
-
exportable: bool = False,
|
|
131
|
-
beam_width: int = 1,
|
|
132
|
-
top_paths: int = 1,
|
|
133
|
-
cfg: dict[str, Any] | None = None,
|
|
134
|
-
) -> None:
|
|
135
|
-
# Initialize kernels
|
|
136
|
-
h, w, c = feature_extractor.output_shape[1:]
|
|
137
|
-
|
|
138
|
-
super().__init__()
|
|
139
|
-
self.vocab = vocab
|
|
140
|
-
self.max_length = w
|
|
141
|
-
self.cfg = cfg
|
|
142
|
-
self.exportable = exportable
|
|
143
|
-
self.feat_extractor = feature_extractor
|
|
144
|
-
|
|
145
|
-
self.decoder = Sequential([
|
|
146
|
-
layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
|
|
147
|
-
layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
|
|
148
|
-
layers.Dense(units=len(vocab) + 1),
|
|
149
|
-
])
|
|
150
|
-
self.decoder.build(input_shape=(None, w, h * c))
|
|
151
|
-
|
|
152
|
-
self.postprocessor = CTCPostProcessor(vocab=vocab)
|
|
153
|
-
|
|
154
|
-
self.beam_width = beam_width
|
|
155
|
-
self.top_paths = top_paths
|
|
156
|
-
|
|
157
|
-
def compute_loss(
|
|
158
|
-
self,
|
|
159
|
-
model_output: tf.Tensor,
|
|
160
|
-
target: list[str],
|
|
161
|
-
) -> tf.Tensor:
|
|
162
|
-
"""Compute CTC loss for the model.
|
|
163
|
-
|
|
164
|
-
Args:
|
|
165
|
-
model_output: predicted logits of the model
|
|
166
|
-
target: lengths of each gt word inside the batch
|
|
167
|
-
|
|
168
|
-
Returns:
|
|
169
|
-
The loss of the model on the batch
|
|
170
|
-
"""
|
|
171
|
-
gt, seq_len = self.build_target(target)
|
|
172
|
-
batch_len = model_output.shape[0]
|
|
173
|
-
input_length = tf.fill((batch_len,), model_output.shape[1])
|
|
174
|
-
ctc_loss = tf.nn.ctc_loss(
|
|
175
|
-
gt, model_output, seq_len, input_length, logits_time_major=False, blank_index=len(self.vocab)
|
|
176
|
-
)
|
|
177
|
-
return ctc_loss
|
|
178
|
-
|
|
179
|
-
def call(
|
|
180
|
-
self,
|
|
181
|
-
x: tf.Tensor,
|
|
182
|
-
target: list[str] | None = None,
|
|
183
|
-
return_model_output: bool = False,
|
|
184
|
-
return_preds: bool = False,
|
|
185
|
-
beam_width: int = 1,
|
|
186
|
-
top_paths: int = 1,
|
|
187
|
-
**kwargs: Any,
|
|
188
|
-
) -> dict[str, Any]:
|
|
189
|
-
if kwargs.get("training", False) and target is None:
|
|
190
|
-
raise ValueError("Need to provide labels during training")
|
|
191
|
-
|
|
192
|
-
features = self.feat_extractor(x, **kwargs)
|
|
193
|
-
# B x H x W x C --> B x W x H x C
|
|
194
|
-
transposed_feat = tf.transpose(features, perm=[0, 2, 1, 3])
|
|
195
|
-
w, h, c = transposed_feat.get_shape().as_list()[1:]
|
|
196
|
-
# B x W x H x C --> B x W x H * C
|
|
197
|
-
features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
|
|
198
|
-
logits = _bf16_to_float32(self.decoder(features_seq, **kwargs))
|
|
199
|
-
|
|
200
|
-
out: dict[str, tf.Tensor] = {}
|
|
201
|
-
if self.exportable:
|
|
202
|
-
out["logits"] = logits
|
|
203
|
-
return out
|
|
204
|
-
|
|
205
|
-
if return_model_output:
|
|
206
|
-
out["out_map"] = logits
|
|
207
|
-
|
|
208
|
-
if target is None or return_preds:
|
|
209
|
-
# Post-process boxes
|
|
210
|
-
out["preds"] = self.postprocessor(logits, beam_width=beam_width, top_paths=top_paths)
|
|
211
|
-
|
|
212
|
-
if target is not None:
|
|
213
|
-
out["loss"] = self.compute_loss(logits, target)
|
|
214
|
-
|
|
215
|
-
return out
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
def _crnn(
|
|
219
|
-
arch: str,
|
|
220
|
-
pretrained: bool,
|
|
221
|
-
backbone_fn,
|
|
222
|
-
pretrained_backbone: bool = True,
|
|
223
|
-
input_shape: tuple[int, int, int] | None = None,
|
|
224
|
-
**kwargs: Any,
|
|
225
|
-
) -> CRNN:
|
|
226
|
-
pretrained_backbone = pretrained_backbone and not pretrained
|
|
227
|
-
|
|
228
|
-
kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
|
|
229
|
-
|
|
230
|
-
_cfg = deepcopy(default_cfgs[arch])
|
|
231
|
-
_cfg["vocab"] = kwargs["vocab"]
|
|
232
|
-
_cfg["input_shape"] = input_shape or default_cfgs[arch]["input_shape"]
|
|
233
|
-
|
|
234
|
-
feat_extractor = backbone_fn(
|
|
235
|
-
input_shape=_cfg["input_shape"],
|
|
236
|
-
include_top=False,
|
|
237
|
-
pretrained=pretrained_backbone,
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
# Build the model
|
|
241
|
-
model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
|
|
242
|
-
_build_model(model)
|
|
243
|
-
# Load pretrained parameters
|
|
244
|
-
if pretrained:
|
|
245
|
-
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
246
|
-
load_pretrained_params(model, _cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
247
|
-
|
|
248
|
-
return model
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
252
|
-
"""CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
253
|
-
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
254
|
-
|
|
255
|
-
>>> import tensorflow as tf
|
|
256
|
-
>>> from doctr.models import crnn_vgg16_bn
|
|
257
|
-
>>> model = crnn_vgg16_bn(pretrained=True)
|
|
258
|
-
>>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
|
|
259
|
-
>>> out = model(input_tensor)
|
|
260
|
-
|
|
261
|
-
Args:
|
|
262
|
-
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
263
|
-
**kwargs: keyword arguments of the CRNN architecture
|
|
264
|
-
|
|
265
|
-
Returns:
|
|
266
|
-
text recognition architecture
|
|
267
|
-
"""
|
|
268
|
-
return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, **kwargs)
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
272
|
-
"""CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
273
|
-
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
274
|
-
|
|
275
|
-
>>> import tensorflow as tf
|
|
276
|
-
>>> from doctr.models import crnn_mobilenet_v3_small
|
|
277
|
-
>>> model = crnn_mobilenet_v3_small(pretrained=True)
|
|
278
|
-
>>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
|
|
279
|
-
>>> out = model(input_tensor)
|
|
280
|
-
|
|
281
|
-
Args:
|
|
282
|
-
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
283
|
-
**kwargs: keyword arguments of the CRNN architecture
|
|
284
|
-
|
|
285
|
-
Returns:
|
|
286
|
-
text recognition architecture
|
|
287
|
-
"""
|
|
288
|
-
return _crnn("crnn_mobilenet_v3_small", pretrained, mobilenet_v3_small_r, **kwargs)
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
|
|
292
|
-
"""CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
|
|
293
|
-
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
|
|
294
|
-
|
|
295
|
-
>>> import tensorflow as tf
|
|
296
|
-
>>> from doctr.models import crnn_mobilenet_v3_large
|
|
297
|
-
>>> model = crnn_mobilenet_v3_large(pretrained=True)
|
|
298
|
-
>>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
|
|
299
|
-
>>> out = model(input_tensor)
|
|
300
|
-
|
|
301
|
-
Args:
|
|
302
|
-
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
303
|
-
**kwargs: keyword arguments of the CRNN architecture
|
|
304
|
-
|
|
305
|
-
Returns:
|
|
306
|
-
text recognition architecture
|
|
307
|
-
"""
|
|
308
|
-
return _crnn("crnn_mobilenet_v3_large", pretrained, mobilenet_v3_large_r, **kwargs)
|