python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/contrib/artefacts.py +1 -1
  3. doctr/contrib/base.py +1 -1
  4. doctr/datasets/__init__.py +0 -5
  5. doctr/datasets/coco_text.py +1 -1
  6. doctr/datasets/cord.py +1 -1
  7. doctr/datasets/datasets/__init__.py +1 -6
  8. doctr/datasets/datasets/base.py +1 -1
  9. doctr/datasets/datasets/pytorch.py +3 -3
  10. doctr/datasets/detection.py +1 -1
  11. doctr/datasets/doc_artefacts.py +1 -1
  12. doctr/datasets/funsd.py +1 -1
  13. doctr/datasets/generator/__init__.py +1 -6
  14. doctr/datasets/generator/base.py +1 -1
  15. doctr/datasets/generator/pytorch.py +1 -1
  16. doctr/datasets/ic03.py +1 -1
  17. doctr/datasets/ic13.py +1 -1
  18. doctr/datasets/iiit5k.py +1 -1
  19. doctr/datasets/iiithws.py +1 -1
  20. doctr/datasets/imgur5k.py +1 -1
  21. doctr/datasets/mjsynth.py +1 -1
  22. doctr/datasets/ocr.py +1 -1
  23. doctr/datasets/orientation.py +1 -1
  24. doctr/datasets/recognition.py +1 -1
  25. doctr/datasets/sroie.py +1 -1
  26. doctr/datasets/svhn.py +1 -1
  27. doctr/datasets/svt.py +1 -1
  28. doctr/datasets/synthtext.py +1 -1
  29. doctr/datasets/utils.py +1 -1
  30. doctr/datasets/vocabs.py +1 -3
  31. doctr/datasets/wildreceipt.py +1 -1
  32. doctr/file_utils.py +3 -102
  33. doctr/io/elements.py +1 -1
  34. doctr/io/html.py +1 -1
  35. doctr/io/image/__init__.py +1 -7
  36. doctr/io/image/base.py +1 -1
  37. doctr/io/image/pytorch.py +2 -2
  38. doctr/io/pdf.py +1 -1
  39. doctr/io/reader.py +1 -1
  40. doctr/models/_utils.py +56 -18
  41. doctr/models/builder.py +1 -1
  42. doctr/models/classification/magc_resnet/__init__.py +1 -6
  43. doctr/models/classification/magc_resnet/pytorch.py +3 -3
  44. doctr/models/classification/mobilenet/__init__.py +1 -6
  45. doctr/models/classification/mobilenet/pytorch.py +1 -1
  46. doctr/models/classification/predictor/__init__.py +1 -6
  47. doctr/models/classification/predictor/pytorch.py +2 -2
  48. doctr/models/classification/resnet/__init__.py +1 -6
  49. doctr/models/classification/resnet/pytorch.py +1 -1
  50. doctr/models/classification/textnet/__init__.py +1 -6
  51. doctr/models/classification/textnet/pytorch.py +2 -2
  52. doctr/models/classification/vgg/__init__.py +1 -6
  53. doctr/models/classification/vgg/pytorch.py +1 -1
  54. doctr/models/classification/vip/__init__.py +1 -4
  55. doctr/models/classification/vip/layers/__init__.py +1 -4
  56. doctr/models/classification/vip/layers/pytorch.py +2 -2
  57. doctr/models/classification/vip/pytorch.py +1 -1
  58. doctr/models/classification/vit/__init__.py +1 -6
  59. doctr/models/classification/vit/pytorch.py +3 -3
  60. doctr/models/classification/zoo.py +7 -12
  61. doctr/models/core.py +1 -1
  62. doctr/models/detection/_utils/__init__.py +1 -6
  63. doctr/models/detection/_utils/base.py +1 -1
  64. doctr/models/detection/_utils/pytorch.py +1 -1
  65. doctr/models/detection/core.py +2 -2
  66. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  67. doctr/models/detection/differentiable_binarization/base.py +5 -13
  68. doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
  69. doctr/models/detection/fast/__init__.py +1 -6
  70. doctr/models/detection/fast/base.py +5 -15
  71. doctr/models/detection/fast/pytorch.py +5 -5
  72. doctr/models/detection/linknet/__init__.py +1 -6
  73. doctr/models/detection/linknet/base.py +4 -13
  74. doctr/models/detection/linknet/pytorch.py +3 -3
  75. doctr/models/detection/predictor/__init__.py +1 -6
  76. doctr/models/detection/predictor/pytorch.py +2 -2
  77. doctr/models/detection/zoo.py +16 -33
  78. doctr/models/factory/hub.py +26 -34
  79. doctr/models/kie_predictor/__init__.py +1 -6
  80. doctr/models/kie_predictor/base.py +1 -1
  81. doctr/models/kie_predictor/pytorch.py +3 -7
  82. doctr/models/modules/layers/__init__.py +1 -6
  83. doctr/models/modules/layers/pytorch.py +4 -4
  84. doctr/models/modules/transformer/__init__.py +1 -6
  85. doctr/models/modules/transformer/pytorch.py +3 -3
  86. doctr/models/modules/vision_transformer/__init__.py +1 -6
  87. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  88. doctr/models/predictor/__init__.py +1 -6
  89. doctr/models/predictor/base.py +4 -9
  90. doctr/models/predictor/pytorch.py +3 -6
  91. doctr/models/preprocessor/__init__.py +1 -6
  92. doctr/models/preprocessor/pytorch.py +28 -33
  93. doctr/models/recognition/core.py +1 -1
  94. doctr/models/recognition/crnn/__init__.py +1 -6
  95. doctr/models/recognition/crnn/pytorch.py +7 -7
  96. doctr/models/recognition/master/__init__.py +1 -6
  97. doctr/models/recognition/master/base.py +1 -1
  98. doctr/models/recognition/master/pytorch.py +6 -6
  99. doctr/models/recognition/parseq/__init__.py +1 -6
  100. doctr/models/recognition/parseq/base.py +1 -1
  101. doctr/models/recognition/parseq/pytorch.py +6 -6
  102. doctr/models/recognition/predictor/__init__.py +1 -6
  103. doctr/models/recognition/predictor/_utils.py +8 -17
  104. doctr/models/recognition/predictor/pytorch.py +2 -3
  105. doctr/models/recognition/sar/__init__.py +1 -6
  106. doctr/models/recognition/sar/pytorch.py +4 -4
  107. doctr/models/recognition/utils.py +1 -1
  108. doctr/models/recognition/viptr/__init__.py +1 -4
  109. doctr/models/recognition/viptr/pytorch.py +4 -4
  110. doctr/models/recognition/vitstr/__init__.py +1 -6
  111. doctr/models/recognition/vitstr/base.py +1 -1
  112. doctr/models/recognition/vitstr/pytorch.py +4 -4
  113. doctr/models/recognition/zoo.py +14 -14
  114. doctr/models/utils/__init__.py +1 -6
  115. doctr/models/utils/pytorch.py +3 -2
  116. doctr/models/zoo.py +1 -1
  117. doctr/transforms/functional/__init__.py +1 -6
  118. doctr/transforms/functional/base.py +3 -2
  119. doctr/transforms/functional/pytorch.py +5 -5
  120. doctr/transforms/modules/__init__.py +1 -7
  121. doctr/transforms/modules/base.py +28 -94
  122. doctr/transforms/modules/pytorch.py +29 -27
  123. doctr/utils/common_types.py +1 -1
  124. doctr/utils/data.py +1 -2
  125. doctr/utils/fonts.py +1 -1
  126. doctr/utils/geometry.py +7 -11
  127. doctr/utils/metrics.py +1 -1
  128. doctr/utils/multithreading.py +1 -1
  129. doctr/utils/reconstitution.py +1 -1
  130. doctr/utils/repr.py +1 -1
  131. doctr/utils/visualization.py +2 -2
  132. doctr/version.py +1 -1
  133. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
  134. python_doctr-1.0.1.dist-info/RECORD +149 -0
  135. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
  136. doctr/datasets/datasets/tensorflow.py +0 -59
  137. doctr/datasets/generator/tensorflow.py +0 -58
  138. doctr/datasets/loader.py +0 -94
  139. doctr/io/image/tensorflow.py +0 -101
  140. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  141. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  142. doctr/models/classification/predictor/tensorflow.py +0 -60
  143. doctr/models/classification/resnet/tensorflow.py +0 -418
  144. doctr/models/classification/textnet/tensorflow.py +0 -275
  145. doctr/models/classification/vgg/tensorflow.py +0 -125
  146. doctr/models/classification/vit/tensorflow.py +0 -201
  147. doctr/models/detection/_utils/tensorflow.py +0 -34
  148. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  149. doctr/models/detection/fast/tensorflow.py +0 -427
  150. doctr/models/detection/linknet/tensorflow.py +0 -377
  151. doctr/models/detection/predictor/tensorflow.py +0 -70
  152. doctr/models/kie_predictor/tensorflow.py +0 -187
  153. doctr/models/modules/layers/tensorflow.py +0 -171
  154. doctr/models/modules/transformer/tensorflow.py +0 -235
  155. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  156. doctr/models/predictor/tensorflow.py +0 -155
  157. doctr/models/preprocessor/tensorflow.py +0 -122
  158. doctr/models/recognition/crnn/tensorflow.py +0 -317
  159. doctr/models/recognition/master/tensorflow.py +0 -320
  160. doctr/models/recognition/parseq/tensorflow.py +0 -516
  161. doctr/models/recognition/predictor/tensorflow.py +0 -79
  162. doctr/models/recognition/sar/tensorflow.py +0 -423
  163. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  164. doctr/models/utils/tensorflow.py +0 -189
  165. doctr/transforms/functional/tensorflow.py +0 -254
  166. doctr/transforms/modules/tensorflow.py +0 -562
  167. python_doctr-0.12.0.dist-info/RECORD +0 -180
  168. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
  169. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
  170. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
