python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -1,416 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from copy import deepcopy
7
- from typing import Any
8
-
9
- import tensorflow as tf
10
- from tensorflow.keras import Model, Sequential, layers
11
-
12
- from doctr.datasets import VOCABS
13
- from doctr.utils.repr import NestedObject
14
-
15
- from ...classification import resnet31
16
- from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
17
- from ..core import RecognitionModel, RecognitionPostProcessor
18
-
19
- __all__ = ["SAR", "sar_resnet31"]
20
-
21
- default_cfgs: dict[str, dict[str, Any]] = {
22
- "sar_resnet31": {
23
- "mean": (0.694, 0.695, 0.693),
24
- "std": (0.299, 0.296, 0.301),
25
- "input_shape": (32, 128, 3),
26
- "vocab": VOCABS["french"],
27
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0",
28
- },
29
- }
30
-
31
-
32
- class SAREncoder(layers.Layer, NestedObject):
33
- """Implements encoder module of the SAR model
34
-
35
- Args:
36
- rnn_units: number of hidden rnn units
37
- dropout_prob: dropout probability
38
- """
39
-
40
- def __init__(self, rnn_units: int, dropout_prob: float = 0.0) -> None:
41
- super().__init__()
42
- self.rnn = Sequential([
43
- layers.LSTM(units=rnn_units, return_sequences=True, recurrent_dropout=dropout_prob),
44
- layers.LSTM(units=rnn_units, return_sequences=False, recurrent_dropout=dropout_prob),
45
- ])
46
-
47
- def call(
48
- self,
49
- x: tf.Tensor,
50
- **kwargs: Any,
51
- ) -> tf.Tensor:
52
- # (N, C)
53
- return self.rnn(x, **kwargs)
54
-
55
-
56
- class AttentionModule(layers.Layer, NestedObject):
57
- """Implements attention module of the SAR model
58
-
59
- Args:
60
- attention_units: number of hidden attention units
61
-
62
- """
63
-
64
- def __init__(self, attention_units: int) -> None:
65
- super().__init__()
66
- self.hidden_state_projector = layers.Conv2D(
67
- attention_units,
68
- 1,
69
- strides=1,
70
- use_bias=False,
71
- padding="same",
72
- kernel_initializer="he_normal",
73
- )
74
- self.features_projector = layers.Conv2D(
75
- attention_units,
76
- 3,
77
- strides=1,
78
- use_bias=True,
79
- padding="same",
80
- kernel_initializer="he_normal",
81
- )
82
- self.attention_projector = layers.Conv2D(
83
- 1,
84
- 1,
85
- strides=1,
86
- use_bias=False,
87
- padding="same",
88
- kernel_initializer="he_normal",
89
- )
90
- self.flatten = layers.Flatten()
91
-
92
- def call(
93
- self,
94
- features: tf.Tensor,
95
- hidden_state: tf.Tensor,
96
- **kwargs: Any,
97
- ) -> tf.Tensor:
98
- [H, W] = features.get_shape().as_list()[1:3]
99
- # shape (N, H, W, vgg_units) -> (N, H, W, attention_units)
100
- features_projection = self.features_projector(features, **kwargs)
101
- # shape (N, 1, 1, rnn_units) -> (N, 1, 1, attention_units)
102
- hidden_state = tf.expand_dims(tf.expand_dims(hidden_state, axis=1), axis=1)
103
- hidden_state_projection = self.hidden_state_projector(hidden_state, **kwargs)
104
- projection = tf.math.tanh(hidden_state_projection + features_projection)
105
- # shape (N, H, W, attention_units) -> (N, H, W, 1)
106
- attention = self.attention_projector(projection, **kwargs)
107
- # shape (N, H, W, 1) -> (N, H * W)
108
- attention = self.flatten(attention)
109
- attention = tf.nn.softmax(attention)
110
- # shape (N, H * W) -> (N, H, W, 1)
111
- attention_map = tf.reshape(attention, [-1, H, W, 1])
112
- glimpse = tf.math.multiply(features, attention_map)
113
- # shape (N, H * W) -> (N, C)
114
- return tf.reduce_sum(glimpse, axis=[1, 2])
115
-
116
-
117
- class SARDecoder(layers.Layer, NestedObject):
118
- """Implements decoder module of the SAR model
119
-
120
- Args:
121
- rnn_units: number of hidden units in recurrent cells
122
- max_length: maximum length of a sequence
123
- vocab_size: number of classes in the model alphabet
124
- embedding_units: number of hidden embedding units
125
- attention_units: number of hidden attention units
126
- num_decoder_cells: number of LSTMCell layers to stack
127
- dropout_prob: dropout probability
128
-
129
- """
130
-
131
- def __init__(
132
- self,
133
- rnn_units: int,
134
- max_length: int,
135
- vocab_size: int,
136
- embedding_units: int,
137
- attention_units: int,
138
- num_decoder_cells: int = 2,
139
- dropout_prob: float = 0.0,
140
- ) -> None:
141
- super().__init__()
142
- self.vocab_size = vocab_size
143
- self.max_length = max_length
144
-
145
- self.embed = layers.Dense(embedding_units, use_bias=False)
146
- self.embed_tgt = layers.Embedding(embedding_units, self.vocab_size + 1)
147
-
148
- self.lstm_cells = layers.StackedRNNCells([
149
- layers.LSTMCell(rnn_units, implementation=1) for _ in range(num_decoder_cells)
150
- ])
151
- self.attention_module = AttentionModule(attention_units)
152
- self.output_dense = layers.Dense(self.vocab_size + 1, use_bias=True)
153
- self.dropout = layers.Dropout(dropout_prob)
154
-
155
- def call(
156
- self,
157
- features: tf.Tensor,
158
- holistic: tf.Tensor,
159
- gt: tf.Tensor | None = None,
160
- **kwargs: Any,
161
- ) -> tf.Tensor:
162
- if gt is not None:
163
- gt_embedding = self.embed_tgt(gt, **kwargs)
164
-
165
- logits_list: list[tf.Tensor] = []
166
-
167
- for t in range(self.max_length + 1): # 32
168
- if t == 0:
169
- # step to init the first states of the LSTMCell
170
- states = self.lstm_cells.get_initial_state(
171
- inputs=None, batch_size=features.shape[0], dtype=features.dtype
172
- )
173
- prev_symbol = holistic
174
- elif t == 1:
175
- # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
176
- # (N, vocab_size + 1) --> (N, embedding_units)
177
- prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1], dtype=features.dtype)
178
- prev_symbol = self.embed(prev_symbol, **kwargs)
179
- else:
180
- if gt is not None and kwargs.get("training", False):
181
- # (N, embedding_units) -2 because of <bos> and <eos> (same)
182
- prev_symbol = self.embed(gt_embedding[:, t - 2], **kwargs)
183
- else:
184
- # -1 to start at timestep where prev_symbol was initialized
185
- index = tf.argmax(logits_list[t - 1], axis=-1)
186
- # update prev_symbol with ones at the index of the previous logit vector
187
- prev_symbol = self.embed(self.embed_tgt(index, **kwargs), **kwargs)
188
-
189
- # (N, C), (N, C) take the last hidden state and cell state from current timestep
190
- _, states = self.lstm_cells(prev_symbol, states, **kwargs)
191
- # states = (hidden_state, cell_state)
192
- hidden_state = states[0][0]
193
- # (N, H, W, C), (N, C) --> (N, C)
194
- glimpse = self.attention_module(features, hidden_state, **kwargs)
195
- # (N, C), (N, C) --> (N, 2 * C)
196
- logits = tf.concat([hidden_state, glimpse], axis=1)
197
- logits = self.dropout(logits, **kwargs)
198
- # (N, vocab_size + 1)
199
- logits_list.append(self.output_dense(logits, **kwargs))
200
-
201
- # (max_length + 1, N, vocab_size + 1) --> (N, max_length + 1, vocab_size + 1)
202
- return tf.transpose(tf.stack(logits_list[1:]), (1, 0, 2))
203
-
204
-
205
- class SAR(Model, RecognitionModel):
206
- """Implements a SAR architecture as described in `"Show, Attend and Read:A Simple and Strong Baseline for
207
- Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
208
-
209
- Args:
210
- feature_extractor: the backbone serving as feature extractor
211
- vocab: vocabulary used for encoding
212
- rnn_units: number of hidden units in both encoder and decoder LSTM
213
- embedding_units: number of embedding units
214
- attention_units: number of hidden units in attention module
215
- max_length: maximum word length handled by the model
216
- num_decoder_cells: number of LSTMCell layers to stack
217
- dropout_prob: dropout probability for the encoder and decoder
218
- exportable: onnx exportable returns only logits
219
- cfg: dictionary containing information about the model
220
- """
221
-
222
- _children_names: list[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"]
223
-
224
- def __init__(
225
- self,
226
- feature_extractor,
227
- vocab: str,
228
- rnn_units: int = 512,
229
- embedding_units: int = 512,
230
- attention_units: int = 512,
231
- max_length: int = 30,
232
- num_decoder_cells: int = 2,
233
- dropout_prob: float = 0.0,
234
- exportable: bool = False,
235
- cfg: dict[str, Any] | None = None,
236
- ) -> None:
237
- super().__init__()
238
- self.vocab = vocab
239
- self.exportable = exportable
240
- self.cfg = cfg
241
- self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word
242
-
243
- self.feat_extractor = feature_extractor
244
-
245
- self.encoder = SAREncoder(rnn_units, dropout_prob)
246
- self.decoder = SARDecoder(
247
- rnn_units,
248
- self.max_length,
249
- len(vocab),
250
- embedding_units,
251
- attention_units,
252
- num_decoder_cells,
253
- dropout_prob,
254
- )
255
-
256
- self.postprocessor = SARPostProcessor(vocab=vocab)
257
-
258
- @staticmethod
259
- def compute_loss(
260
- model_output: tf.Tensor,
261
- gt: tf.Tensor,
262
- seq_len: tf.Tensor,
263
- ) -> tf.Tensor:
264
- """Compute categorical cross-entropy loss for the model.
265
- Sequences are masked after the EOS character.
266
-
267
- Args:
268
- gt: the encoded tensor with gt labels
269
- model_output: predicted logits of the model
270
- seq_len: lengths of each gt word inside the batch
271
-
272
- Returns:
273
- The loss of the model on the batch
274
- """
275
- # Input length : number of timesteps
276
- input_len = tf.shape(model_output)[1]
277
- # Add one for additional <eos> token
278
- seq_len = seq_len + 1
279
- # One-hot gt labels
280
- oh_gt = tf.one_hot(gt, depth=model_output.shape[2])
281
- # Compute loss
282
- cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt, model_output)
283
- # Compute mask
284
- mask_values = tf.zeros_like(cce)
285
- mask_2d = tf.sequence_mask(seq_len, input_len)
286
- masked_loss = tf.where(mask_2d, cce, mask_values)
287
- ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype))
288
- return tf.expand_dims(ce_loss, axis=1)
289
-
290
- def call(
291
- self,
292
- x: tf.Tensor,
293
- target: list[str] | None = None,
294
- return_model_output: bool = False,
295
- return_preds: bool = False,
296
- **kwargs: Any,
297
- ) -> dict[str, Any]:
298
- features = self.feat_extractor(x, **kwargs)
299
- # vertical max pooling --> (N, C, W)
300
- pooled_features = tf.reduce_max(features, axis=1)
301
- # holistic (N, C)
302
- encoded = self.encoder(pooled_features, **kwargs)
303
-
304
- if target is not None:
305
- gt, seq_len = self.build_target(target)
306
- seq_len = tf.cast(seq_len, tf.int32)
307
-
308
- if kwargs.get("training", False) and target is None:
309
- raise ValueError("Need to provide labels during training for teacher forcing")
310
-
311
- decoded_features = _bf16_to_float32(
312
- self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
313
- )
314
-
315
- out: dict[str, tf.Tensor] = {}
316
- if self.exportable:
317
- out["logits"] = decoded_features
318
- return out
319
-
320
- if return_model_output:
321
- out["out_map"] = decoded_features
322
-
323
- if target is None or return_preds:
324
- # Post-process boxes
325
- out["preds"] = self.postprocessor(decoded_features)
326
-
327
- if target is not None:
328
- out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
329
-
330
- return out
331
-
332
-
333
- class SARPostProcessor(RecognitionPostProcessor):
334
- """Post processor for SAR architectures
335
-
336
- Args:
337
- vocab: string containing the ordered sequence of supported characters
338
- """
339
-
340
- def __call__(
341
- self,
342
- logits: tf.Tensor,
343
- ) -> list[tuple[str, float]]:
344
- # compute pred with argmax for attention models
345
- out_idxs = tf.math.argmax(logits, axis=2)
346
- # N x L
347
- probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)
348
- # Take the minimum confidence of the sequence
349
- probs = tf.math.reduce_min(probs, axis=1)
350
-
351
- # decode raw output of the model with tf_label_to_idx
352
- out_idxs = tf.cast(out_idxs, dtype="int32")
353
- embedding = tf.constant(self._embedding, dtype=tf.string)
354
- decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
355
- decoded_strings_pred = tf.strings.split(decoded_strings_pred, "<eos>")
356
- decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
357
- word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
358
-
359
- return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
360
-
361
-
362
- def _sar(
363
- arch: str,
364
- pretrained: bool,
365
- backbone_fn,
366
- pretrained_backbone: bool = True,
367
- input_shape: tuple[int, int, int] | None = None,
368
- **kwargs: Any,
369
- ) -> SAR:
370
- pretrained_backbone = pretrained_backbone and not pretrained
371
-
372
- # Patch the config
373
- _cfg = deepcopy(default_cfgs[arch])
374
- _cfg["input_shape"] = input_shape or _cfg["input_shape"]
375
- _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
376
-
377
- # Feature extractor
378
- feat_extractor = backbone_fn(
379
- pretrained=pretrained_backbone,
380
- input_shape=_cfg["input_shape"],
381
- include_top=False,
382
- )
383
-
384
- kwargs["vocab"] = _cfg["vocab"]
385
-
386
- # Build the model
387
- model = SAR(feat_extractor, cfg=_cfg, **kwargs)
388
- _build_model(model)
389
- # Load pretrained parameters
390
- if pretrained:
391
- # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
392
- load_pretrained_params(
393
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
394
- )
395
-
396
- return model
397
-
398
-
399
- def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
400
- """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
401
- Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
402
-
403
- >>> import tensorflow as tf
404
- >>> from doctr.models import sar_resnet31
405
- >>> model = sar_resnet31(pretrained=False)
406
- >>> input_tensor = tf.random.uniform(shape=[1, 64, 256, 3], maxval=1, dtype=tf.float32)
407
- >>> out = model(input_tensor)
408
-
409
- Args:
410
- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
411
- **kwargs: keyword arguments of the SAR architecture
412
-
413
- Returns:
414
- text recognition architecture
415
- """
416
- return _sar("sar_resnet31", pretrained, resnet31, **kwargs)
@@ -1,278 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from copy import deepcopy
7
- from typing import Any
8
-
9
- import tensorflow as tf
10
- from tensorflow.keras import Model, layers
11
-
12
- from doctr.datasets import VOCABS
13
-
14
- from ...classification import vit_b, vit_s
15
- from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
16
- from .base import _ViTSTR, _ViTSTRPostProcessor
17
-
18
- __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
19
-
20
- default_cfgs: dict[str, dict[str, Any]] = {
21
- "vitstr_small": {
22
- "mean": (0.694, 0.695, 0.693),
23
- "std": (0.299, 0.296, 0.301),
24
- "input_shape": (32, 128, 3),
25
- "vocab": VOCABS["french"],
26
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0",
27
- },
28
- "vitstr_base": {
29
- "mean": (0.694, 0.695, 0.693),
30
- "std": (0.299, 0.296, 0.301),
31
- "input_shape": (32, 128, 3),
32
- "vocab": VOCABS["french"],
33
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0",
34
- },
35
- }
36
-
37
-
38
- class ViTSTR(_ViTSTR, Model):
39
- """Implements a ViTSTR architecture as described in `"Vision Transformer for Fast and
40
- Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
41
-
42
- Args:
43
- feature_extractor: the backbone serving as feature extractor
44
- vocab: vocabulary used for encoding
45
- embedding_units: number of embedding units
46
- max_length: maximum word length handled by the model
47
- dropout_prob: dropout probability for the encoder and decoder
48
- input_shape: input shape of the image
49
- exportable: onnx exportable returns only logits
50
- cfg: dictionary containing information about the model
51
- """
52
-
53
- _children_names: list[str] = ["feat_extractor", "postprocessor"]
54
-
55
- def __init__(
56
- self,
57
- feature_extractor,
58
- vocab: str,
59
- embedding_units: int,
60
- max_length: int = 32,
61
- dropout_prob: float = 0.0,
62
- input_shape: tuple[int, int, int] = (32, 128, 3), # different from paper
63
- exportable: bool = False,
64
- cfg: dict[str, Any] | None = None,
65
- ) -> None:
66
- super().__init__()
67
- self.vocab = vocab
68
- self.exportable = exportable
69
- self.cfg = cfg
70
- self.max_length = max_length + 2 # +2 for SOS and EOS
71
-
72
- self.feat_extractor = feature_extractor
73
- self.head = layers.Dense(len(self.vocab) + 1, name="head") # +1 for EOS
74
-
75
- self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
76
-
77
- @staticmethod
78
- def compute_loss(
79
- model_output: tf.Tensor,
80
- gt: tf.Tensor,
81
- seq_len: list[int],
82
- ) -> tf.Tensor:
83
- """Compute categorical cross-entropy loss for the model.
84
- Sequences are masked after the EOS character.
85
-
86
- Args:
87
- model_output: predicted logits of the model
88
- gt: the encoded tensor with gt labels
89
- seq_len: lengths of each gt word inside the batch
90
-
91
- Returns:
92
- The loss of the model on the batch
93
- """
94
- # Input length : number of steps
95
- input_len = tf.shape(model_output)[1]
96
- # Add one for additional <eos> token (sos disappear in shift!)
97
- seq_len = tf.cast(seq_len, tf.int32) + 1
98
- # One-hot gt labels
99
- oh_gt = tf.one_hot(gt, depth=model_output.shape[2])
100
- # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
101
- # The "masked" first gt char is <sos>.
102
- cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt[:, 1:, :], model_output)
103
- # Compute mask
104
- mask_values = tf.zeros_like(cce)
105
- mask_2d = tf.sequence_mask(seq_len, input_len)
106
- masked_loss = tf.where(mask_2d, cce, mask_values)
107
- ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype))
108
-
109
- return tf.expand_dims(ce_loss, axis=1)
110
-
111
- def call(
112
- self,
113
- x: tf.Tensor,
114
- target: list[str] | None = None,
115
- return_model_output: bool = False,
116
- return_preds: bool = False,
117
- **kwargs: Any,
118
- ) -> dict[str, Any]:
119
- features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
120
-
121
- if target is not None:
122
- gt, seq_len = self.build_target(target)
123
- seq_len = tf.cast(seq_len, tf.int32)
124
-
125
- if kwargs.get("training", False) and target is None:
126
- raise ValueError("Need to provide labels during training")
127
-
128
- features = features[:, : self.max_length] # (batch_size, max_length, d_model)
129
- B, N, E = features.shape
130
- features = tf.reshape(features, (B * N, E))
131
- logits = tf.reshape(
132
- self.head(features, **kwargs), (B, N, len(self.vocab) + 1)
133
- ) # (batch_size, max_length, vocab + 1)
134
- decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
135
-
136
- out: dict[str, tf.Tensor] = {}
137
- if self.exportable:
138
- out["logits"] = decoded_features
139
- return out
140
-
141
- if return_model_output:
142
- out["out_map"] = decoded_features
143
-
144
- if target is None or return_preds:
145
- # Post-process boxes
146
- out["preds"] = self.postprocessor(decoded_features)
147
-
148
- if target is not None:
149
- out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
150
-
151
- return out
152
-
153
-
154
- class ViTSTRPostProcessor(_ViTSTRPostProcessor):
155
- """Post processor for ViTSTR architecture
156
-
157
- Args:
158
- vocab: string containing the ordered sequence of supported characters
159
- """
160
-
161
- def __call__(
162
- self,
163
- logits: tf.Tensor,
164
- ) -> list[tuple[str, float]]:
165
- # compute pred with argmax for attention models
166
- out_idxs = tf.math.argmax(logits, axis=2)
167
- preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
168
-
169
- # decode raw output of the model with tf_label_to_idx
170
- out_idxs = tf.cast(out_idxs, dtype="int32")
171
- embedding = tf.constant(self._embedding, dtype=tf.string)
172
- decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
173
- decoded_strings_pred = tf.strings.split(decoded_strings_pred, "<eos>")
174
- decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
175
- word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
176
-
177
- # compute probabilties for each word up to the EOS token
178
- probs = [
179
- preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
180
- for i, word in enumerate(word_values)
181
- ]
182
-
183
- return list(zip(word_values, probs))
184
-
185
-
186
- def _vitstr(
187
- arch: str,
188
- pretrained: bool,
189
- backbone_fn,
190
- input_shape: tuple[int, int, int] | None = None,
191
- **kwargs: Any,
192
- ) -> ViTSTR:
193
- # Patch the config
194
- _cfg = deepcopy(default_cfgs[arch])
195
- _cfg["input_shape"] = input_shape or _cfg["input_shape"]
196
- _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
197
- patch_size = kwargs.get("patch_size", (4, 8))
198
-
199
- kwargs["vocab"] = _cfg["vocab"]
200
-
201
- # Feature extractor
202
- feat_extractor = backbone_fn(
203
- # NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch
204
- pretrained=False,
205
- input_shape=_cfg["input_shape"],
206
- patch_size=patch_size,
207
- include_top=False,
208
- )
209
-
210
- kwargs.pop("patch_size", None)
211
- kwargs.pop("pretrained_backbone", None)
212
-
213
- # Build the model
214
- model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs)
215
- _build_model(model)
216
-
217
- # Load pretrained parameters
218
- if pretrained:
219
- # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
220
- load_pretrained_params(
221
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
222
- )
223
-
224
- return model
225
-
226
-
227
- def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
228
- """ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
229
- <https://arxiv.org/pdf/2105.08582.pdf>`_.
230
-
231
- >>> import tensorflow as tf
232
- >>> from doctr.models import vitstr_small
233
- >>> model = vitstr_small(pretrained=False)
234
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
235
- >>> out = model(input_tensor)
236
-
237
- Args:
238
- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
239
- **kwargs: keyword arguments of the ViTSTR architecture
240
-
241
- Returns:
242
- text recognition architecture
243
- """
244
- return _vitstr(
245
- "vitstr_small",
246
- pretrained,
247
- vit_s,
248
- embedding_units=384,
249
- patch_size=(4, 8),
250
- **kwargs,
251
- )
252
-
253
-
254
- def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
255
- """ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
256
- <https://arxiv.org/pdf/2105.08582.pdf>`_.
257
-
258
- >>> import tensorflow as tf
259
- >>> from doctr.models import vitstr_base
260
- >>> model = vitstr_base(pretrained=False)
261
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
262
- >>> out = model(input_tensor)
263
-
264
- Args:
265
- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
266
- **kwargs: keyword arguments of the ViTSTR architecture
267
-
268
- Returns:
269
- text recognition architecture
270
- """
271
- return _vitstr(
272
- "vitstr_base",
273
- pretrained,
274
- vit_b,
275
- embedding_units=768,
276
- patch_size=(4, 8),
277
- **kwargs,
278
- )