python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -1,508 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import math
7
- from copy import deepcopy
8
- from itertools import permutations
9
- from typing import Any
10
-
11
- import numpy as np
12
- import tensorflow as tf
13
- from tensorflow.keras import Model, layers
14
-
15
- from doctr.datasets import VOCABS
16
- from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
17
-
18
- from ...classification import vit_s
19
- from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
20
- from .base import _PARSeq, _PARSeqPostProcessor
21
-
22
- __all__ = ["PARSeq", "parseq"]
23
-
24
- default_cfgs: dict[str, dict[str, Any]] = {
25
- "parseq": {
26
- "mean": (0.694, 0.695, 0.693),
27
- "std": (0.299, 0.296, 0.301),
28
- "input_shape": (32, 128, 3),
29
- "vocab": VOCABS["french"],
30
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0",
31
- },
32
- }
33
-
34
-
35
- class CharEmbedding(layers.Layer):
36
- """Implements the character embedding module
37
-
38
- Args:
39
- -
40
- vocab_size: size of the vocabulary
41
- d_model: dimension of the model
42
- """
43
-
44
- def __init__(self, vocab_size: int, d_model: int):
45
- super(CharEmbedding, self).__init__()
46
- self.embedding = layers.Embedding(vocab_size, d_model)
47
- self.d_model = d_model
48
-
49
- def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
50
- return math.sqrt(self.d_model) * self.embedding(x, **kwargs)
51
-
52
-
53
- class PARSeqDecoder(layers.Layer):
54
- """Implements decoder module of the PARSeq model
55
-
56
- Args:
57
- d_model: dimension of the model
58
- num_heads: number of attention heads
59
- ffd: dimension of the feed forward layer
60
- ffd_ratio: depth multiplier for the feed forward layer
61
- dropout: dropout rate
62
- """
63
-
64
- def __init__(
65
- self,
66
- d_model: int,
67
- num_heads: int = 12,
68
- ffd: int = 2048,
69
- ffd_ratio: int = 4,
70
- dropout: float = 0.1,
71
- ):
72
- super(PARSeqDecoder, self).__init__()
73
- self.attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
74
- self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
75
- self.position_feed_forward = PositionwiseFeedForward(
76
- d_model, ffd * ffd_ratio, dropout, layers.Activation(tf.nn.gelu)
77
- )
78
-
79
- self.attention_norm = layers.LayerNormalization(epsilon=1e-5)
80
- self.cross_attention_norm = layers.LayerNormalization(epsilon=1e-5)
81
- self.query_norm = layers.LayerNormalization(epsilon=1e-5)
82
- self.content_norm = layers.LayerNormalization(epsilon=1e-5)
83
- self.feed_forward_norm = layers.LayerNormalization(epsilon=1e-5)
84
- self.output_norm = layers.LayerNormalization(epsilon=1e-5)
85
- self.attention_dropout = layers.Dropout(dropout)
86
- self.cross_attention_dropout = layers.Dropout(dropout)
87
- self.feed_forward_dropout = layers.Dropout(dropout)
88
-
89
- def call(
90
- self,
91
- target,
92
- content,
93
- memory,
94
- target_mask=None,
95
- **kwargs: Any,
96
- ):
97
- query_norm = self.query_norm(target, **kwargs)
98
- content_norm = self.content_norm(content, **kwargs)
99
- target = target + self.attention_dropout(
100
- self.attention(query_norm, content_norm, content_norm, mask=target_mask, **kwargs), **kwargs
101
- )
102
- target = target + self.cross_attention_dropout(
103
- self.cross_attention(self.query_norm(target, **kwargs), memory, memory, **kwargs), **kwargs
104
- )
105
- target = target + self.feed_forward_dropout(
106
- self.position_feed_forward(self.feed_forward_norm(target, **kwargs), **kwargs), **kwargs
107
- )
108
- return self.output_norm(target, **kwargs)
109
-
110
-
111
- class PARSeq(_PARSeq, Model):
112
- """Implements a PARSeq architecture as described in `"Scene Text Recognition
113
- with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
114
- Modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
115
-
116
- Args:
117
- feature_extractor: the backbone serving as feature extractor
118
- vocab: vocabulary used for encoding
119
- embedding_units: number of embedding units
120
- max_length: maximum word length handled by the model
121
- dropout_prob: dropout probability for the decoder
122
- dec_num_heads: number of attention heads in the decoder
123
- dec_ff_dim: dimension of the feed forward layer in the decoder
124
- dec_ffd_ratio: depth multiplier for the feed forward layer in the decoder
125
- input_shape: input shape of the image
126
- exportable: onnx exportable returns only logits
127
- cfg: dictionary containing information about the model
128
- """
129
-
130
- _children_names: list[str] = ["feat_extractor", "postprocessor"]
131
-
132
- def __init__(
133
- self,
134
- feature_extractor,
135
- vocab: str,
136
- embedding_units: int,
137
- max_length: int = 32, # different from paper
138
- dropout_prob: float = 0.1,
139
- dec_num_heads: int = 12,
140
- dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
141
- dec_ffd_ratio: int = 4,
142
- input_shape: tuple[int, int, int] = (32, 128, 3),
143
- exportable: bool = False,
144
- cfg: dict[str, Any] | None = None,
145
- ) -> None:
146
- super().__init__()
147
- self.vocab = vocab
148
- self.exportable = exportable
149
- self.cfg = cfg
150
- self.max_length = max_length
151
- self.vocab_size = len(vocab)
152
- self.rng = np.random.default_rng()
153
-
154
- self.feat_extractor = feature_extractor
155
- self.decoder = PARSeqDecoder(embedding_units, dec_num_heads, dec_ff_dim, dec_ffd_ratio, dropout_prob)
156
- self.embed = CharEmbedding(self.vocab_size + 3, embedding_units) # +3 for SOS, EOS, PAD
157
- self.head = layers.Dense(self.vocab_size + 1, name="head") # +1 for EOS
158
- self.pos_queries = self.add_weight(
159
- shape=(1, self.max_length + 1, embedding_units),
160
- initializer="zeros",
161
- trainable=True,
162
- name="positions",
163
- )
164
- self.dropout = layers.Dropout(dropout_prob)
165
-
166
- self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
167
-
168
- def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
169
- # Generates permutations of the target sequence.
170
- # Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
171
- # with small modifications
172
-
173
- max_num_chars = int(tf.reduce_max(seqlen)) # get longest sequence length in batch
174
- perms = [tf.range(max_num_chars, dtype=tf.int32)]
175
-
176
- max_perms = math.factorial(max_num_chars) // 2
177
- num_gen_perms = min(3, max_perms)
178
- if max_num_chars < 5:
179
- # Pool of permutations to sample from. We only need the first half (if complementary option is selected)
180
- # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
181
- if max_num_chars == 4:
182
- selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
183
- else:
184
- selector = list(range(max_perms))
185
- perm_pool_candidates = list(permutations(range(max_num_chars), max_num_chars))
186
- perm_pool = tf.convert_to_tensor([perm_pool_candidates[i] for i in selector])
187
- # If the forward permutation is always selected, no need to add it to the pool for sampling
188
- perm_pool = perm_pool[1:]
189
- final_perms = tf.stack(perms)
190
- if len(perm_pool):
191
- i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(final_perms), replace=False)
192
- final_perms = tf.concat([final_perms, perm_pool[i[0] : i[1]]], axis=0)
193
- else:
194
- perms.extend([
195
- tf.random.shuffle(tf.range(max_num_chars, dtype=tf.int32)) for _ in range(num_gen_perms - len(perms))
196
- ])
197
- final_perms = tf.stack(perms)
198
-
199
- comp = tf.reverse(final_perms, axis=[-1])
200
- final_perms = tf.stack([final_perms, comp])
201
- final_perms = tf.transpose(final_perms, perm=[1, 0, 2])
202
- final_perms = tf.reshape(final_perms, shape=(-1, max_num_chars))
203
-
204
- sos_idx = tf.zeros([tf.shape(final_perms)[0], 1], dtype=tf.int32)
205
- eos_idx = tf.fill([tf.shape(final_perms)[0], 1], max_num_chars + 1)
206
- combined = tf.concat([sos_idx, final_perms + 1, eos_idx], axis=1)
207
- combined = tf.cast(combined, dtype=tf.int32)
208
- if tf.shape(combined)[0] > 1:
209
- combined = tf.tensor_scatter_nd_update(
210
- combined, [[1, i] for i in range(1, max_num_chars + 2)], max_num_chars + 1 - tf.range(max_num_chars + 1)
211
- )
212
- return combined
213
-
214
- def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
215
- # Generate source and target mask for the decoder attention.
216
- sz = permutation.shape[0]
217
- mask = tf.ones((sz, sz), dtype=tf.float32)
218
-
219
- for i in range(sz - 1):
220
- query_idx = int(permutation[i])
221
- masked_keys = permutation[i + 1 :].numpy().tolist()
222
- indices = tf.constant([[query_idx, j] for j in masked_keys], dtype=tf.int32)
223
- mask = tf.tensor_scatter_nd_update(mask, indices, tf.zeros(len(masked_keys), dtype=tf.float32))
224
-
225
- source_mask = tf.identity(mask[:-1, :-1])
226
- eye_indices = tf.eye(sz, dtype=tf.bool)
227
- mask = tf.tensor_scatter_nd_update(
228
- mask, tf.where(eye_indices), tf.zeros_like(tf.boolean_mask(mask, eye_indices))
229
- )
230
- target_mask = mask[1:, :-1]
231
- return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)
232
-
233
- def decode(
234
- self,
235
- target: tf.Tensor,
236
- memory: tf.Tensor,
237
- target_mask: tf.Tensor | None = None,
238
- target_query: tf.Tensor | None = None,
239
- **kwargs: Any,
240
- ) -> tf.Tensor:
241
- batch_size, sequence_length = target.shape
242
- # apply positional information to the target sequence excluding the SOS token
243
- null_ctx = self.embed(target[:, :1], **kwargs)
244
- content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:], **kwargs)
245
- content = self.dropout(tf.concat([null_ctx, content], axis=1), **kwargs)
246
- if target_query is None:
247
- target_query = tf.tile(self.pos_queries[:, :sequence_length], [batch_size, 1, 1])
248
- target_query = self.dropout(target_query, **kwargs)
249
- return self.decoder(target_query, content, memory, target_mask, **kwargs)
250
-
251
- def decode_autoregressive(self, features: tf.Tensor, max_len: int | None = None, **kwargs) -> tf.Tensor:
252
- """Generate predictions for the given features."""
253
- max_length = max_len if max_len is not None else self.max_length
254
- max_length = min(max_length, self.max_length) + 1
255
- b = tf.shape(features)[0]
256
- # Padding symbol + SOS at the beginning
257
- ys = tf.fill(dims=(b, max_length), value=self.vocab_size + 2)
258
- start_vector = tf.fill(dims=(b, 1), value=self.vocab_size + 1)
259
- ys = tf.concat([start_vector, ys], axis=-1)
260
- pos_queries = tf.tile(self.pos_queries[:, :max_length], [b, 1, 1])
261
- query_mask = tf.cast(tf.linalg.band_part(tf.ones((max_length, max_length)), -1, 0), dtype=tf.bool)
262
-
263
- pos_logits = []
264
- for i in range(max_length):
265
- # Decode one token at a time without providing information about the future tokens
266
- tgt_out = self.decode(
267
- ys[:, : i + 1],
268
- features,
269
- query_mask[i : i + 1, : i + 1],
270
- target_query=pos_queries[:, i : i + 1],
271
- **kwargs,
272
- )
273
- pos_prob = self.head(tgt_out)
274
- pos_logits.append(pos_prob)
275
-
276
- if i + 1 < max_length:
277
- # update ys with the next token
278
- i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(max_length), indexing="ij")
279
- indices = tf.stack([i_mesh[:, i + 1], j_mesh[:, i + 1]], axis=1)
280
- ys = tf.tensor_scatter_nd_update(
281
- ys, indices, tf.cast(tf.argmax(pos_prob[:, -1, :], axis=-1), dtype=tf.int32)
282
- )
283
-
284
- # Stop decoding if all sequences have reached the EOS token
285
- # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
286
- if (
287
- not self.exportable
288
- and max_len is None
289
- and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1))
290
- ):
291
- break
292
-
293
- logits = tf.concat(pos_logits, axis=1) # (N, max_length, vocab_size + 1)
294
-
295
- # One refine iteration
296
- # Update query mask
297
- diag_matrix = tf.eye(max_length)
298
- diag_matrix = tf.cast(tf.logical_not(tf.cast(diag_matrix, dtype=tf.bool)), dtype=tf.float32)
299
- query_mask = tf.cast(tf.concat([diag_matrix[1:], tf.ones((1, max_length))], axis=0), dtype=tf.bool)
300
-
301
- sos = tf.fill((tf.shape(features)[0], 1), self.vocab_size + 1)
302
- ys = tf.concat([sos, tf.cast(tf.argmax(logits[:, :-1], axis=-1), dtype=tf.int32)], axis=1)
303
- # Create padding mask for refined target input maskes all behind EOS token as False
304
- # (N, 1, 1, max_length)
305
- mask = tf.cast(tf.equal(ys, self.vocab_size), tf.float32)
306
- first_eos_indices = tf.argmax(mask, axis=1, output_type=tf.int32)
307
- mask = tf.sequence_mask(first_eos_indices + 1, maxlen=ys.shape[-1], dtype=tf.float32)
308
- target_pad_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.bool)
309
-
310
- mask = tf.math.logical_and(target_pad_mask, query_mask[:, : ys.shape[1]])
311
- logits = self.head(self.decode(ys, features, mask, target_query=pos_queries, **kwargs), **kwargs)
312
-
313
- return logits # (N, max_length, vocab_size + 1)
314
-
315
- def call(
316
- self,
317
- x: tf.Tensor,
318
- target: list[str] | None = None,
319
- return_model_output: bool = False,
320
- return_preds: bool = False,
321
- **kwargs: Any,
322
- ) -> dict[str, Any]:
323
- features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
324
- # remove cls token
325
- features = features[:, 1:, :]
326
-
327
- if kwargs.get("training", False) and target is None:
328
- raise ValueError("Need to provide labels during training")
329
-
330
- if target is not None:
331
- gt, seq_len = self.build_target(target)
332
- seq_len = tf.cast(seq_len, tf.int32)
333
- gt = gt[:, : int(tf.reduce_max(seq_len)) + 2] # slice up to the max length of the batch + 2 (SOS + EOS)
334
-
335
- if kwargs.get("training", False):
336
- # Generate permutations of the target sequences
337
- tgt_perms = self.generate_permutations(seq_len)
338
-
339
- gt_in = gt[:, :-1] # remove EOS token from longest target sequence
340
- gt_out = gt[:, 1:] # remove SOS token
341
-
342
- # Create padding mask for target input
343
- # [True, True, True, ..., False, False, False] -> False is masked
344
- padding_mask = tf.math.logical_and(
345
- tf.math.not_equal(gt_in, self.vocab_size + 2), tf.math.not_equal(gt_in, self.vocab_size)
346
- )
347
- padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :] # (N, 1, 1, seq_len)
348
-
349
- loss = tf.constant(0.0)
350
- loss_numel = tf.constant(0.0)
351
- n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
352
- for i, perm in enumerate(tgt_perms):
353
- _, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
354
- # combine both masks to (N, 1, seq_len, seq_len)
355
- mask = tf.logical_and(padding_mask, tf.expand_dims(tf.expand_dims(target_mask, axis=0), axis=0))
356
-
357
- logits = self.head(self.decode(gt_in, features, mask, **kwargs), **kwargs)
358
- logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
359
- targets_flat = tf.reshape(gt_out, (-1,))
360
- mask = tf.not_equal(targets_flat, self.vocab_size + 2)
361
- loss += n * tf.reduce_mean(
362
- tf.nn.sparse_softmax_cross_entropy_with_logits(
363
- labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
364
- )
365
- )
366
- loss_numel += n
367
-
368
- # After the second iteration (i.e. done with canonical and reverse orderings),
369
- # remove the [EOS] tokens for the succeeding perms
370
- if i == 1:
371
- gt_out = tf.where(tf.equal(gt_out, self.vocab_size), self.vocab_size + 2, gt_out)
372
- n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
373
-
374
- loss /= loss_numel
375
-
376
- else:
377
- gt = gt[:, 1:] # remove SOS token
378
- max_len = gt.shape[1] - 1 # exclude EOS token
379
- logits = self.decode_autoregressive(features, max_len, **kwargs)
380
- logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
381
- targets_flat = tf.reshape(gt, (-1,))
382
- mask = tf.not_equal(targets_flat, self.vocab_size + 2)
383
- loss = tf.reduce_mean(
384
- tf.nn.sparse_softmax_cross_entropy_with_logits(
385
- labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
386
- )
387
- )
388
- else:
389
- logits = self.decode_autoregressive(features, **kwargs)
390
-
391
- logits = _bf16_to_float32(logits)
392
-
393
- out: dict[str, tf.Tensor] = {}
394
- if self.exportable:
395
- out["logits"] = logits
396
- return out
397
-
398
- if return_model_output:
399
- out["out_map"] = logits
400
-
401
- if target is None or return_preds:
402
- # Post-process boxes
403
- out["preds"] = self.postprocessor(logits)
404
-
405
- if target is not None:
406
- out["loss"] = loss
407
-
408
- return out
409
-
410
-
411
- class PARSeqPostProcessor(_PARSeqPostProcessor):
412
- """Post processor for PARSeq architecture
413
-
414
- Args:
415
- vocab: string containing the ordered sequence of supported characters
416
- """
417
-
418
- def __call__(
419
- self,
420
- logits: tf.Tensor,
421
- ) -> list[tuple[str, float]]:
422
- # compute pred with argmax for attention models
423
- out_idxs = tf.math.argmax(logits, axis=2)
424
- preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
425
-
426
- # decode raw output of the model with tf_label_to_idx
427
- out_idxs = tf.cast(out_idxs, dtype="int32")
428
- embedding = tf.constant(self._embedding, dtype=tf.string)
429
- decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
430
- decoded_strings_pred = tf.strings.split(decoded_strings_pred, "<eos>")
431
- decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
432
- word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
433
-
434
- # compute probabilties for each word up to the EOS token
435
- probs = [
436
- preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
437
- for i, word in enumerate(word_values)
438
- ]
439
-
440
- return list(zip(word_values, probs))
441
-
442
-
443
- def _parseq(
444
- arch: str,
445
- pretrained: bool,
446
- backbone_fn,
447
- input_shape: tuple[int, int, int] | None = None,
448
- **kwargs: Any,
449
- ) -> PARSeq:
450
- # Patch the config
451
- _cfg = deepcopy(default_cfgs[arch])
452
- _cfg["input_shape"] = input_shape or _cfg["input_shape"]
453
- _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
454
- patch_size = kwargs.get("patch_size", (4, 8))
455
-
456
- kwargs["vocab"] = _cfg["vocab"]
457
-
458
- # Feature extractor
459
- feat_extractor = backbone_fn(
460
- # NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch
461
- pretrained=False,
462
- input_shape=_cfg["input_shape"],
463
- patch_size=patch_size,
464
- include_top=False,
465
- )
466
-
467
- kwargs.pop("patch_size", None)
468
- kwargs.pop("pretrained_backbone", None)
469
-
470
- # Build the model
471
- model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
472
- _build_model(model)
473
-
474
- # Load pretrained parameters
475
- if pretrained:
476
- # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
477
- load_pretrained_params(
478
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
479
- )
480
-
481
- return model
482
-
483
-
484
- def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
485
- """PARSeq architecture from
486
- `"Scene Text Recognition with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
487
-
488
- >>> import tensorflow as tf
489
- >>> from doctr.models import parseq
490
- >>> model = parseq(pretrained=False)
491
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
492
- >>> out = model(input_tensor)
493
-
494
- Args:
495
- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
496
- **kwargs: keyword arguments of the PARSeq architecture
497
-
498
- Returns:
499
- text recognition architecture
500
- """
501
- return _parseq(
502
- "parseq",
503
- pretrained,
504
- vit_s,
505
- embedding_units=384,
506
- patch_size=(4, 8),
507
- **kwargs,
508
- )
@@ -1,79 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from typing import Any
7
-
8
- import numpy as np
9
- import tensorflow as tf
10
-
11
- from doctr.models.preprocessor import PreProcessor
12
- from doctr.utils.repr import NestedObject
13
-
14
- from ..core import RecognitionModel
15
- from ._utils import remap_preds, split_crops
16
-
17
- __all__ = ["RecognitionPredictor"]
18
-
19
-
20
- class RecognitionPredictor(NestedObject):
21
- """Implements an object able to identify character sequences in images
22
-
23
- Args:
24
- pre_processor: transform inputs for easier batched model inference
25
- model: core detection architecture
26
- split_wide_crops: wether to use crop splitting for high aspect ratio crops
27
- """
28
-
29
- _children_names: list[str] = ["pre_processor", "model"]
30
-
31
- def __init__(
32
- self,
33
- pre_processor: PreProcessor,
34
- model: RecognitionModel,
35
- split_wide_crops: bool = True,
36
- ) -> None:
37
- super().__init__()
38
- self.pre_processor = pre_processor
39
- self.model = model
40
- self.split_wide_crops = split_wide_crops
41
- self.critical_ar = 8 # Critical aspect ratio
42
- self.dil_factor = 1.4 # Dilation factor to overlap the crops
43
- self.target_ar = 6 # Target aspect ratio
44
-
45
- def __call__(
46
- self,
47
- crops: list[np.ndarray | tf.Tensor],
48
- **kwargs: Any,
49
- ) -> list[tuple[str, float]]:
50
- if len(crops) == 0:
51
- return []
52
- # Dimension check
53
- if any(crop.ndim != 3 for crop in crops):
54
- raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
55
-
56
- # Split crops that are too wide
57
- remapped = False
58
- if self.split_wide_crops:
59
- new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.dil_factor)
60
- if remapped:
61
- crops = new_crops
62
-
63
- # Resize & batch them
64
- processed_batches = self.pre_processor(crops)
65
-
66
- # Forward it
67
- raw = [
68
- self.model(batch, return_preds=True, training=False, **kwargs)["preds"] # type: ignore[operator]
69
- for batch in processed_batches
70
- ]
71
-
72
- # Process outputs
73
- out = [charseq for batch in raw for charseq in batch]
74
-
75
- # Remap crops
76
- if self.split_wide_crops and remapped:
77
- out = remap_preds(out, crop_map, self.dil_factor)
78
-
79
- return out