@@ -1,516 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import math
7
- from copy import deepcopy
8
- from itertools import permutations
9
- from typing import Any
10
-
11
- import numpy as np
12
- import tensorflow as tf
13
- from tensorflow.keras import Model, layers
14
-
15
- from doctr.datasets import VOCABS
16
- from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
17
-
18
- from ...classification import vit_s
19
- from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
20
- from .base import _PARSeq, _PARSeqPostProcessor
21
-
22
- __all__ = ["PARSeq", "parseq"]
23
-
24
- default_cfgs: dict[str, dict[str, Any]] = {
25
- "parseq": {
26
- "mean": (0.694, 0.695, 0.693),
27
- "std": (0.299, 0.296, 0.301),
28
- "input_shape": (32, 128, 3),
29
- "vocab": VOCABS["french"],
30
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0",
31
- },
32
- }
33
-
34
-
35
- class CharEmbedding(layers.Layer):
36
- """Implements the character embedding module
37
-
38
- Args:
39
- -
40
- vocab_size: size of the vocabulary
41
- d_model: dimension of the model
42
- """
43
-
44
- def __init__(self, vocab_size: int, d_model: int):
45
- super(CharEmbedding, self).__init__()
46
- self.embedding = layers.Embedding(vocab_size, d_model)
47
- self.d_model = d_model
48
-
49
- def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
50
- return math.sqrt(self.d_model) * self.embedding(x, **kwargs)
51
-
52
-
53
- class PARSeqDecoder(layers.Layer):
54
- """Implements decoder module of the PARSeq model
55
-
56
- Args:
57
- d_model: dimension of the model
58
- num_heads: number of attention heads
59
- ffd: dimension of the feed forward layer
60
- ffd_ratio: depth multiplier for the feed forward layer
61
- dropout: dropout rate
62
- """
63
-
64
- def __init__(
65
- self,
66
- d_model: int,
67
- num_heads: int = 12,
68
- ffd: int = 2048,
69
- ffd_ratio: int = 4,
70
- dropout: float = 0.1,
71
- ):
72
- super(PARSeqDecoder, self).__init__()
73
- self.attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
74
- self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
75
- self.position_feed_forward = PositionwiseFeedForward(
76
- d_model, ffd * ffd_ratio, dropout, layers.Activation(tf.nn.gelu)
77
- )
78
-
79
- self.query_norm = layers.LayerNormalization(epsilon=1e-5)
80
- self.content_norm = layers.LayerNormalization(epsilon=1e-5)
81
- self.feed_forward_norm = layers.LayerNormalization(epsilon=1e-5)
82
- self.output_norm = layers.LayerNormalization(epsilon=1e-5)
83
- self.attention_dropout = layers.Dropout(dropout)
84
- self.cross_attention_dropout = layers.Dropout(dropout)
85
- self.feed_forward_dropout = layers.Dropout(dropout)
86
-
87
- def call(
88
- self,
89
- target,
90
- content,
91
- memory,
92
- target_mask=None,
93
- **kwargs: Any,
94
- ):
95
- query_norm = self.query_norm(target, **kwargs)
96
- content_norm = self.content_norm(content, **kwargs)
97
- target = target + self.attention_dropout(
98
- self.attention(query_norm, content_norm, content_norm, mask=target_mask, **kwargs), **kwargs
99
- )
100
- target = target + self.cross_attention_dropout(
101
- self.cross_attention(self.query_norm(target, **kwargs), memory, memory, **kwargs), **kwargs
102
- )
103
- target = target + self.feed_forward_dropout(
104
- self.position_feed_forward(self.feed_forward_norm(target, **kwargs), **kwargs), **kwargs
105
- )
106
- return self.output_norm(target, **kwargs)
107
-
108
-
109
- class PARSeq(_PARSeq, Model):
110
- """Implements a PARSeq architecture as described in `"Scene Text Recognition
111
- with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
112
- Modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
113
-
114
- Args:
115
- feature_extractor: the backbone serving as feature extractor
116
- vocab: vocabulary used for encoding
117
- embedding_units: number of embedding units
118
- max_length: maximum word length handled by the model
119
- dropout_prob: dropout probability for the decoder
120
- dec_num_heads: number of attention heads in the decoder
121
- dec_ff_dim: dimension of the feed forward layer in the decoder
122
- dec_ffd_ratio: depth multiplier for the feed forward layer in the decoder
123
- input_shape: input shape of the image
124
- exportable: onnx exportable returns only logits
125
- cfg: dictionary containing information about the model
126
- """
127
-
128
- _children_names: list[str] = ["feat_extractor", "postprocessor"]
129
-
130
- def __init__(
131
- self,
132
- feature_extractor,
133
- vocab: str,
134
- embedding_units: int,
135
- max_length: int = 32, # different from paper
136
- dropout_prob: float = 0.1,
137
- dec_num_heads: int = 12,
138
- dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
139
- dec_ffd_ratio: int = 4,
140
- input_shape: tuple[int, int, int] = (32, 128, 3),
141
- exportable: bool = False,
142
- cfg: dict[str, Any] | None = None,
143
- ) -> None:
144
- super().__init__()
145
- self.vocab = vocab
146
- self.exportable = exportable
147
- self.cfg = cfg
148
- self.max_length = max_length
149
- self.vocab_size = len(vocab)
150
- self.rng = np.random.default_rng()
151
-
152
- self.feat_extractor = feature_extractor
153
- self.decoder = PARSeqDecoder(embedding_units, dec_num_heads, dec_ff_dim, dec_ffd_ratio, dropout_prob)
154
- self.embed = CharEmbedding(self.vocab_size + 3, embedding_units) # +3 for SOS, EOS, PAD
155
- self.head = layers.Dense(self.vocab_size + 1, name="head") # +1 for EOS
156
- self.pos_queries = self.add_weight(
157
- shape=(1, self.max_length + 1, embedding_units),
158
- initializer="zeros",
159
- trainable=True,
160
- name="positions",
161
- )
162
- self.dropout = layers.Dropout(dropout_prob)
163
-
164
- self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
165
-
166
- def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
167
- """Load pretrained parameters onto the model
168
-
169
- Args:
170
- path_or_url: the path or URL to the model parameters (checkpoint)
171
- **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
172
- """
173
- # NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
174
- # ref.: https://github.com/mindee/doctr/issues/1911
175
- kwargs["skip_mismatch"] = True
176
- load_pretrained_params(self, path_or_url, **kwargs)
177
-
178
- def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
179
- # Generates permutations of the target sequence.
180
- # Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
181
- # with small modifications
182
-
183
- max_num_chars = int(tf.reduce_max(seqlen)) # get longest sequence length in batch
184
- perms = [tf.range(max_num_chars, dtype=tf.int32)]
185
-
186
- max_perms = math.factorial(max_num_chars) // 2
187
- num_gen_perms = min(3, max_perms)
188
- if max_num_chars < 5:
189
- # Pool of permutations to sample from. We only need the first half (if complementary option is selected)
190
- # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
191
- if max_num_chars == 4:
192
- selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
193
- else:
194
- selector = list(range(max_perms))
195
- perm_pool_candidates = list(permutations(range(max_num_chars), max_num_chars))
196
- perm_pool = tf.convert_to_tensor([perm_pool_candidates[i] for i in selector])
197
- # If the forward permutation is always selected, no need to add it to the pool for sampling
198
- perm_pool = perm_pool[1:]
199
- final_perms = tf.stack(perms)
200
- if len(perm_pool):
201
- i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(final_perms), replace=False)
202
- final_perms = tf.concat([final_perms, perm_pool[i[0] : i[1]]], axis=0)
203
- else:
204
- perms.extend([
205
- tf.random.shuffle(tf.range(max_num_chars, dtype=tf.int32)) for _ in range(num_gen_perms - len(perms))
206
- ])
207
- final_perms = tf.stack(perms)
208
-
209
- comp = tf.reverse(final_perms, axis=[-1])
210
- final_perms = tf.stack([final_perms, comp])
211
- final_perms = tf.transpose(final_perms, perm=[1, 0, 2])
212
- final_perms = tf.reshape(final_perms, shape=(-1, max_num_chars))
213
-
214
- sos_idx = tf.zeros([tf.shape(final_perms)[0], 1], dtype=tf.int32)
215
- eos_idx = tf.fill([tf.shape(final_perms)[0], 1], max_num_chars + 1)
216
- combined = tf.concat([sos_idx, final_perms + 1, eos_idx], axis=1)
217
- combined = tf.cast(combined, dtype=tf.int32)
218
- if tf.shape(combined)[0] > 1:
219
- combined = tf.tensor_scatter_nd_update(
220
- combined, [[1, i] for i in range(1, max_num_chars + 2)], max_num_chars + 1 - tf.range(max_num_chars + 1)
221
- )
222
- return combined
223
-
224
- def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
225
- # Generate source and target mask for the decoder attention.
226
- sz = permutation.shape[0]
227
- mask = tf.ones((sz, sz), dtype=tf.float32)
228
-
229
- for i in range(sz - 1):
230
- query_idx = int(permutation[i])
231
- masked_keys = permutation[i + 1 :].numpy().tolist()
232
- indices = tf.constant([[query_idx, j] for j in masked_keys], dtype=tf.int32)
233
- mask = tf.tensor_scatter_nd_update(mask, indices, tf.zeros(len(masked_keys), dtype=tf.float32))
234
-
235
- source_mask = tf.identity(mask[:-1, :-1])
236
- eye_indices = tf.eye(sz, dtype=tf.bool)
237
- mask = tf.tensor_scatter_nd_update(
238
- mask, tf.where(eye_indices), tf.zeros_like(tf.boolean_mask(mask, eye_indices))
239
- )
240
- target_mask = mask[1:, :-1]
241
- return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)
242
-
243
- def decode(
244
- self,
245
- target: tf.Tensor,
246
- memory: tf.Tensor,
247
- target_mask: tf.Tensor | None = None,
248
- target_query: tf.Tensor | None = None,
249
- **kwargs: Any,
250
- ) -> tf.Tensor:
251
- batch_size, sequence_length = target.shape
252
- # apply positional information to the target sequence excluding the SOS token
253
- null_ctx = self.embed(target[:, :1], **kwargs)
254
- content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:], **kwargs)
255
- content = self.dropout(tf.concat([null_ctx, content], axis=1), **kwargs)
256
- if target_query is None:
257
- target_query = tf.tile(self.pos_queries[:, :sequence_length], [batch_size, 1, 1])
258
- target_query = self.dropout(target_query, **kwargs)
259
- return self.decoder(target_query, content, memory, target_mask, **kwargs)
260
-
261
- def decode_autoregressive(self, features: tf.Tensor, max_len: int | None = None, **kwargs) -> tf.Tensor:
262
- """Generate predictions for the given features."""
263
- max_length = max_len if max_len is not None else self.max_length
264
- max_length = min(max_length, self.max_length) + 1
265
- b = tf.shape(features)[0]
266
- # Padding symbol + SOS at the beginning
267
- ys = tf.fill(dims=(b, max_length), value=self.vocab_size + 2)
268
- start_vector = tf.fill(dims=(b, 1), value=self.vocab_size + 1)
269
- ys = tf.concat([start_vector, ys], axis=-1)
270
- pos_queries = tf.tile(self.pos_queries[:, :max_length], [b, 1, 1])
271
- query_mask = tf.cast(tf.linalg.band_part(tf.ones((max_length, max_length)), -1, 0), dtype=tf.bool)
272
-
273
- pos_logits = []
274
- for i in range(max_length):
275
- # Decode one token at a time without providing information about the future tokens
276
- tgt_out = self.decode(
277
- ys[:, : i + 1],
278
- features,
279
- query_mask[i : i + 1, : i + 1],
280
- target_query=pos_queries[:, i : i + 1],
281
- **kwargs,
282
- )
283
- pos_prob = self.head(tgt_out)
284
- pos_logits.append(pos_prob)
285
-
286
- if i + 1 < max_length:
287
- # update ys with the next token
288
- i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(max_length), indexing="ij")
289
- indices = tf.stack([i_mesh[:, i + 1], j_mesh[:, i + 1]], axis=1)
290
- ys = tf.tensor_scatter_nd_update(
291
- ys, indices, tf.cast(tf.argmax(pos_prob[:, -1, :], axis=-1), dtype=tf.int32)
292
- )
293
-
294
- # Stop decoding if all sequences have reached the EOS token
295
- # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
296
- if (
297
- not self.exportable
298
- and max_len is None
299
- and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1))
300
- ):
301
- break
302
-
303
- logits = tf.concat(pos_logits, axis=1) # (N, max_length, vocab_size + 1)
304
-
305
- # One refine iteration
306
- # Update query mask
307
- diag_matrix = tf.eye(max_length)
308
- diag_matrix = tf.cast(tf.logical_not(tf.cast(diag_matrix, dtype=tf.bool)), dtype=tf.float32)
309
- query_mask = tf.cast(tf.concat([diag_matrix[1:], tf.ones((1, max_length))], axis=0), dtype=tf.bool)
310
-
311
- sos = tf.fill((tf.shape(features)[0], 1), self.vocab_size + 1)
312
- ys = tf.concat([sos, tf.cast(tf.argmax(logits[:, :-1], axis=-1), dtype=tf.int32)], axis=1)
313
- # Create padding mask for refined target input maskes all behind EOS token as False
314
- # (N, 1, 1, max_length)
315
- mask = tf.cast(tf.equal(ys, self.vocab_size), tf.float32)
316
- first_eos_indices = tf.argmax(mask, axis=1, output_type=tf.int32)
317
- mask = tf.sequence_mask(first_eos_indices + 1, maxlen=ys.shape[-1], dtype=tf.float32)
318
- target_pad_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.bool)
319
-
320
- mask = tf.math.logical_and(target_pad_mask, query_mask[:, : ys.shape[1]])
321
- logits = self.head(self.decode(ys, features, mask, target_query=pos_queries, **kwargs), **kwargs)
322
-
323
- return logits # (N, max_length, vocab_size + 1)
324
-
325
- def call(
326
- self,
327
- x: tf.Tensor,
328
- target: list[str] | None = None,
329
- return_model_output: bool = False,
330
- return_preds: bool = False,
331
- **kwargs: Any,
332
- ) -> dict[str, Any]:
333
- features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
334
- # remove cls token
335
- features = features[:, 1:, :]
336
-
337
- if kwargs.get("training", False) and target is None:
338
- raise ValueError("Need to provide labels during training")
339
-
340
- if target is not None:
341
- gt, seq_len = self.build_target(target)
342
- seq_len = tf.cast(seq_len, tf.int32)
343
- gt = gt[:, : int(tf.reduce_max(seq_len)) + 2] # slice up to the max length of the batch + 2 (SOS + EOS)
344
-
345
- if kwargs.get("training", False):
346
- # Generate permutations of the target sequences
347
- tgt_perms = self.generate_permutations(seq_len)
348
-
349
- gt_in = gt[:, :-1] # remove EOS token from longest target sequence
350
- gt_out = gt[:, 1:] # remove SOS token
351
-
352
- # Create padding mask for target input
353
- # [True, True, True, ..., False, False, False] -> False is masked
354
- padding_mask = tf.math.logical_and(
355
- tf.math.not_equal(gt_in, self.vocab_size + 2), tf.math.not_equal(gt_in, self.vocab_size)
356
- )
357
- padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :] # (N, 1, 1, seq_len)
358
-
359
- loss = tf.constant(0.0)
360
- loss_numel = tf.constant(0.0)
361
- n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
362
- for i, perm in enumerate(tgt_perms):
363
- _, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
364
- # combine both masks to (N, 1, seq_len, seq_len)
365
- mask = tf.logical_and(padding_mask, tf.expand_dims(tf.expand_dims(target_mask, axis=0), axis=0))
366
-
367
- logits = self.head(self.decode(gt_in, features, mask, **kwargs), **kwargs)
368
- logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
369
- targets_flat = tf.reshape(gt_out, (-1,))
370
- mask = tf.not_equal(targets_flat, self.vocab_size + 2)
371
- loss += n * tf.reduce_mean(
372
- tf.nn.sparse_softmax_cross_entropy_with_logits(
373
- labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
374
- )
375
- )
376
- loss_numel += n
377
-
378
- # After the second iteration (i.e. done with canonical and reverse orderings),
379
- # remove the [EOS] tokens for the succeeding perms
380
- if i == 1:
381
- gt_out = tf.where(tf.equal(gt_out, self.vocab_size), self.vocab_size + 2, gt_out)
382
- n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
383
-
384
- loss /= loss_numel
385
-
386
- else:
387
- gt = gt[:, 1:] # remove SOS token
388
- max_len = gt.shape[1] - 1 # exclude EOS token
389
- logits = self.decode_autoregressive(features, max_len, **kwargs)
390
- logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
391
- targets_flat = tf.reshape(gt, (-1,))
392
- mask = tf.not_equal(targets_flat, self.vocab_size + 2)
393
- loss = tf.reduce_mean(
394
- tf.nn.sparse_softmax_cross_entropy_with_logits(
395
- labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
396
- )
397
- )
398
- else:
399
- logits = self.decode_autoregressive(features, **kwargs)
400
-
401
- logits = _bf16_to_float32(logits)
402
-
403
- out: dict[str, tf.Tensor] = {}
404
- if self.exportable:
405
- out["logits"] = logits
406
- return out
407
-
408
- if return_model_output:
409
- out["out_map"] = logits
410
-
411
- if target is None or return_preds:
412
- # Post-process boxes
413
- out["preds"] = self.postprocessor(logits)
414
-
415
- if target is not None:
416
- out["loss"] = loss
417
-
418
- return out
419
-
420
-
421
- class PARSeqPostProcessor(_PARSeqPostProcessor):
422
- """Post processor for PARSeq architecture
423
-
424
- Args:
425
- vocab: string containing the ordered sequence of supported characters
426
- """
427
-
428
- def __call__(
429
- self,
430
- logits: tf.Tensor,
431
- ) -> list[tuple[str, float]]:
432
- # compute pred with argmax for attention models
433
- out_idxs = tf.math.argmax(logits, axis=2)
434
- preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
435
-
436
- # decode raw output of the model with tf_label_to_idx
437
- out_idxs = tf.cast(out_idxs, dtype="int32")
438
- embedding = tf.constant(self._embedding, dtype=tf.string)
439
- decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
440
- decoded_strings_pred = tf.strings.split(decoded_strings_pred, "<eos>")
441
- decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
442
- word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
443
-
444
- # compute probabilties for each word up to the EOS token
445
- probs = [
446
- preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
447
- for i, word in enumerate(word_values)
448
- ]
449
-
450
- return list(zip(word_values, probs))
451
-
452
-
453
- def _parseq(
454
- arch: str,
455
- pretrained: bool,
456
- backbone_fn,
457
- input_shape: tuple[int, int, int] | None = None,
458
- **kwargs: Any,
459
- ) -> PARSeq:
460
- # Patch the config
461
- _cfg = deepcopy(default_cfgs[arch])
462
- _cfg["input_shape"] = input_shape or _cfg["input_shape"]
463
- _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
464
- patch_size = kwargs.get("patch_size", (4, 8))
465
-
466
- kwargs["vocab"] = _cfg["vocab"]
467
-
468
- # Feature extractor
469
- feat_extractor = backbone_fn(
470
- # NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch
471
- pretrained=False,
472
- input_shape=_cfg["input_shape"],
473
- patch_size=patch_size,
474
- include_top=False,
475
- )
476
-
477
- kwargs.pop("patch_size", None)
478
- kwargs.pop("pretrained_backbone", None)
479
-
480
- # Build the model
481
- model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
482
- _build_model(model)
483
-
484
- # Load pretrained parameters
485
- if pretrained:
486
- # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
487
- model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
488
-
489
- return model
490
-
491
-
492
- def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
493
- """PARSeq architecture from
494
- `"Scene Text Recognition with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
495
-
496
- >>> import tensorflow as tf
497
- >>> from doctr.models import parseq
498
- >>> model = parseq(pretrained=False)
499
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
500
- >>> out = model(input_tensor)
501
-
502
- Args:
503
- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
504
- **kwargs: keyword arguments of the PARSeq architecture
505
-
506
- Returns:
507
- text recognition architecture
508
- """
509
- return _parseq(
510
- "parseq",
511
- pretrained,
512
- vit_s,
513
- embedding_units=384,
514
- patch_size=(4, 8),
515
- **kwargs,
516
- )
@@ -1,79 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from typing import Any
7
-
8
- import numpy as np
9
- import tensorflow as tf
10
-
11
- from doctr.models.preprocessor import PreProcessor
12
- from doctr.utils.repr import NestedObject
13
-
14
- from ..core import RecognitionModel
15
- from ._utils import remap_preds, split_crops
16
-
17
- __all__ = ["RecognitionPredictor"]
18
-
19
-
20
- class RecognitionPredictor(NestedObject):
21
- """Implements an object able to identify character sequences in images
22
-
23
- Args:
24
- pre_processor: transform inputs for easier batched model inference
25
- model: core detection architecture
26
- split_wide_crops: wether to use crop splitting for high aspect ratio crops
27
- """
28
-
29
- _children_names: list[str] = ["pre_processor", "model"]
30
-
31
- def __init__(
32
- self,
33
- pre_processor: PreProcessor,
34
- model: RecognitionModel,
35
- split_wide_crops: bool = True,
36
- ) -> None:
37
- super().__init__()
38
- self.pre_processor = pre_processor
39
- self.model = model
40
- self.split_wide_crops = split_wide_crops
41
- self.critical_ar = 8 # Critical aspect ratio
42
- self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
43
- self.target_ar = 6 # Target aspect ratio
44
-
45
- def __call__(
46
- self,
47
- crops: list[np.ndarray | tf.Tensor],
48
- **kwargs: Any,
49
- ) -> list[tuple[str, float]]:
50
- if len(crops) == 0:
51
- return []
52
- # Dimension check
53
- if any(crop.ndim != 3 for crop in crops):
54
- raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
55
-
56
- # Split crops that are too wide
57
- remapped = False
58
- if self.split_wide_crops:
59
- new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.overlap_ratio)
60
- if remapped:
61
- crops = new_crops
62
-
63
- # Resize & batch them
64
- processed_batches = self.pre_processor(crops)
65
-
66
- # Forward it
67
- raw = [
68
- self.model(batch, return_preds=True, training=False, **kwargs)["preds"] # type: ignore[operator]
69
- for batch in processed_batches
70
- ]
71
-
72
- # Process outputs
73
- out = [charseq for batch in raw for charseq in batch]
74
-
75
- # Remap crops
76
- if self.split_wide_crops and remapped:
77
- out = remap_preds(out, crop_map, self.overlap_ratio)
78
-
79
- return out