python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/contrib/artefacts.py +1 -1
- doctr/contrib/base.py +1 -1
- doctr/datasets/__init__.py +0 -5
- doctr/datasets/coco_text.py +1 -1
- doctr/datasets/cord.py +1 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/base.py +1 -1
- doctr/datasets/datasets/pytorch.py +3 -3
- doctr/datasets/detection.py +1 -1
- doctr/datasets/doc_artefacts.py +1 -1
- doctr/datasets/funsd.py +1 -1
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/generator/base.py +1 -1
- doctr/datasets/generator/pytorch.py +1 -1
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +1 -1
- doctr/datasets/iiit5k.py +1 -1
- doctr/datasets/iiithws.py +1 -1
- doctr/datasets/imgur5k.py +1 -1
- doctr/datasets/mjsynth.py +1 -1
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/orientation.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/sroie.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +1 -1
- doctr/datasets/synthtext.py +1 -1
- doctr/datasets/utils.py +1 -1
- doctr/datasets/vocabs.py +1 -3
- doctr/datasets/wildreceipt.py +1 -1
- doctr/file_utils.py +3 -102
- doctr/io/elements.py +1 -1
- doctr/io/html.py +1 -1
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/base.py +1 -1
- doctr/io/image/pytorch.py +2 -2
- doctr/io/pdf.py +1 -1
- doctr/io/reader.py +1 -1
- doctr/models/_utils.py +56 -18
- doctr/models/builder.py +1 -1
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -3
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +1 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +1 -1
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +2 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +1 -1
- doctr/models/classification/vip/__init__.py +1 -4
- doctr/models/classification/vip/layers/__init__.py +1 -4
- doctr/models/classification/vip/layers/pytorch.py +2 -2
- doctr/models/classification/vip/pytorch.py +1 -1
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +3 -3
- doctr/models/classification/zoo.py +7 -12
- doctr/models/core.py +1 -1
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/_utils/base.py +1 -1
- doctr/models/detection/_utils/pytorch.py +1 -1
- doctr/models/detection/core.py +2 -2
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +5 -13
- doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +5 -15
- doctr/models/detection/fast/pytorch.py +5 -5
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +4 -13
- doctr/models/detection/linknet/pytorch.py +3 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +2 -2
- doctr/models/detection/zoo.py +16 -33
- doctr/models/factory/hub.py +26 -34
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/base.py +1 -1
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +4 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +3 -3
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/modules/vision_transformer/pytorch.py +1 -1
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +4 -9
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +28 -33
- doctr/models/recognition/core.py +1 -1
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +7 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/base.py +1 -1
- doctr/models/recognition/master/pytorch.py +6 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/base.py +1 -1
- doctr/models/recognition/parseq/pytorch.py +6 -6
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +8 -17
- doctr/models/recognition/predictor/pytorch.py +2 -3
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +4 -4
- doctr/models/recognition/utils.py +1 -1
- doctr/models/recognition/viptr/__init__.py +1 -4
- doctr/models/recognition/viptr/pytorch.py +4 -4
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/base.py +1 -1
- doctr/models/recognition/vitstr/pytorch.py +4 -4
- doctr/models/recognition/zoo.py +14 -14
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +3 -2
- doctr/models/zoo.py +1 -1
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/base.py +3 -2
- doctr/transforms/functional/pytorch.py +5 -5
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +28 -94
- doctr/transforms/modules/pytorch.py +29 -27
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +1 -2
- doctr/utils/fonts.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/metrics.py +1 -1
- doctr/utils/multithreading.py +1 -1
- doctr/utils/reconstitution.py +1 -1
- doctr/utils/repr.py +1 -1
- doctr/utils/visualization.py +2 -2
- doctr/version.py +1 -1
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
- python_doctr-1.0.1.dist-info/RECORD +149 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -442
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -418
- doctr/models/classification/textnet/tensorflow.py +0 -275
- doctr/models/classification/vgg/tensorflow.py +0 -125
- doctr/models/classification/vit/tensorflow.py +0 -201
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
- doctr/models/detection/fast/tensorflow.py +0 -427
- doctr/models/detection/linknet/tensorflow.py +0 -377
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -317
- doctr/models/recognition/master/tensorflow.py +0 -320
- doctr/models/recognition/parseq/tensorflow.py +0 -516
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -423
- doctr/models/recognition/vitstr/tensorflow.py +0 -285
- doctr/models/utils/tensorflow.py +0 -189
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.12.0.dist-info/RECORD +0 -180
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
|
@@ -1,516 +0,0 @@
|
|
|
1
|
-
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
-
|
|
3
|
-
# This program is licensed under the Apache License 2.0.
|
|
4
|
-
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
-
|
|
6
|
-
import math
|
|
7
|
-
from copy import deepcopy
|
|
8
|
-
from itertools import permutations
|
|
9
|
-
from typing import Any
|
|
10
|
-
|
|
11
|
-
import numpy as np
|
|
12
|
-
import tensorflow as tf
|
|
13
|
-
from tensorflow.keras import Model, layers
|
|
14
|
-
|
|
15
|
-
from doctr.datasets import VOCABS
|
|
16
|
-
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
|
|
17
|
-
|
|
18
|
-
from ...classification import vit_s
|
|
19
|
-
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
20
|
-
from .base import _PARSeq, _PARSeqPostProcessor
|
|
21
|
-
|
|
22
|
-
__all__ = ["PARSeq", "parseq"]
|
|
23
|
-
|
|
24
|
-
default_cfgs: dict[str, dict[str, Any]] = {
|
|
25
|
-
"parseq": {
|
|
26
|
-
"mean": (0.694, 0.695, 0.693),
|
|
27
|
-
"std": (0.299, 0.296, 0.301),
|
|
28
|
-
"input_shape": (32, 128, 3),
|
|
29
|
-
"vocab": VOCABS["french"],
|
|
30
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0",
|
|
31
|
-
},
|
|
32
|
-
}
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class CharEmbedding(layers.Layer):
|
|
36
|
-
"""Implements the character embedding module
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
-
|
|
40
|
-
vocab_size: size of the vocabulary
|
|
41
|
-
d_model: dimension of the model
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
|
-
def __init__(self, vocab_size: int, d_model: int):
|
|
45
|
-
super(CharEmbedding, self).__init__()
|
|
46
|
-
self.embedding = layers.Embedding(vocab_size, d_model)
|
|
47
|
-
self.d_model = d_model
|
|
48
|
-
|
|
49
|
-
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
50
|
-
return math.sqrt(self.d_model) * self.embedding(x, **kwargs)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class PARSeqDecoder(layers.Layer):
|
|
54
|
-
"""Implements decoder module of the PARSeq model
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
d_model: dimension of the model
|
|
58
|
-
num_heads: number of attention heads
|
|
59
|
-
ffd: dimension of the feed forward layer
|
|
60
|
-
ffd_ratio: depth multiplier for the feed forward layer
|
|
61
|
-
dropout: dropout rate
|
|
62
|
-
"""
|
|
63
|
-
|
|
64
|
-
def __init__(
|
|
65
|
-
self,
|
|
66
|
-
d_model: int,
|
|
67
|
-
num_heads: int = 12,
|
|
68
|
-
ffd: int = 2048,
|
|
69
|
-
ffd_ratio: int = 4,
|
|
70
|
-
dropout: float = 0.1,
|
|
71
|
-
):
|
|
72
|
-
super(PARSeqDecoder, self).__init__()
|
|
73
|
-
self.attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
|
|
74
|
-
self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
|
|
75
|
-
self.position_feed_forward = PositionwiseFeedForward(
|
|
76
|
-
d_model, ffd * ffd_ratio, dropout, layers.Activation(tf.nn.gelu)
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
self.query_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
80
|
-
self.content_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
81
|
-
self.feed_forward_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
82
|
-
self.output_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
83
|
-
self.attention_dropout = layers.Dropout(dropout)
|
|
84
|
-
self.cross_attention_dropout = layers.Dropout(dropout)
|
|
85
|
-
self.feed_forward_dropout = layers.Dropout(dropout)
|
|
86
|
-
|
|
87
|
-
def call(
|
|
88
|
-
self,
|
|
89
|
-
target,
|
|
90
|
-
content,
|
|
91
|
-
memory,
|
|
92
|
-
target_mask=None,
|
|
93
|
-
**kwargs: Any,
|
|
94
|
-
):
|
|
95
|
-
query_norm = self.query_norm(target, **kwargs)
|
|
96
|
-
content_norm = self.content_norm(content, **kwargs)
|
|
97
|
-
target = target + self.attention_dropout(
|
|
98
|
-
self.attention(query_norm, content_norm, content_norm, mask=target_mask, **kwargs), **kwargs
|
|
99
|
-
)
|
|
100
|
-
target = target + self.cross_attention_dropout(
|
|
101
|
-
self.cross_attention(self.query_norm(target, **kwargs), memory, memory, **kwargs), **kwargs
|
|
102
|
-
)
|
|
103
|
-
target = target + self.feed_forward_dropout(
|
|
104
|
-
self.position_feed_forward(self.feed_forward_norm(target, **kwargs), **kwargs), **kwargs
|
|
105
|
-
)
|
|
106
|
-
return self.output_norm(target, **kwargs)
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
class PARSeq(_PARSeq, Model):
|
|
110
|
-
"""Implements a PARSeq architecture as described in `"Scene Text Recognition
|
|
111
|
-
with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
|
|
112
|
-
Modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
feature_extractor: the backbone serving as feature extractor
|
|
116
|
-
vocab: vocabulary used for encoding
|
|
117
|
-
embedding_units: number of embedding units
|
|
118
|
-
max_length: maximum word length handled by the model
|
|
119
|
-
dropout_prob: dropout probability for the decoder
|
|
120
|
-
dec_num_heads: number of attention heads in the decoder
|
|
121
|
-
dec_ff_dim: dimension of the feed forward layer in the decoder
|
|
122
|
-
dec_ffd_ratio: depth multiplier for the feed forward layer in the decoder
|
|
123
|
-
input_shape: input shape of the image
|
|
124
|
-
exportable: onnx exportable returns only logits
|
|
125
|
-
cfg: dictionary containing information about the model
|
|
126
|
-
"""
|
|
127
|
-
|
|
128
|
-
_children_names: list[str] = ["feat_extractor", "postprocessor"]
|
|
129
|
-
|
|
130
|
-
def __init__(
|
|
131
|
-
self,
|
|
132
|
-
feature_extractor,
|
|
133
|
-
vocab: str,
|
|
134
|
-
embedding_units: int,
|
|
135
|
-
max_length: int = 32, # different from paper
|
|
136
|
-
dropout_prob: float = 0.1,
|
|
137
|
-
dec_num_heads: int = 12,
|
|
138
|
-
dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
|
|
139
|
-
dec_ffd_ratio: int = 4,
|
|
140
|
-
input_shape: tuple[int, int, int] = (32, 128, 3),
|
|
141
|
-
exportable: bool = False,
|
|
142
|
-
cfg: dict[str, Any] | None = None,
|
|
143
|
-
) -> None:
|
|
144
|
-
super().__init__()
|
|
145
|
-
self.vocab = vocab
|
|
146
|
-
self.exportable = exportable
|
|
147
|
-
self.cfg = cfg
|
|
148
|
-
self.max_length = max_length
|
|
149
|
-
self.vocab_size = len(vocab)
|
|
150
|
-
self.rng = np.random.default_rng()
|
|
151
|
-
|
|
152
|
-
self.feat_extractor = feature_extractor
|
|
153
|
-
self.decoder = PARSeqDecoder(embedding_units, dec_num_heads, dec_ff_dim, dec_ffd_ratio, dropout_prob)
|
|
154
|
-
self.embed = CharEmbedding(self.vocab_size + 3, embedding_units) # +3 for SOS, EOS, PAD
|
|
155
|
-
self.head = layers.Dense(self.vocab_size + 1, name="head") # +1 for EOS
|
|
156
|
-
self.pos_queries = self.add_weight(
|
|
157
|
-
shape=(1, self.max_length + 1, embedding_units),
|
|
158
|
-
initializer="zeros",
|
|
159
|
-
trainable=True,
|
|
160
|
-
name="positions",
|
|
161
|
-
)
|
|
162
|
-
self.dropout = layers.Dropout(dropout_prob)
|
|
163
|
-
|
|
164
|
-
self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
|
|
165
|
-
|
|
166
|
-
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
167
|
-
"""Load pretrained parameters onto the model
|
|
168
|
-
|
|
169
|
-
Args:
|
|
170
|
-
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
171
|
-
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
172
|
-
"""
|
|
173
|
-
# NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
|
|
174
|
-
# ref.: https://github.com/mindee/doctr/issues/1911
|
|
175
|
-
kwargs["skip_mismatch"] = True
|
|
176
|
-
load_pretrained_params(self, path_or_url, **kwargs)
|
|
177
|
-
|
|
178
|
-
def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
|
|
179
|
-
# Generates permutations of the target sequence.
|
|
180
|
-
# Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
|
|
181
|
-
# with small modifications
|
|
182
|
-
|
|
183
|
-
max_num_chars = int(tf.reduce_max(seqlen)) # get longest sequence length in batch
|
|
184
|
-
perms = [tf.range(max_num_chars, dtype=tf.int32)]
|
|
185
|
-
|
|
186
|
-
max_perms = math.factorial(max_num_chars) // 2
|
|
187
|
-
num_gen_perms = min(3, max_perms)
|
|
188
|
-
if max_num_chars < 5:
|
|
189
|
-
# Pool of permutations to sample from. We only need the first half (if complementary option is selected)
|
|
190
|
-
# Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
|
|
191
|
-
if max_num_chars == 4:
|
|
192
|
-
selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
|
|
193
|
-
else:
|
|
194
|
-
selector = list(range(max_perms))
|
|
195
|
-
perm_pool_candidates = list(permutations(range(max_num_chars), max_num_chars))
|
|
196
|
-
perm_pool = tf.convert_to_tensor([perm_pool_candidates[i] for i in selector])
|
|
197
|
-
# If the forward permutation is always selected, no need to add it to the pool for sampling
|
|
198
|
-
perm_pool = perm_pool[1:]
|
|
199
|
-
final_perms = tf.stack(perms)
|
|
200
|
-
if len(perm_pool):
|
|
201
|
-
i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(final_perms), replace=False)
|
|
202
|
-
final_perms = tf.concat([final_perms, perm_pool[i[0] : i[1]]], axis=0)
|
|
203
|
-
else:
|
|
204
|
-
perms.extend([
|
|
205
|
-
tf.random.shuffle(tf.range(max_num_chars, dtype=tf.int32)) for _ in range(num_gen_perms - len(perms))
|
|
206
|
-
])
|
|
207
|
-
final_perms = tf.stack(perms)
|
|
208
|
-
|
|
209
|
-
comp = tf.reverse(final_perms, axis=[-1])
|
|
210
|
-
final_perms = tf.stack([final_perms, comp])
|
|
211
|
-
final_perms = tf.transpose(final_perms, perm=[1, 0, 2])
|
|
212
|
-
final_perms = tf.reshape(final_perms, shape=(-1, max_num_chars))
|
|
213
|
-
|
|
214
|
-
sos_idx = tf.zeros([tf.shape(final_perms)[0], 1], dtype=tf.int32)
|
|
215
|
-
eos_idx = tf.fill([tf.shape(final_perms)[0], 1], max_num_chars + 1)
|
|
216
|
-
combined = tf.concat([sos_idx, final_perms + 1, eos_idx], axis=1)
|
|
217
|
-
combined = tf.cast(combined, dtype=tf.int32)
|
|
218
|
-
if tf.shape(combined)[0] > 1:
|
|
219
|
-
combined = tf.tensor_scatter_nd_update(
|
|
220
|
-
combined, [[1, i] for i in range(1, max_num_chars + 2)], max_num_chars + 1 - tf.range(max_num_chars + 1)
|
|
221
|
-
)
|
|
222
|
-
return combined
|
|
223
|
-
|
|
224
|
-
def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
|
|
225
|
-
# Generate source and target mask for the decoder attention.
|
|
226
|
-
sz = permutation.shape[0]
|
|
227
|
-
mask = tf.ones((sz, sz), dtype=tf.float32)
|
|
228
|
-
|
|
229
|
-
for i in range(sz - 1):
|
|
230
|
-
query_idx = int(permutation[i])
|
|
231
|
-
masked_keys = permutation[i + 1 :].numpy().tolist()
|
|
232
|
-
indices = tf.constant([[query_idx, j] for j in masked_keys], dtype=tf.int32)
|
|
233
|
-
mask = tf.tensor_scatter_nd_update(mask, indices, tf.zeros(len(masked_keys), dtype=tf.float32))
|
|
234
|
-
|
|
235
|
-
source_mask = tf.identity(mask[:-1, :-1])
|
|
236
|
-
eye_indices = tf.eye(sz, dtype=tf.bool)
|
|
237
|
-
mask = tf.tensor_scatter_nd_update(
|
|
238
|
-
mask, tf.where(eye_indices), tf.zeros_like(tf.boolean_mask(mask, eye_indices))
|
|
239
|
-
)
|
|
240
|
-
target_mask = mask[1:, :-1]
|
|
241
|
-
return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)
|
|
242
|
-
|
|
243
|
-
def decode(
|
|
244
|
-
self,
|
|
245
|
-
target: tf.Tensor,
|
|
246
|
-
memory: tf.Tensor,
|
|
247
|
-
target_mask: tf.Tensor | None = None,
|
|
248
|
-
target_query: tf.Tensor | None = None,
|
|
249
|
-
**kwargs: Any,
|
|
250
|
-
) -> tf.Tensor:
|
|
251
|
-
batch_size, sequence_length = target.shape
|
|
252
|
-
# apply positional information to the target sequence excluding the SOS token
|
|
253
|
-
null_ctx = self.embed(target[:, :1], **kwargs)
|
|
254
|
-
content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:], **kwargs)
|
|
255
|
-
content = self.dropout(tf.concat([null_ctx, content], axis=1), **kwargs)
|
|
256
|
-
if target_query is None:
|
|
257
|
-
target_query = tf.tile(self.pos_queries[:, :sequence_length], [batch_size, 1, 1])
|
|
258
|
-
target_query = self.dropout(target_query, **kwargs)
|
|
259
|
-
return self.decoder(target_query, content, memory, target_mask, **kwargs)
|
|
260
|
-
|
|
261
|
-
def decode_autoregressive(self, features: tf.Tensor, max_len: int | None = None, **kwargs) -> tf.Tensor:
|
|
262
|
-
"""Generate predictions for the given features."""
|
|
263
|
-
max_length = max_len if max_len is not None else self.max_length
|
|
264
|
-
max_length = min(max_length, self.max_length) + 1
|
|
265
|
-
b = tf.shape(features)[0]
|
|
266
|
-
# Padding symbol + SOS at the beginning
|
|
267
|
-
ys = tf.fill(dims=(b, max_length), value=self.vocab_size + 2)
|
|
268
|
-
start_vector = tf.fill(dims=(b, 1), value=self.vocab_size + 1)
|
|
269
|
-
ys = tf.concat([start_vector, ys], axis=-1)
|
|
270
|
-
pos_queries = tf.tile(self.pos_queries[:, :max_length], [b, 1, 1])
|
|
271
|
-
query_mask = tf.cast(tf.linalg.band_part(tf.ones((max_length, max_length)), -1, 0), dtype=tf.bool)
|
|
272
|
-
|
|
273
|
-
pos_logits = []
|
|
274
|
-
for i in range(max_length):
|
|
275
|
-
# Decode one token at a time without providing information about the future tokens
|
|
276
|
-
tgt_out = self.decode(
|
|
277
|
-
ys[:, : i + 1],
|
|
278
|
-
features,
|
|
279
|
-
query_mask[i : i + 1, : i + 1],
|
|
280
|
-
target_query=pos_queries[:, i : i + 1],
|
|
281
|
-
**kwargs,
|
|
282
|
-
)
|
|
283
|
-
pos_prob = self.head(tgt_out)
|
|
284
|
-
pos_logits.append(pos_prob)
|
|
285
|
-
|
|
286
|
-
if i + 1 < max_length:
|
|
287
|
-
# update ys with the next token
|
|
288
|
-
i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(max_length), indexing="ij")
|
|
289
|
-
indices = tf.stack([i_mesh[:, i + 1], j_mesh[:, i + 1]], axis=1)
|
|
290
|
-
ys = tf.tensor_scatter_nd_update(
|
|
291
|
-
ys, indices, tf.cast(tf.argmax(pos_prob[:, -1, :], axis=-1), dtype=tf.int32)
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
# Stop decoding if all sequences have reached the EOS token
|
|
295
|
-
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
|
|
296
|
-
if (
|
|
297
|
-
not self.exportable
|
|
298
|
-
and max_len is None
|
|
299
|
-
and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1))
|
|
300
|
-
):
|
|
301
|
-
break
|
|
302
|
-
|
|
303
|
-
logits = tf.concat(pos_logits, axis=1) # (N, max_length, vocab_size + 1)
|
|
304
|
-
|
|
305
|
-
# One refine iteration
|
|
306
|
-
# Update query mask
|
|
307
|
-
diag_matrix = tf.eye(max_length)
|
|
308
|
-
diag_matrix = tf.cast(tf.logical_not(tf.cast(diag_matrix, dtype=tf.bool)), dtype=tf.float32)
|
|
309
|
-
query_mask = tf.cast(tf.concat([diag_matrix[1:], tf.ones((1, max_length))], axis=0), dtype=tf.bool)
|
|
310
|
-
|
|
311
|
-
sos = tf.fill((tf.shape(features)[0], 1), self.vocab_size + 1)
|
|
312
|
-
ys = tf.concat([sos, tf.cast(tf.argmax(logits[:, :-1], axis=-1), dtype=tf.int32)], axis=1)
|
|
313
|
-
# Create padding mask for refined target input maskes all behind EOS token as False
|
|
314
|
-
# (N, 1, 1, max_length)
|
|
315
|
-
mask = tf.cast(tf.equal(ys, self.vocab_size), tf.float32)
|
|
316
|
-
first_eos_indices = tf.argmax(mask, axis=1, output_type=tf.int32)
|
|
317
|
-
mask = tf.sequence_mask(first_eos_indices + 1, maxlen=ys.shape[-1], dtype=tf.float32)
|
|
318
|
-
target_pad_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.bool)
|
|
319
|
-
|
|
320
|
-
mask = tf.math.logical_and(target_pad_mask, query_mask[:, : ys.shape[1]])
|
|
321
|
-
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries, **kwargs), **kwargs)
|
|
322
|
-
|
|
323
|
-
return logits # (N, max_length, vocab_size + 1)
|
|
324
|
-
|
|
325
|
-
def call(
|
|
326
|
-
self,
|
|
327
|
-
x: tf.Tensor,
|
|
328
|
-
target: list[str] | None = None,
|
|
329
|
-
return_model_output: bool = False,
|
|
330
|
-
return_preds: bool = False,
|
|
331
|
-
**kwargs: Any,
|
|
332
|
-
) -> dict[str, Any]:
|
|
333
|
-
features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
|
|
334
|
-
# remove cls token
|
|
335
|
-
features = features[:, 1:, :]
|
|
336
|
-
|
|
337
|
-
if kwargs.get("training", False) and target is None:
|
|
338
|
-
raise ValueError("Need to provide labels during training")
|
|
339
|
-
|
|
340
|
-
if target is not None:
|
|
341
|
-
gt, seq_len = self.build_target(target)
|
|
342
|
-
seq_len = tf.cast(seq_len, tf.int32)
|
|
343
|
-
gt = gt[:, : int(tf.reduce_max(seq_len)) + 2] # slice up to the max length of the batch + 2 (SOS + EOS)
|
|
344
|
-
|
|
345
|
-
if kwargs.get("training", False):
|
|
346
|
-
# Generate permutations of the target sequences
|
|
347
|
-
tgt_perms = self.generate_permutations(seq_len)
|
|
348
|
-
|
|
349
|
-
gt_in = gt[:, :-1] # remove EOS token from longest target sequence
|
|
350
|
-
gt_out = gt[:, 1:] # remove SOS token
|
|
351
|
-
|
|
352
|
-
# Create padding mask for target input
|
|
353
|
-
# [True, True, True, ..., False, False, False] -> False is masked
|
|
354
|
-
padding_mask = tf.math.logical_and(
|
|
355
|
-
tf.math.not_equal(gt_in, self.vocab_size + 2), tf.math.not_equal(gt_in, self.vocab_size)
|
|
356
|
-
)
|
|
357
|
-
padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :] # (N, 1, 1, seq_len)
|
|
358
|
-
|
|
359
|
-
loss = tf.constant(0.0)
|
|
360
|
-
loss_numel = tf.constant(0.0)
|
|
361
|
-
n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
|
|
362
|
-
for i, perm in enumerate(tgt_perms):
|
|
363
|
-
_, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
|
|
364
|
-
# combine both masks to (N, 1, seq_len, seq_len)
|
|
365
|
-
mask = tf.logical_and(padding_mask, tf.expand_dims(tf.expand_dims(target_mask, axis=0), axis=0))
|
|
366
|
-
|
|
367
|
-
logits = self.head(self.decode(gt_in, features, mask, **kwargs), **kwargs)
|
|
368
|
-
logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
|
|
369
|
-
targets_flat = tf.reshape(gt_out, (-1,))
|
|
370
|
-
mask = tf.not_equal(targets_flat, self.vocab_size + 2)
|
|
371
|
-
loss += n * tf.reduce_mean(
|
|
372
|
-
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
|
373
|
-
labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
|
|
374
|
-
)
|
|
375
|
-
)
|
|
376
|
-
loss_numel += n
|
|
377
|
-
|
|
378
|
-
# After the second iteration (i.e. done with canonical and reverse orderings),
|
|
379
|
-
# remove the [EOS] tokens for the succeeding perms
|
|
380
|
-
if i == 1:
|
|
381
|
-
gt_out = tf.where(tf.equal(gt_out, self.vocab_size), self.vocab_size + 2, gt_out)
|
|
382
|
-
n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
|
|
383
|
-
|
|
384
|
-
loss /= loss_numel
|
|
385
|
-
|
|
386
|
-
else:
|
|
387
|
-
gt = gt[:, 1:] # remove SOS token
|
|
388
|
-
max_len = gt.shape[1] - 1 # exclude EOS token
|
|
389
|
-
logits = self.decode_autoregressive(features, max_len, **kwargs)
|
|
390
|
-
logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
|
|
391
|
-
targets_flat = tf.reshape(gt, (-1,))
|
|
392
|
-
mask = tf.not_equal(targets_flat, self.vocab_size + 2)
|
|
393
|
-
loss = tf.reduce_mean(
|
|
394
|
-
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
|
395
|
-
labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
|
|
396
|
-
)
|
|
397
|
-
)
|
|
398
|
-
else:
|
|
399
|
-
logits = self.decode_autoregressive(features, **kwargs)
|
|
400
|
-
|
|
401
|
-
logits = _bf16_to_float32(logits)
|
|
402
|
-
|
|
403
|
-
out: dict[str, tf.Tensor] = {}
|
|
404
|
-
if self.exportable:
|
|
405
|
-
out["logits"] = logits
|
|
406
|
-
return out
|
|
407
|
-
|
|
408
|
-
if return_model_output:
|
|
409
|
-
out["out_map"] = logits
|
|
410
|
-
|
|
411
|
-
if target is None or return_preds:
|
|
412
|
-
# Post-process boxes
|
|
413
|
-
out["preds"] = self.postprocessor(logits)
|
|
414
|
-
|
|
415
|
-
if target is not None:
|
|
416
|
-
out["loss"] = loss
|
|
417
|
-
|
|
418
|
-
return out
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
class PARSeqPostProcessor(_PARSeqPostProcessor):
|
|
422
|
-
"""Post processor for PARSeq architecture
|
|
423
|
-
|
|
424
|
-
Args:
|
|
425
|
-
vocab: string containing the ordered sequence of supported characters
|
|
426
|
-
"""
|
|
427
|
-
|
|
428
|
-
def __call__(
|
|
429
|
-
self,
|
|
430
|
-
logits: tf.Tensor,
|
|
431
|
-
) -> list[tuple[str, float]]:
|
|
432
|
-
# compute pred with argmax for attention models
|
|
433
|
-
out_idxs = tf.math.argmax(logits, axis=2)
|
|
434
|
-
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
|
|
435
|
-
|
|
436
|
-
# decode raw output of the model with tf_label_to_idx
|
|
437
|
-
out_idxs = tf.cast(out_idxs, dtype="int32")
|
|
438
|
-
embedding = tf.constant(self._embedding, dtype=tf.string)
|
|
439
|
-
decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
|
|
440
|
-
decoded_strings_pred = tf.strings.split(decoded_strings_pred, "<eos>")
|
|
441
|
-
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
|
|
442
|
-
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
|
|
443
|
-
|
|
444
|
-
# compute probabilties for each word up to the EOS token
|
|
445
|
-
probs = [
|
|
446
|
-
preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
|
|
447
|
-
for i, word in enumerate(word_values)
|
|
448
|
-
]
|
|
449
|
-
|
|
450
|
-
return list(zip(word_values, probs))
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
def _parseq(
|
|
454
|
-
arch: str,
|
|
455
|
-
pretrained: bool,
|
|
456
|
-
backbone_fn,
|
|
457
|
-
input_shape: tuple[int, int, int] | None = None,
|
|
458
|
-
**kwargs: Any,
|
|
459
|
-
) -> PARSeq:
|
|
460
|
-
# Patch the config
|
|
461
|
-
_cfg = deepcopy(default_cfgs[arch])
|
|
462
|
-
_cfg["input_shape"] = input_shape or _cfg["input_shape"]
|
|
463
|
-
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
|
|
464
|
-
patch_size = kwargs.get("patch_size", (4, 8))
|
|
465
|
-
|
|
466
|
-
kwargs["vocab"] = _cfg["vocab"]
|
|
467
|
-
|
|
468
|
-
# Feature extractor
|
|
469
|
-
feat_extractor = backbone_fn(
|
|
470
|
-
# NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch
|
|
471
|
-
pretrained=False,
|
|
472
|
-
input_shape=_cfg["input_shape"],
|
|
473
|
-
patch_size=patch_size,
|
|
474
|
-
include_top=False,
|
|
475
|
-
)
|
|
476
|
-
|
|
477
|
-
kwargs.pop("patch_size", None)
|
|
478
|
-
kwargs.pop("pretrained_backbone", None)
|
|
479
|
-
|
|
480
|
-
# Build the model
|
|
481
|
-
model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
|
|
482
|
-
_build_model(model)
|
|
483
|
-
|
|
484
|
-
# Load pretrained parameters
|
|
485
|
-
if pretrained:
|
|
486
|
-
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
487
|
-
model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
488
|
-
|
|
489
|
-
return model
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
|
|
493
|
-
"""PARSeq architecture from
|
|
494
|
-
`"Scene Text Recognition with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
|
|
495
|
-
|
|
496
|
-
>>> import tensorflow as tf
|
|
497
|
-
>>> from doctr.models import parseq
|
|
498
|
-
>>> model = parseq(pretrained=False)
|
|
499
|
-
>>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
|
|
500
|
-
>>> out = model(input_tensor)
|
|
501
|
-
|
|
502
|
-
Args:
|
|
503
|
-
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
|
|
504
|
-
**kwargs: keyword arguments of the PARSeq architecture
|
|
505
|
-
|
|
506
|
-
Returns:
|
|
507
|
-
text recognition architecture
|
|
508
|
-
"""
|
|
509
|
-
return _parseq(
|
|
510
|
-
"parseq",
|
|
511
|
-
pretrained,
|
|
512
|
-
vit_s,
|
|
513
|
-
embedding_units=384,
|
|
514
|
-
patch_size=(4, 8),
|
|
515
|
-
**kwargs,
|
|
516
|
-
)
|
|
@@ -1,79 +0,0 @@
|
|
|
1
|
-
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
-
|
|
3
|
-
# This program is licensed under the Apache License 2.0.
|
|
4
|
-
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
-
|
|
6
|
-
from typing import Any
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
import tensorflow as tf
|
|
10
|
-
|
|
11
|
-
from doctr.models.preprocessor import PreProcessor
|
|
12
|
-
from doctr.utils.repr import NestedObject
|
|
13
|
-
|
|
14
|
-
from ..core import RecognitionModel
|
|
15
|
-
from ._utils import remap_preds, split_crops
|
|
16
|
-
|
|
17
|
-
__all__ = ["RecognitionPredictor"]
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class RecognitionPredictor(NestedObject):
|
|
21
|
-
"""Implements an object able to identify character sequences in images
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
pre_processor: transform inputs for easier batched model inference
|
|
25
|
-
model: core detection architecture
|
|
26
|
-
split_wide_crops: wether to use crop splitting for high aspect ratio crops
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
_children_names: list[str] = ["pre_processor", "model"]
|
|
30
|
-
|
|
31
|
-
def __init__(
|
|
32
|
-
self,
|
|
33
|
-
pre_processor: PreProcessor,
|
|
34
|
-
model: RecognitionModel,
|
|
35
|
-
split_wide_crops: bool = True,
|
|
36
|
-
) -> None:
|
|
37
|
-
super().__init__()
|
|
38
|
-
self.pre_processor = pre_processor
|
|
39
|
-
self.model = model
|
|
40
|
-
self.split_wide_crops = split_wide_crops
|
|
41
|
-
self.critical_ar = 8 # Critical aspect ratio
|
|
42
|
-
self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
|
|
43
|
-
self.target_ar = 6 # Target aspect ratio
|
|
44
|
-
|
|
45
|
-
def __call__(
|
|
46
|
-
self,
|
|
47
|
-
crops: list[np.ndarray | tf.Tensor],
|
|
48
|
-
**kwargs: Any,
|
|
49
|
-
) -> list[tuple[str, float]]:
|
|
50
|
-
if len(crops) == 0:
|
|
51
|
-
return []
|
|
52
|
-
# Dimension check
|
|
53
|
-
if any(crop.ndim != 3 for crop in crops):
|
|
54
|
-
raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
|
|
55
|
-
|
|
56
|
-
# Split crops that are too wide
|
|
57
|
-
remapped = False
|
|
58
|
-
if self.split_wide_crops:
|
|
59
|
-
new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.overlap_ratio)
|
|
60
|
-
if remapped:
|
|
61
|
-
crops = new_crops
|
|
62
|
-
|
|
63
|
-
# Resize & batch them
|
|
64
|
-
processed_batches = self.pre_processor(crops)
|
|
65
|
-
|
|
66
|
-
# Forward it
|
|
67
|
-
raw = [
|
|
68
|
-
self.model(batch, return_preds=True, training=False, **kwargs)["preds"] # type: ignore[operator]
|
|
69
|
-
for batch in processed_batches
|
|
70
|
-
]
|
|
71
|
-
|
|
72
|
-
# Process outputs
|
|
73
|
-
out = [charseq for batch in raw for charseq in batch]
|
|
74
|
-
|
|
75
|
-
# Remap crops
|
|
76
|
-
if self.split_wide_crops and remapped:
|
|
77
|
-
out = remap_preds(out, crop_map, self.overlap_ratio)
|
|
78
|
-
|
|
79
|
-
return out
|