python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +1 -5
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1100 -54
- doctr/file_utils.py +2 -92
- doctr/io/elements.py +37 -3
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +4 -4
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -4
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +11 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vip/__init__.py +1 -0
- doctr/models/classification/vip/layers/__init__.py +1 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +12 -3
- doctr/models/classification/zoo.py +7 -8
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +7 -16
- doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +6 -17
- doctr/models/detection/fast/pytorch.py +17 -8
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +5 -15
- doctr/models/detection/linknet/pytorch.py +12 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +9 -22
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +52 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +16 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +15 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +26 -8
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +100 -47
- doctr/models/recognition/predictor/pytorch.py +4 -5
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +13 -4
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +1 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +13 -4
- doctr/models/recognition/zoo.py +13 -8
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +29 -19
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -433
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -397
- doctr/models/classification/textnet/tensorflow.py +0 -266
- doctr/models/classification/vgg/tensorflow.py +0 -116
- doctr/models/classification/vit/tensorflow.py +0 -192
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
- doctr/models/detection/fast/tensorflow.py +0 -419
- doctr/models/detection/linknet/tensorflow.py +0 -369
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -308
- doctr/models/recognition/master/tensorflow.py +0 -313
- doctr/models/recognition/parseq/tensorflow.py +0 -508
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -416
- doctr/models/recognition/vitstr/tensorflow.py +0 -278
- doctr/models/utils/tensorflow.py +0 -182
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.11.0.dist-info/RECORD +0 -173
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
|
@@ -1,416 +0,0 @@
|
|
|
1
|
-
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
-
|
|
3
|
-
# This program is licensed under the Apache License 2.0.
|
|
4
|
-
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
-
|
|
6
|
-
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
|
-
import tensorflow as tf
|
|
10
|
-
from tensorflow.keras import Model, Sequential, layers
|
|
11
|
-
|
|
12
|
-
from doctr.datasets import VOCABS
|
|
13
|
-
from doctr.utils.repr import NestedObject
|
|
14
|
-
|
|
15
|
-
from ...classification import resnet31
|
|
16
|
-
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
17
|
-
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
18
|
-
|
|
19
|
-
__all__ = ["SAR", "sar_resnet31"]
|
|
20
|
-
|
|
21
|
-
default_cfgs: dict[str, dict[str, Any]] = {
|
|
22
|
-
"sar_resnet31": {
|
|
23
|
-
"mean": (0.694, 0.695, 0.693),
|
|
24
|
-
"std": (0.299, 0.296, 0.301),
|
|
25
|
-
"input_shape": (32, 128, 3),
|
|
26
|
-
"vocab": VOCABS["french"],
|
|
27
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0",
|
|
28
|
-
},
|
|
29
|
-
}
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class SAREncoder(layers.Layer, NestedObject):
|
|
33
|
-
"""Implements encoder module of the SAR model
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
rnn_units: number of hidden rnn units
|
|
37
|
-
dropout_prob: dropout probability
|
|
38
|
-
"""
|
|
39
|
-
|
|
40
|
-
def __init__(self, rnn_units: int, dropout_prob: float = 0.0) -> None:
|
|
41
|
-
super().__init__()
|
|
42
|
-
self.rnn = Sequential([
|
|
43
|
-
layers.LSTM(units=rnn_units, return_sequences=True, recurrent_dropout=dropout_prob),
|
|
44
|
-
layers.LSTM(units=rnn_units, return_sequences=False, recurrent_dropout=dropout_prob),
|
|
45
|
-
])
|
|
46
|
-
|
|
47
|
-
def call(
|
|
48
|
-
self,
|
|
49
|
-
x: tf.Tensor,
|
|
50
|
-
**kwargs: Any,
|
|
51
|
-
) -> tf.Tensor:
|
|
52
|
-
# (N, C)
|
|
53
|
-
return self.rnn(x, **kwargs)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class AttentionModule(layers.Layer, NestedObject):
|
|
57
|
-
"""Implements attention module of the SAR model
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
attention_units: number of hidden attention units
|
|
61
|
-
|
|
62
|
-
"""
|
|
63
|
-
|
|
64
|
-
def __init__(self, attention_units: int) -> None:
|
|
65
|
-
super().__init__()
|
|
66
|
-
self.hidden_state_projector = layers.Conv2D(
|
|
67
|
-
attention_units,
|
|
68
|
-
1,
|
|
69
|
-
strides=1,
|
|
70
|
-
use_bias=False,
|
|
71
|
-
padding="same",
|
|
72
|
-
kernel_initializer="he_normal",
|
|
73
|
-
)
|
|
74
|
-
self.features_projector = layers.Conv2D(
|
|
75
|
-
attention_units,
|
|
76
|
-
3,
|
|
77
|
-
strides=1,
|
|
78
|
-
use_bias=True,
|
|
79
|
-
padding="same",
|
|
80
|
-
kernel_initializer="he_normal",
|
|
81
|
-
)
|
|
82
|
-
self.attention_projector = layers.Conv2D(
|
|
83
|
-
1,
|
|
84
|
-
1,
|
|
85
|
-
strides=1,
|
|
86
|
-
use_bias=False,
|
|
87
|
-
padding="same",
|
|
88
|
-
kernel_initializer="he_normal",
|
|
89
|
-
)
|
|
90
|
-
self.flatten = layers.Flatten()
|
|
91
|
-
|
|
92
|
-
def call(
|
|
93
|
-
self,
|
|
94
|
-
features: tf.Tensor,
|
|
95
|
-
hidden_state: tf.Tensor,
|
|
96
|
-
**kwargs: Any,
|
|
97
|
-
) -> tf.Tensor:
|
|
98
|
-
[H, W] = features.get_shape().as_list()[1:3]
|
|
99
|
-
# shape (N, H, W, vgg_units) -> (N, H, W, attention_units)
|
|
100
|
-
features_projection = self.features_projector(features, **kwargs)
|
|
101
|
-
# shape (N, 1, 1, rnn_units) -> (N, 1, 1, attention_units)
|
|
102
|
-
hidden_state = tf.expand_dims(tf.expand_dims(hidden_state, axis=1), axis=1)
|
|
103
|
-
hidden_state_projection = self.hidden_state_projector(hidden_state, **kwargs)
|
|
104
|
-
projection = tf.math.tanh(hidden_state_projection + features_projection)
|
|
105
|
-
# shape (N, H, W, attention_units) -> (N, H, W, 1)
|
|
106
|
-
attention = self.attention_projector(projection, **kwargs)
|
|
107
|
-
# shape (N, H, W, 1) -> (N, H * W)
|
|
108
|
-
attention = self.flatten(attention)
|
|
109
|
-
attention = tf.nn.softmax(attention)
|
|
110
|
-
# shape (N, H * W) -> (N, H, W, 1)
|
|
111
|
-
attention_map = tf.reshape(attention, [-1, H, W, 1])
|
|
112
|
-
glimpse = tf.math.multiply(features, attention_map)
|
|
113
|
-
# shape (N, H * W) -> (N, C)
|
|
114
|
-
return tf.reduce_sum(glimpse, axis=[1, 2])
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
class SARDecoder(layers.Layer, NestedObject):
|
|
118
|
-
"""Implements decoder module of the SAR model
|
|
119
|
-
|
|
120
|
-
Args:
|
|
121
|
-
rnn_units: number of hidden units in recurrent cells
|
|
122
|
-
max_length: maximum length of a sequence
|
|
123
|
-
vocab_size: number of classes in the model alphabet
|
|
124
|
-
embedding_units: number of hidden embedding units
|
|
125
|
-
attention_units: number of hidden attention units
|
|
126
|
-
num_decoder_cells: number of LSTMCell layers to stack
|
|
127
|
-
dropout_prob: dropout probability
|
|
128
|
-
|
|
129
|
-
"""
|
|
130
|
-
|
|
131
|
-
def __init__(
|
|
132
|
-
self,
|
|
133
|
-
rnn_units: int,
|
|
134
|
-
max_length: int,
|
|
135
|
-
vocab_size: int,
|
|
136
|
-
embedding_units: int,
|
|
137
|
-
attention_units: int,
|
|
138
|
-
num_decoder_cells: int = 2,
|
|
139
|
-
dropout_prob: float = 0.0,
|
|
140
|
-
) -> None:
|
|
141
|
-
super().__init__()
|
|
142
|
-
self.vocab_size = vocab_size
|
|
143
|
-
self.max_length = max_length
|
|
144
|
-
|
|
145
|
-
self.embed = layers.Dense(embedding_units, use_bias=False)
|
|
146
|
-
self.embed_tgt = layers.Embedding(embedding_units, self.vocab_size + 1)
|
|
147
|
-
|
|
148
|
-
self.lstm_cells = layers.StackedRNNCells([
|
|
149
|
-
layers.LSTMCell(rnn_units, implementation=1) for _ in range(num_decoder_cells)
|
|
150
|
-
])
|
|
151
|
-
self.attention_module = AttentionModule(attention_units)
|
|
152
|
-
self.output_dense = layers.Dense(self.vocab_size + 1, use_bias=True)
|
|
153
|
-
self.dropout = layers.Dropout(dropout_prob)
|
|
154
|
-
|
|
155
|
-
def call(
|
|
156
|
-
self,
|
|
157
|
-
features: tf.Tensor,
|
|
158
|
-
holistic: tf.Tensor,
|
|
159
|
-
gt: tf.Tensor | None = None,
|
|
160
|
-
**kwargs: Any,
|
|
161
|
-
) -> tf.Tensor:
|
|
162
|
-
if gt is not None:
|
|
163
|
-
gt_embedding = self.embed_tgt(gt, **kwargs)
|
|
164
|
-
|
|
165
|
-
logits_list: list[tf.Tensor] = []
|
|
166
|
-
|
|
167
|
-
for t in range(self.max_length + 1): # 32
|
|
168
|
-
if t == 0:
|
|
169
|
-
# step to init the first states of the LSTMCell
|
|
170
|
-
states = self.lstm_cells.get_initial_state(
|
|
171
|
-
inputs=None, batch_size=features.shape[0], dtype=features.dtype
|
|
172
|
-
)
|
|
173
|
-
prev_symbol = holistic
|
|
174
|
-
elif t == 1:
|
|
175
|
-
# step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
|
|
176
|
-
# (N, vocab_size + 1) --> (N, embedding_units)
|
|
177
|
-
prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1], dtype=features.dtype)
|
|
178
|
-
prev_symbol = self.embed(prev_symbol, **kwargs)
|
|
179
|
-
else:
|
|
180
|
-
if gt is not None and kwargs.get("training", False):
|
|
181
|
-
# (N, embedding_units) -2 because of <bos> and <eos> (same)
|
|
182
|
-
prev_symbol = self.embed(gt_embedding[:, t - 2], **kwargs)
|
|
183
|
-
else:
|
|
184
|
-
# -1 to start at timestep where prev_symbol was initialized
|
|
185
|
-
index = tf.argmax(logits_list[t - 1], axis=-1)
|
|
186
|
-
# update prev_symbol with ones at the index of the previous logit vector
|
|
187
|
-
prev_symbol = self.embed(self.embed_tgt(index, **kwargs), **kwargs)
|
|
188
|
-
|
|
189
|
-
# (N, C), (N, C) take the last hidden state and cell state from current timestep
|
|
190
|
-
_, states = self.lstm_cells(prev_symbol, states, **kwargs)
|
|
191
|
-
# states = (hidden_state, cell_state)
|
|
192
|
-
hidden_state = states[0][0]
|
|
193
|
-
# (N, H, W, C), (N, C) --> (N, C)
|
|
194
|
-
glimpse = self.attention_module(features, hidden_state, **kwargs)
|
|
195
|
-
# (N, C), (N, C) --> (N, 2 * C)
|
|
196
|
-
logits = tf.concat([hidden_state, glimpse], axis=1)
|
|
197
|
-
logits = self.dropout(logits, **kwargs)
|
|
198
|
-
# (N, vocab_size + 1)
|
|
199
|
-
logits_list.append(self.output_dense(logits, **kwargs))
|
|
200
|
-
|
|
201
|
-
# (max_length + 1, N, vocab_size + 1) --> (N, max_length + 1, vocab_size + 1)
|
|
202
|
-
return tf.transpose(tf.stack(logits_list[1:]), (1, 0, 2))
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
class SAR(Model, RecognitionModel):
|
|
206
|
-
"""Implements a SAR architecture as described in `"Show, Attend and Read:A Simple and Strong Baseline for
|
|
207
|
-
Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
208
|
-
|
|
209
|
-
Args:
|
|
210
|
-
feature_extractor: the backbone serving as feature extractor
|
|
211
|
-
vocab: vocabulary used for encoding
|
|
212
|
-
rnn_units: number of hidden units in both encoder and decoder LSTM
|
|
213
|
-
embedding_units: number of embedding units
|
|
214
|
-
attention_units: number of hidden units in attention module
|
|
215
|
-
max_length: maximum word length handled by the model
|
|
216
|
-
num_decoder_cells: number of LSTMCell layers to stack
|
|
217
|
-
dropout_prob: dropout probability for the encoder and decoder
|
|
218
|
-
exportable: onnx exportable returns only logits
|
|
219
|
-
cfg: dictionary containing information about the model
|
|
220
|
-
"""
|
|
221
|
-
|
|
222
|
-
_children_names: list[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"]
|
|
223
|
-
|
|
224
|
-
def __init__(
|
|
225
|
-
self,
|
|
226
|
-
feature_extractor,
|
|
227
|
-
vocab: str,
|
|
228
|
-
rnn_units: int = 512,
|
|
229
|
-
embedding_units: int = 512,
|
|
230
|
-
attention_units: int = 512,
|
|
231
|
-
max_length: int = 30,
|
|
232
|
-
num_decoder_cells: int = 2,
|
|
233
|
-
dropout_prob: float = 0.0,
|
|
234
|
-
exportable: bool = False,
|
|
235
|
-
cfg: dict[str, Any] | None = None,
|
|
236
|
-
) -> None:
|
|
237
|
-
super().__init__()
|
|
238
|
-
self.vocab = vocab
|
|
239
|
-
self.exportable = exportable
|
|
240
|
-
self.cfg = cfg
|
|
241
|
-
self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word
|
|
242
|
-
|
|
243
|
-
self.feat_extractor = feature_extractor
|
|
244
|
-
|
|
245
|
-
self.encoder = SAREncoder(rnn_units, dropout_prob)
|
|
246
|
-
self.decoder = SARDecoder(
|
|
247
|
-
rnn_units,
|
|
248
|
-
self.max_length,
|
|
249
|
-
len(vocab),
|
|
250
|
-
embedding_units,
|
|
251
|
-
attention_units,
|
|
252
|
-
num_decoder_cells,
|
|
253
|
-
dropout_prob,
|
|
254
|
-
)
|
|
255
|
-
|
|
256
|
-
self.postprocessor = SARPostProcessor(vocab=vocab)
|
|
257
|
-
|
|
258
|
-
@staticmethod
|
|
259
|
-
def compute_loss(
|
|
260
|
-
model_output: tf.Tensor,
|
|
261
|
-
gt: tf.Tensor,
|
|
262
|
-
seq_len: tf.Tensor,
|
|
263
|
-
) -> tf.Tensor:
|
|
264
|
-
"""Compute categorical cross-entropy loss for the model.
|
|
265
|
-
Sequences are masked after the EOS character.
|
|
266
|
-
|
|
267
|
-
Args:
|
|
268
|
-
gt: the encoded tensor with gt labels
|
|
269
|
-
model_output: predicted logits of the model
|
|
270
|
-
seq_len: lengths of each gt word inside the batch
|
|
271
|
-
|
|
272
|
-
Returns:
|
|
273
|
-
The loss of the model on the batch
|
|
274
|
-
"""
|
|
275
|
-
# Input length : number of timesteps
|
|
276
|
-
input_len = tf.shape(model_output)[1]
|
|
277
|
-
# Add one for additional <eos> token
|
|
278
|
-
seq_len = seq_len + 1
|
|
279
|
-
# One-hot gt labels
|
|
280
|
-
oh_gt = tf.one_hot(gt, depth=model_output.shape[2])
|
|
281
|
-
# Compute loss
|
|
282
|
-
cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt, model_output)
|
|
283
|
-
# Compute mask
|
|
284
|
-
mask_values = tf.zeros_like(cce)
|
|
285
|
-
mask_2d = tf.sequence_mask(seq_len, input_len)
|
|
286
|
-
masked_loss = tf.where(mask_2d, cce, mask_values)
|
|
287
|
-
ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype))
|
|
288
|
-
return tf.expand_dims(ce_loss, axis=1)
|
|
289
|
-
|
|
290
|
-
def call(
|
|
291
|
-
self,
|
|
292
|
-
x: tf.Tensor,
|
|
293
|
-
target: list[str] | None = None,
|
|
294
|
-
return_model_output: bool = False,
|
|
295
|
-
return_preds: bool = False,
|
|
296
|
-
**kwargs: Any,
|
|
297
|
-
) -> dict[str, Any]:
|
|
298
|
-
features = self.feat_extractor(x, **kwargs)
|
|
299
|
-
# vertical max pooling --> (N, C, W)
|
|
300
|
-
pooled_features = tf.reduce_max(features, axis=1)
|
|
301
|
-
# holistic (N, C)
|
|
302
|
-
encoded = self.encoder(pooled_features, **kwargs)
|
|
303
|
-
|
|
304
|
-
if target is not None:
|
|
305
|
-
gt, seq_len = self.build_target(target)
|
|
306
|
-
seq_len = tf.cast(seq_len, tf.int32)
|
|
307
|
-
|
|
308
|
-
if kwargs.get("training", False) and target is None:
|
|
309
|
-
raise ValueError("Need to provide labels during training for teacher forcing")
|
|
310
|
-
|
|
311
|
-
decoded_features = _bf16_to_float32(
|
|
312
|
-
self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
out: dict[str, tf.Tensor] = {}
|
|
316
|
-
if self.exportable:
|
|
317
|
-
out["logits"] = decoded_features
|
|
318
|
-
return out
|
|
319
|
-
|
|
320
|
-
if return_model_output:
|
|
321
|
-
out["out_map"] = decoded_features
|
|
322
|
-
|
|
323
|
-
if target is None or return_preds:
|
|
324
|
-
# Post-process boxes
|
|
325
|
-
out["preds"] = self.postprocessor(decoded_features)
|
|
326
|
-
|
|
327
|
-
if target is not None:
|
|
328
|
-
out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
|
|
329
|
-
|
|
330
|
-
return out
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
class SARPostProcessor(RecognitionPostProcessor):
|
|
334
|
-
"""Post processor for SAR architectures
|
|
335
|
-
|
|
336
|
-
Args:
|
|
337
|
-
vocab: string containing the ordered sequence of supported characters
|
|
338
|
-
"""
|
|
339
|
-
|
|
340
|
-
def __call__(
|
|
341
|
-
self,
|
|
342
|
-
logits: tf.Tensor,
|
|
343
|
-
) -> list[tuple[str, float]]:
|
|
344
|
-
# compute pred with argmax for attention models
|
|
345
|
-
out_idxs = tf.math.argmax(logits, axis=2)
|
|
346
|
-
# N x L
|
|
347
|
-
probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)
|
|
348
|
-
# Take the minimum confidence of the sequence
|
|
349
|
-
probs = tf.math.reduce_min(probs, axis=1)
|
|
350
|
-
|
|
351
|
-
# decode raw output of the model with tf_label_to_idx
|
|
352
|
-
out_idxs = tf.cast(out_idxs, dtype="int32")
|
|
353
|
-
embedding = tf.constant(self._embedding, dtype=tf.string)
|
|
354
|
-
decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
|
|
355
|
-
decoded_strings_pred = tf.strings.split(decoded_strings_pred, "<eos>")
|
|
356
|
-
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
|
|
357
|
-
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
|
|
358
|
-
|
|
359
|
-
return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
def _sar(
|
|
363
|
-
arch: str,
|
|
364
|
-
pretrained: bool,
|
|
365
|
-
backbone_fn,
|
|
366
|
-
pretrained_backbone: bool = True,
|
|
367
|
-
input_shape: tuple[int, int, int] | None = None,
|
|
368
|
-
**kwargs: Any,
|
|
369
|
-
) -> SAR:
|
|
370
|
-
pretrained_backbone = pretrained_backbone and not pretrained
|
|
371
|
-
|
|
372
|
-
# Patch the config
|
|
373
|
-
_cfg = deepcopy(default_cfgs[arch])
|
|
374
|
-
_cfg["input_shape"] = input_shape or _cfg["input_shape"]
|
|
375
|
-
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
|
|
376
|
-
|
|
377
|
-
# Feature extractor
|
|
378
|
-
feat_extractor = backbone_fn(
|
|
379
|
-
pretrained=pretrained_backbone,
|
|
380
|
-
input_shape=_cfg["input_shape"],
|
|
381
|
-
include_top=False,
|
|
382
|
-
)
|
|
383
|
-
|
|
384
|
-
kwargs["vocab"] = _cfg["vocab"]
|
|
385
|
-
|
|
386
|
-
# Build the model
|
|
387
|
-
model = SAR(feat_extractor, cfg=_cfg, **kwargs)
|
|
388
|
-
_build_model(model)
|
|
389
|
-
# Load pretrained parameters
|
|
390
|
-
if pretrained:
|
|
391
|
-
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
392
|
-
load_pretrained_params(
|
|
393
|
-
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
394
|
-
)
|
|
395
|
-
|
|
396
|
-
return model
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
|
|
400
|
-
"""SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
|
|
401
|
-
Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
402
|
-
|
|
403
|
-
>>> import tensorflow as tf
|
|
404
|
-
>>> from doctr.models import sar_resnet31
|
|
405
|
-
>>> model = sar_resnet31(pretrained=False)
|
|
406
|
-
>>> input_tensor = tf.random.uniform(shape=[1, 64, 256, 3], maxval=1, dtype=tf.float32)
|
|
407
|
-
>>> out = model(input_tensor)
|
|
408
|
-
|
|
409
|
-
Args:
|
|
410
|
-
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
411
|
-
**kwargs: keyword arguments of the SAR architecture
|
|
412
|
-
|
|
413
|
-
Returns:
|
|
414
|
-
text recognition architecture
|
|
415
|
-
"""
|
|
416
|
-
return _sar("sar_resnet31", pretrained, resnet31, **kwargs)
|
|
@@ -1,278 +0,0 @@
|
|
|
1
|
-
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
-
|
|
3
|
-
# This program is licensed under the Apache License 2.0.
|
|
4
|
-
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
-
|
|
6
|
-
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
|
-
import tensorflow as tf
|
|
10
|
-
from tensorflow.keras import Model, layers
|
|
11
|
-
|
|
12
|
-
from doctr.datasets import VOCABS
|
|
13
|
-
|
|
14
|
-
from ...classification import vit_b, vit_s
|
|
15
|
-
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
16
|
-
from .base import _ViTSTR, _ViTSTRPostProcessor
|
|
17
|
-
|
|
18
|
-
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
|
|
19
|
-
|
|
20
|
-
default_cfgs: dict[str, dict[str, Any]] = {
|
|
21
|
-
"vitstr_small": {
|
|
22
|
-
"mean": (0.694, 0.695, 0.693),
|
|
23
|
-
"std": (0.299, 0.296, 0.301),
|
|
24
|
-
"input_shape": (32, 128, 3),
|
|
25
|
-
"vocab": VOCABS["french"],
|
|
26
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0",
|
|
27
|
-
},
|
|
28
|
-
"vitstr_base": {
|
|
29
|
-
"mean": (0.694, 0.695, 0.693),
|
|
30
|
-
"std": (0.299, 0.296, 0.301),
|
|
31
|
-
"input_shape": (32, 128, 3),
|
|
32
|
-
"vocab": VOCABS["french"],
|
|
33
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0",
|
|
34
|
-
},
|
|
35
|
-
}
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class ViTSTR(_ViTSTR, Model):
|
|
39
|
-
"""Implements a ViTSTR architecture as described in `"Vision Transformer for Fast and
|
|
40
|
-
Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
feature_extractor: the backbone serving as feature extractor
|
|
44
|
-
vocab: vocabulary used for encoding
|
|
45
|
-
embedding_units: number of embedding units
|
|
46
|
-
max_length: maximum word length handled by the model
|
|
47
|
-
dropout_prob: dropout probability for the encoder and decoder
|
|
48
|
-
input_shape: input shape of the image
|
|
49
|
-
exportable: onnx exportable returns only logits
|
|
50
|
-
cfg: dictionary containing information about the model
|
|
51
|
-
"""
|
|
52
|
-
|
|
53
|
-
_children_names: list[str] = ["feat_extractor", "postprocessor"]
|
|
54
|
-
|
|
55
|
-
def __init__(
|
|
56
|
-
self,
|
|
57
|
-
feature_extractor,
|
|
58
|
-
vocab: str,
|
|
59
|
-
embedding_units: int,
|
|
60
|
-
max_length: int = 32,
|
|
61
|
-
dropout_prob: float = 0.0,
|
|
62
|
-
input_shape: tuple[int, int, int] = (32, 128, 3), # different from paper
|
|
63
|
-
exportable: bool = False,
|
|
64
|
-
cfg: dict[str, Any] | None = None,
|
|
65
|
-
) -> None:
|
|
66
|
-
super().__init__()
|
|
67
|
-
self.vocab = vocab
|
|
68
|
-
self.exportable = exportable
|
|
69
|
-
self.cfg = cfg
|
|
70
|
-
self.max_length = max_length + 2 # +2 for SOS and EOS
|
|
71
|
-
|
|
72
|
-
self.feat_extractor = feature_extractor
|
|
73
|
-
self.head = layers.Dense(len(self.vocab) + 1, name="head") # +1 for EOS
|
|
74
|
-
|
|
75
|
-
self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
|
|
76
|
-
|
|
77
|
-
@staticmethod
|
|
78
|
-
def compute_loss(
|
|
79
|
-
model_output: tf.Tensor,
|
|
80
|
-
gt: tf.Tensor,
|
|
81
|
-
seq_len: list[int],
|
|
82
|
-
) -> tf.Tensor:
|
|
83
|
-
"""Compute categorical cross-entropy loss for the model.
|
|
84
|
-
Sequences are masked after the EOS character.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
model_output: predicted logits of the model
|
|
88
|
-
gt: the encoded tensor with gt labels
|
|
89
|
-
seq_len: lengths of each gt word inside the batch
|
|
90
|
-
|
|
91
|
-
Returns:
|
|
92
|
-
The loss of the model on the batch
|
|
93
|
-
"""
|
|
94
|
-
# Input length : number of steps
|
|
95
|
-
input_len = tf.shape(model_output)[1]
|
|
96
|
-
# Add one for additional <eos> token (sos disappear in shift!)
|
|
97
|
-
seq_len = tf.cast(seq_len, tf.int32) + 1
|
|
98
|
-
# One-hot gt labels
|
|
99
|
-
oh_gt = tf.one_hot(gt, depth=model_output.shape[2])
|
|
100
|
-
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
101
|
-
# The "masked" first gt char is <sos>.
|
|
102
|
-
cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt[:, 1:, :], model_output)
|
|
103
|
-
# Compute mask
|
|
104
|
-
mask_values = tf.zeros_like(cce)
|
|
105
|
-
mask_2d = tf.sequence_mask(seq_len, input_len)
|
|
106
|
-
masked_loss = tf.where(mask_2d, cce, mask_values)
|
|
107
|
-
ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype))
|
|
108
|
-
|
|
109
|
-
return tf.expand_dims(ce_loss, axis=1)
|
|
110
|
-
|
|
111
|
-
def call(
|
|
112
|
-
self,
|
|
113
|
-
x: tf.Tensor,
|
|
114
|
-
target: list[str] | None = None,
|
|
115
|
-
return_model_output: bool = False,
|
|
116
|
-
return_preds: bool = False,
|
|
117
|
-
**kwargs: Any,
|
|
118
|
-
) -> dict[str, Any]:
|
|
119
|
-
features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
|
|
120
|
-
|
|
121
|
-
if target is not None:
|
|
122
|
-
gt, seq_len = self.build_target(target)
|
|
123
|
-
seq_len = tf.cast(seq_len, tf.int32)
|
|
124
|
-
|
|
125
|
-
if kwargs.get("training", False) and target is None:
|
|
126
|
-
raise ValueError("Need to provide labels during training")
|
|
127
|
-
|
|
128
|
-
features = features[:, : self.max_length] # (batch_size, max_length, d_model)
|
|
129
|
-
B, N, E = features.shape
|
|
130
|
-
features = tf.reshape(features, (B * N, E))
|
|
131
|
-
logits = tf.reshape(
|
|
132
|
-
self.head(features, **kwargs), (B, N, len(self.vocab) + 1)
|
|
133
|
-
) # (batch_size, max_length, vocab + 1)
|
|
134
|
-
decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
|
|
135
|
-
|
|
136
|
-
out: dict[str, tf.Tensor] = {}
|
|
137
|
-
if self.exportable:
|
|
138
|
-
out["logits"] = decoded_features
|
|
139
|
-
return out
|
|
140
|
-
|
|
141
|
-
if return_model_output:
|
|
142
|
-
out["out_map"] = decoded_features
|
|
143
|
-
|
|
144
|
-
if target is None or return_preds:
|
|
145
|
-
# Post-process boxes
|
|
146
|
-
out["preds"] = self.postprocessor(decoded_features)
|
|
147
|
-
|
|
148
|
-
if target is not None:
|
|
149
|
-
out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
|
|
150
|
-
|
|
151
|
-
return out
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
class ViTSTRPostProcessor(_ViTSTRPostProcessor):
|
|
155
|
-
"""Post processor for ViTSTR architecture
|
|
156
|
-
|
|
157
|
-
Args:
|
|
158
|
-
vocab: string containing the ordered sequence of supported characters
|
|
159
|
-
"""
|
|
160
|
-
|
|
161
|
-
def __call__(
|
|
162
|
-
self,
|
|
163
|
-
logits: tf.Tensor,
|
|
164
|
-
) -> list[tuple[str, float]]:
|
|
165
|
-
# compute pred with argmax for attention models
|
|
166
|
-
out_idxs = tf.math.argmax(logits, axis=2)
|
|
167
|
-
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
|
|
168
|
-
|
|
169
|
-
# decode raw output of the model with tf_label_to_idx
|
|
170
|
-
out_idxs = tf.cast(out_idxs, dtype="int32")
|
|
171
|
-
embedding = tf.constant(self._embedding, dtype=tf.string)
|
|
172
|
-
decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
|
|
173
|
-
decoded_strings_pred = tf.strings.split(decoded_strings_pred, "<eos>")
|
|
174
|
-
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
|
|
175
|
-
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
|
|
176
|
-
|
|
177
|
-
# compute probabilties for each word up to the EOS token
|
|
178
|
-
probs = [
|
|
179
|
-
preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
|
|
180
|
-
for i, word in enumerate(word_values)
|
|
181
|
-
]
|
|
182
|
-
|
|
183
|
-
return list(zip(word_values, probs))
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def _vitstr(
|
|
187
|
-
arch: str,
|
|
188
|
-
pretrained: bool,
|
|
189
|
-
backbone_fn,
|
|
190
|
-
input_shape: tuple[int, int, int] | None = None,
|
|
191
|
-
**kwargs: Any,
|
|
192
|
-
) -> ViTSTR:
|
|
193
|
-
# Patch the config
|
|
194
|
-
_cfg = deepcopy(default_cfgs[arch])
|
|
195
|
-
_cfg["input_shape"] = input_shape or _cfg["input_shape"]
|
|
196
|
-
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
|
|
197
|
-
patch_size = kwargs.get("patch_size", (4, 8))
|
|
198
|
-
|
|
199
|
-
kwargs["vocab"] = _cfg["vocab"]
|
|
200
|
-
|
|
201
|
-
# Feature extractor
|
|
202
|
-
feat_extractor = backbone_fn(
|
|
203
|
-
# NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch
|
|
204
|
-
pretrained=False,
|
|
205
|
-
input_shape=_cfg["input_shape"],
|
|
206
|
-
patch_size=patch_size,
|
|
207
|
-
include_top=False,
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
kwargs.pop("patch_size", None)
|
|
211
|
-
kwargs.pop("pretrained_backbone", None)
|
|
212
|
-
|
|
213
|
-
# Build the model
|
|
214
|
-
model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs)
|
|
215
|
-
_build_model(model)
|
|
216
|
-
|
|
217
|
-
# Load pretrained parameters
|
|
218
|
-
if pretrained:
|
|
219
|
-
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
220
|
-
load_pretrained_params(
|
|
221
|
-
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
return model
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
228
|
-
"""ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
|
|
229
|
-
<https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
230
|
-
|
|
231
|
-
>>> import tensorflow as tf
|
|
232
|
-
>>> from doctr.models import vitstr_small
|
|
233
|
-
>>> model = vitstr_small(pretrained=False)
|
|
234
|
-
>>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
|
|
235
|
-
>>> out = model(input_tensor)
|
|
236
|
-
|
|
237
|
-
Args:
|
|
238
|
-
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
239
|
-
**kwargs: keyword arguments of the ViTSTR architecture
|
|
240
|
-
|
|
241
|
-
Returns:
|
|
242
|
-
text recognition architecture
|
|
243
|
-
"""
|
|
244
|
-
return _vitstr(
|
|
245
|
-
"vitstr_small",
|
|
246
|
-
pretrained,
|
|
247
|
-
vit_s,
|
|
248
|
-
embedding_units=384,
|
|
249
|
-
patch_size=(4, 8),
|
|
250
|
-
**kwargs,
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
|
|
255
|
-
"""ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
|
|
256
|
-
<https://arxiv.org/pdf/2105.08582.pdf>`_.
|
|
257
|
-
|
|
258
|
-
>>> import tensorflow as tf
|
|
259
|
-
>>> from doctr.models import vitstr_base
|
|
260
|
-
>>> model = vitstr_base(pretrained=False)
|
|
261
|
-
>>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
|
|
262
|
-
>>> out = model(input_tensor)
|
|
263
|
-
|
|
264
|
-
Args:
|
|
265
|
-
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
266
|
-
**kwargs: keyword arguments of the ViTSTR architecture
|
|
267
|
-
|
|
268
|
-
Returns:
|
|
269
|
-
text recognition architecture
|
|
270
|
-
"""
|
|
271
|
-
return _vitstr(
|
|
272
|
-
"vitstr_base",
|
|
273
|
-
pretrained,
|
|
274
|
-
vit_b,
|
|
275
|
-
embedding_units=768,
|
|
276
|
-
patch_size=(4, 8),
|
|
277
|
-
**kwargs,
|
|
278
|
-
)
|