python-doctr 0.12.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +0 -5
  3. doctr/datasets/datasets/__init__.py +1 -6
  4. doctr/datasets/datasets/pytorch.py +2 -2
  5. doctr/datasets/generator/__init__.py +1 -6
  6. doctr/datasets/vocabs.py +0 -2
  7. doctr/file_utils.py +2 -101
  8. doctr/io/image/__init__.py +1 -7
  9. doctr/io/image/pytorch.py +1 -1
  10. doctr/models/_utils.py +3 -3
  11. doctr/models/classification/magc_resnet/__init__.py +1 -6
  12. doctr/models/classification/magc_resnet/pytorch.py +2 -2
  13. doctr/models/classification/mobilenet/__init__.py +1 -6
  14. doctr/models/classification/predictor/__init__.py +1 -6
  15. doctr/models/classification/predictor/pytorch.py +1 -1
  16. doctr/models/classification/resnet/__init__.py +1 -6
  17. doctr/models/classification/textnet/__init__.py +1 -6
  18. doctr/models/classification/textnet/pytorch.py +1 -1
  19. doctr/models/classification/vgg/__init__.py +1 -6
  20. doctr/models/classification/vip/__init__.py +1 -4
  21. doctr/models/classification/vip/layers/__init__.py +1 -4
  22. doctr/models/classification/vip/layers/pytorch.py +1 -1
  23. doctr/models/classification/vit/__init__.py +1 -6
  24. doctr/models/classification/vit/pytorch.py +2 -2
  25. doctr/models/classification/zoo.py +6 -11
  26. doctr/models/detection/_utils/__init__.py +1 -6
  27. doctr/models/detection/core.py +1 -1
  28. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  29. doctr/models/detection/differentiable_binarization/base.py +4 -12
  30. doctr/models/detection/differentiable_binarization/pytorch.py +3 -3
  31. doctr/models/detection/fast/__init__.py +1 -6
  32. doctr/models/detection/fast/base.py +4 -14
  33. doctr/models/detection/fast/pytorch.py +4 -4
  34. doctr/models/detection/linknet/__init__.py +1 -6
  35. doctr/models/detection/linknet/base.py +3 -12
  36. doctr/models/detection/linknet/pytorch.py +2 -2
  37. doctr/models/detection/predictor/__init__.py +1 -6
  38. doctr/models/detection/predictor/pytorch.py +1 -1
  39. doctr/models/detection/zoo.py +15 -32
  40. doctr/models/factory/hub.py +8 -21
  41. doctr/models/kie_predictor/__init__.py +1 -6
  42. doctr/models/kie_predictor/pytorch.py +2 -6
  43. doctr/models/modules/layers/__init__.py +1 -6
  44. doctr/models/modules/layers/pytorch.py +3 -3
  45. doctr/models/modules/transformer/__init__.py +1 -6
  46. doctr/models/modules/transformer/pytorch.py +2 -2
  47. doctr/models/modules/vision_transformer/__init__.py +1 -6
  48. doctr/models/predictor/__init__.py +1 -6
  49. doctr/models/predictor/base.py +3 -8
  50. doctr/models/predictor/pytorch.py +2 -5
  51. doctr/models/preprocessor/__init__.py +1 -6
  52. doctr/models/preprocessor/pytorch.py +27 -32
  53. doctr/models/recognition/crnn/__init__.py +1 -6
  54. doctr/models/recognition/crnn/pytorch.py +6 -6
  55. doctr/models/recognition/master/__init__.py +1 -6
  56. doctr/models/recognition/master/pytorch.py +5 -5
  57. doctr/models/recognition/parseq/__init__.py +1 -6
  58. doctr/models/recognition/parseq/pytorch.py +5 -5
  59. doctr/models/recognition/predictor/__init__.py +1 -6
  60. doctr/models/recognition/predictor/_utils.py +7 -16
  61. doctr/models/recognition/predictor/pytorch.py +1 -2
  62. doctr/models/recognition/sar/__init__.py +1 -6
  63. doctr/models/recognition/sar/pytorch.py +3 -3
  64. doctr/models/recognition/viptr/__init__.py +1 -4
  65. doctr/models/recognition/viptr/pytorch.py +3 -3
  66. doctr/models/recognition/vitstr/__init__.py +1 -6
  67. doctr/models/recognition/vitstr/pytorch.py +3 -3
  68. doctr/models/recognition/zoo.py +13 -13
  69. doctr/models/utils/__init__.py +1 -6
  70. doctr/models/utils/pytorch.py +1 -1
  71. doctr/transforms/functional/__init__.py +1 -6
  72. doctr/transforms/functional/pytorch.py +4 -4
  73. doctr/transforms/modules/__init__.py +1 -7
  74. doctr/transforms/modules/base.py +26 -92
  75. doctr/transforms/modules/pytorch.py +28 -26
  76. doctr/utils/geometry.py +6 -10
  77. doctr/utils/visualization.py +1 -1
  78. doctr/version.py +1 -1
  79. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +18 -75
  80. python_doctr-1.0.0.dist-info/RECORD +149 -0
  81. doctr/datasets/datasets/tensorflow.py +0 -59
  82. doctr/datasets/generator/tensorflow.py +0 -58
  83. doctr/datasets/loader.py +0 -94
  84. doctr/io/image/tensorflow.py +0 -101
  85. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  86. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  87. doctr/models/classification/predictor/tensorflow.py +0 -60
  88. doctr/models/classification/resnet/tensorflow.py +0 -418
  89. doctr/models/classification/textnet/tensorflow.py +0 -275
  90. doctr/models/classification/vgg/tensorflow.py +0 -125
  91. doctr/models/classification/vit/tensorflow.py +0 -201
  92. doctr/models/detection/_utils/tensorflow.py +0 -34
  93. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  94. doctr/models/detection/fast/tensorflow.py +0 -427
  95. doctr/models/detection/linknet/tensorflow.py +0 -377
  96. doctr/models/detection/predictor/tensorflow.py +0 -70
  97. doctr/models/kie_predictor/tensorflow.py +0 -187
  98. doctr/models/modules/layers/tensorflow.py +0 -171
  99. doctr/models/modules/transformer/tensorflow.py +0 -235
  100. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  101. doctr/models/predictor/tensorflow.py +0 -155
  102. doctr/models/preprocessor/tensorflow.py +0 -122
  103. doctr/models/recognition/crnn/tensorflow.py +0 -317
  104. doctr/models/recognition/master/tensorflow.py +0 -320
  105. doctr/models/recognition/parseq/tensorflow.py +0 -516
  106. doctr/models/recognition/predictor/tensorflow.py +0 -79
  107. doctr/models/recognition/sar/tensorflow.py +0 -423
  108. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  109. doctr/models/utils/tensorflow.py +0 -189
  110. doctr/transforms/functional/tensorflow.py +0 -254
  111. doctr/transforms/modules/tensorflow.py +0 -562
  112. python_doctr-0.12.0.dist-info/RECORD +0 -180
  113. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +0 -0
  114. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/licenses/LICENSE +0 -0
  115. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  116. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -1,423 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from copy import deepcopy
7
- from typing import Any
8
-
9
- import tensorflow as tf
10
- from tensorflow.keras import Model, Sequential, layers
11
-
12
- from doctr.datasets import VOCABS
13
- from doctr.utils.repr import NestedObject
14
-
15
- from ...classification import resnet31
16
- from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
17
- from ..core import RecognitionModel, RecognitionPostProcessor
18
-
19
- __all__ = ["SAR", "sar_resnet31"]
20
-
21
- default_cfgs: dict[str, dict[str, Any]] = {
22
- "sar_resnet31": {
23
- "mean": (0.694, 0.695, 0.693),
24
- "std": (0.299, 0.296, 0.301),
25
- "input_shape": (32, 128, 3),
26
- "vocab": VOCABS["french"],
27
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0",
28
- },
29
- }
30
-
31
-
32
- class SAREncoder(layers.Layer, NestedObject):
33
- """Implements encoder module of the SAR model
34
-
35
- Args:
36
- rnn_units: number of hidden rnn units
37
- dropout_prob: dropout probability
38
- """
39
-
40
- def __init__(self, rnn_units: int, dropout_prob: float = 0.0) -> None:
41
- super().__init__()
42
- self.rnn = Sequential([
43
- layers.LSTM(units=rnn_units, return_sequences=True, recurrent_dropout=dropout_prob),
44
- layers.LSTM(units=rnn_units, return_sequences=False, recurrent_dropout=dropout_prob),
45
- ])
46
-
47
- def call(
48
- self,
49
- x: tf.Tensor,
50
- **kwargs: Any,
51
- ) -> tf.Tensor:
52
- # (N, C)
53
- return self.rnn(x, **kwargs)
54
-
55
-
56
- class AttentionModule(layers.Layer, NestedObject):
57
- """Implements attention module of the SAR model
58
-
59
- Args:
60
- attention_units: number of hidden attention units
61
-
62
- """
63
-
64
- def __init__(self, attention_units: int) -> None:
65
- super().__init__()
66
- self.hidden_state_projector = layers.Conv2D(
67
- attention_units,
68
- 1,
69
- strides=1,
70
- use_bias=False,
71
- padding="same",
72
- kernel_initializer="he_normal",
73
- )
74
- self.features_projector = layers.Conv2D(
75
- attention_units,
76
- 3,
77
- strides=1,
78
- use_bias=True,
79
- padding="same",
80
- kernel_initializer="he_normal",
81
- )
82
- self.attention_projector = layers.Conv2D(
83
- 1,
84
- 1,
85
- strides=1,
86
- use_bias=False,
87
- padding="same",
88
- kernel_initializer="he_normal",
89
- )
90
- self.flatten = layers.Flatten()
91
-
92
- def call(
93
- self,
94
- features: tf.Tensor,
95
- hidden_state: tf.Tensor,
96
- **kwargs: Any,
97
- ) -> tf.Tensor:
98
- [H, W] = features.get_shape().as_list()[1:3]
99
- # shape (N, H, W, vgg_units) -> (N, H, W, attention_units)
100
- features_projection = self.features_projector(features, **kwargs)
101
- # shape (N, 1, 1, rnn_units) -> (N, 1, 1, attention_units)
102
- hidden_state = tf.expand_dims(tf.expand_dims(hidden_state, axis=1), axis=1)
103
- hidden_state_projection = self.hidden_state_projector(hidden_state, **kwargs)
104
- projection = tf.math.tanh(hidden_state_projection + features_projection)
105
- # shape (N, H, W, attention_units) -> (N, H, W, 1)
106
- attention = self.attention_projector(projection, **kwargs)
107
- # shape (N, H, W, 1) -> (N, H * W)
108
- attention = self.flatten(attention)
109
- attention = tf.nn.softmax(attention)
110
- # shape (N, H * W) -> (N, H, W, 1)
111
- attention_map = tf.reshape(attention, [-1, H, W, 1])
112
- glimpse = tf.math.multiply(features, attention_map)
113
- # shape (N, H * W) -> (N, C)
114
- return tf.reduce_sum(glimpse, axis=[1, 2])
115
-
116
-
117
- class SARDecoder(layers.Layer, NestedObject):
118
- """Implements decoder module of the SAR model
119
-
120
- Args:
121
- rnn_units: number of hidden units in recurrent cells
122
- max_length: maximum length of a sequence
123
- vocab_size: number of classes in the model alphabet
124
- embedding_units: number of hidden embedding units
125
- attention_units: number of hidden attention units
126
- num_decoder_cells: number of LSTMCell layers to stack
127
- dropout_prob: dropout probability
128
-
129
- """
130
-
131
- def __init__(
132
- self,
133
- rnn_units: int,
134
- max_length: int,
135
- vocab_size: int,
136
- embedding_units: int,
137
- attention_units: int,
138
- num_decoder_cells: int = 2,
139
- dropout_prob: float = 0.0,
140
- ) -> None:
141
- super().__init__()
142
- self.vocab_size = vocab_size
143
- self.max_length = max_length
144
-
145
- self.embed = layers.Dense(embedding_units, use_bias=False)
146
- self.embed_tgt = layers.Embedding(embedding_units, self.vocab_size + 1)
147
-
148
- self.lstm_cells = layers.StackedRNNCells([
149
- layers.LSTMCell(rnn_units, implementation=1) for _ in range(num_decoder_cells)
150
- ])
151
- self.attention_module = AttentionModule(attention_units)
152
- self.output_dense = layers.Dense(self.vocab_size + 1, use_bias=True)
153
- self.dropout = layers.Dropout(dropout_prob)
154
-
155
- def call(
156
- self,
157
- features: tf.Tensor,
158
- holistic: tf.Tensor,
159
- gt: tf.Tensor | None = None,
160
- **kwargs: Any,
161
- ) -> tf.Tensor:
162
- if gt is not None:
163
- gt_embedding = self.embed_tgt(gt, **kwargs)
164
-
165
- logits_list: list[tf.Tensor] = []
166
-
167
- for t in range(self.max_length + 1): # 32
168
- if t == 0:
169
- # step to init the first states of the LSTMCell
170
- states = self.lstm_cells.get_initial_state(
171
- inputs=None, batch_size=features.shape[0], dtype=features.dtype
172
- )
173
- prev_symbol = holistic
174
- elif t == 1:
175
- # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
176
- # (N, vocab_size + 1) --> (N, embedding_units)
177
- prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1], dtype=features.dtype)
178
- prev_symbol = self.embed(prev_symbol, **kwargs)
179
- else:
180
- if gt is not None and kwargs.get("training", False):
181
- # (N, embedding_units) -2 because of <bos> and <eos> (same)
182
- prev_symbol = self.embed(gt_embedding[:, t - 2], **kwargs)
183
- else:
184
- # -1 to start at timestep where prev_symbol was initialized
185
- index = tf.argmax(logits_list[t - 1], axis=-1)
186
- # update prev_symbol with ones at the index of the previous logit vector
187
- prev_symbol = self.embed(self.embed_tgt(index, **kwargs), **kwargs)
188
-
189
- # (N, C), (N, C) take the last hidden state and cell state from current timestep
190
- _, states = self.lstm_cells(prev_symbol, states, **kwargs)
191
- # states = (hidden_state, cell_state)
192
- hidden_state = states[0][0]
193
- # (N, H, W, C), (N, C) --> (N, C)
194
- glimpse = self.attention_module(features, hidden_state, **kwargs)
195
- # (N, C), (N, C) --> (N, 2 * C)
196
- logits = tf.concat([hidden_state, glimpse], axis=1)
197
- logits = self.dropout(logits, **kwargs)
198
- # (N, vocab_size + 1)
199
- logits_list.append(self.output_dense(logits, **kwargs))
200
-
201
- # (max_length + 1, N, vocab_size + 1) --> (N, max_length + 1, vocab_size + 1)
202
- return tf.transpose(tf.stack(logits_list[1:]), (1, 0, 2))
203
-
204
-
205
- class SAR(Model, RecognitionModel):
206
- """Implements a SAR architecture as described in `"Show, Attend and Read:A Simple and Strong Baseline for
207
- Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
208
-
209
- Args:
210
- feature_extractor: the backbone serving as feature extractor
211
- vocab: vocabulary used for encoding
212
- rnn_units: number of hidden units in both encoder and decoder LSTM
213
- embedding_units: number of embedding units
214
- attention_units: number of hidden units in attention module
215
- max_length: maximum word length handled by the model
216
- num_decoder_cells: number of LSTMCell layers to stack
217
- dropout_prob: dropout probability for the encoder and decoder
218
- exportable: onnx exportable returns only logits
219
- cfg: dictionary containing information about the model
220
- """
221
-
222
- _children_names: list[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"]
223
-
224
- def __init__(
225
- self,
226
- feature_extractor,
227
- vocab: str,
228
- rnn_units: int = 512,
229
- embedding_units: int = 512,
230
- attention_units: int = 512,
231
- max_length: int = 30,
232
- num_decoder_cells: int = 2,
233
- dropout_prob: float = 0.0,
234
- exportable: bool = False,
235
- cfg: dict[str, Any] | None = None,
236
- ) -> None:
237
- super().__init__()
238
- self.vocab = vocab
239
- self.exportable = exportable
240
- self.cfg = cfg
241
- self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word
242
-
243
- self.feat_extractor = feature_extractor
244
-
245
- self.encoder = SAREncoder(rnn_units, dropout_prob)
246
- self.decoder = SARDecoder(
247
- rnn_units,
248
- self.max_length,
249
- len(vocab),
250
- embedding_units,
251
- attention_units,
252
- num_decoder_cells,
253
- dropout_prob,
254
- )
255
-
256
- self.postprocessor = SARPostProcessor(vocab=vocab)
257
-
258
- def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
259
- """Load pretrained parameters onto the model
260
-
261
- Args:
262
- path_or_url: the path or URL to the model parameters (checkpoint)
263
- **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
264
- """
265
- load_pretrained_params(self, path_or_url, **kwargs)
266
-
267
- @staticmethod
268
- def compute_loss(
269
- model_output: tf.Tensor,
270
- gt: tf.Tensor,
271
- seq_len: tf.Tensor,
272
- ) -> tf.Tensor:
273
- """Compute categorical cross-entropy loss for the model.
274
- Sequences are masked after the EOS character.
275
-
276
- Args:
277
- gt: the encoded tensor with gt labels
278
- model_output: predicted logits of the model
279
- seq_len: lengths of each gt word inside the batch
280
-
281
- Returns:
282
- The loss of the model on the batch
283
- """
284
- # Input length : number of timesteps
285
- input_len = tf.shape(model_output)[1]
286
- # Add one for additional <eos> token
287
- seq_len = seq_len + 1
288
- # One-hot gt labels
289
- oh_gt = tf.one_hot(gt, depth=model_output.shape[2])
290
- # Compute loss
291
- cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt, model_output)
292
- # Compute mask
293
- mask_values = tf.zeros_like(cce)
294
- mask_2d = tf.sequence_mask(seq_len, input_len)
295
- masked_loss = tf.where(mask_2d, cce, mask_values)
296
- ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype))
297
- return tf.expand_dims(ce_loss, axis=1)
298
-
299
- def call(
300
- self,
301
- x: tf.Tensor,
302
- target: list[str] | None = None,
303
- return_model_output: bool = False,
304
- return_preds: bool = False,
305
- **kwargs: Any,
306
- ) -> dict[str, Any]:
307
- features = self.feat_extractor(x, **kwargs)
308
- # vertical max pooling --> (N, C, W)
309
- pooled_features = tf.reduce_max(features, axis=1)
310
- # holistic (N, C)
311
- encoded = self.encoder(pooled_features, **kwargs)
312
-
313
- if target is not None:
314
- gt, seq_len = self.build_target(target)
315
- seq_len = tf.cast(seq_len, tf.int32)
316
-
317
- if kwargs.get("training", False) and target is None:
318
- raise ValueError("Need to provide labels during training for teacher forcing")
319
-
320
- decoded_features = _bf16_to_float32(
321
- self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
322
- )
323
-
324
- out: dict[str, tf.Tensor] = {}
325
- if self.exportable:
326
- out["logits"] = decoded_features
327
- return out
328
-
329
- if return_model_output:
330
- out["out_map"] = decoded_features
331
-
332
- if target is None or return_preds:
333
- # Post-process boxes
334
- out["preds"] = self.postprocessor(decoded_features)
335
-
336
- if target is not None:
337
- out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
338
-
339
- return out
340
-
341
-
342
- class SARPostProcessor(RecognitionPostProcessor):
343
- """Post processor for SAR architectures
344
-
345
- Args:
346
- vocab: string containing the ordered sequence of supported characters
347
- """
348
-
349
- def __call__(
350
- self,
351
- logits: tf.Tensor,
352
- ) -> list[tuple[str, float]]:
353
- # compute pred with argmax for attention models
354
- out_idxs = tf.math.argmax(logits, axis=2)
355
- # N x L
356
- probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)
357
- # Take the minimum confidence of the sequence
358
- probs = tf.math.reduce_min(probs, axis=1)
359
-
360
- # decode raw output of the model with tf_label_to_idx
361
- out_idxs = tf.cast(out_idxs, dtype="int32")
362
- embedding = tf.constant(self._embedding, dtype=tf.string)
363
- decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
364
- decoded_strings_pred = tf.strings.split(decoded_strings_pred, "<eos>")
365
- decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
366
- word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
367
-
368
- return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
369
-
370
-
371
- def _sar(
372
- arch: str,
373
- pretrained: bool,
374
- backbone_fn,
375
- pretrained_backbone: bool = True,
376
- input_shape: tuple[int, int, int] | None = None,
377
- **kwargs: Any,
378
- ) -> SAR:
379
- pretrained_backbone = pretrained_backbone and not pretrained
380
-
381
- # Patch the config
382
- _cfg = deepcopy(default_cfgs[arch])
383
- _cfg["input_shape"] = input_shape or _cfg["input_shape"]
384
- _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
385
-
386
- # Feature extractor
387
- feat_extractor = backbone_fn(
388
- pretrained=pretrained_backbone,
389
- input_shape=_cfg["input_shape"],
390
- include_top=False,
391
- )
392
-
393
- kwargs["vocab"] = _cfg["vocab"]
394
-
395
- # Build the model
396
- model = SAR(feat_extractor, cfg=_cfg, **kwargs)
397
- _build_model(model)
398
- # Load pretrained parameters
399
- if pretrained:
400
- # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
401
- model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
402
-
403
- return model
404
-
405
-
406
- def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
407
- """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
408
- Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
409
-
410
- >>> import tensorflow as tf
411
- >>> from doctr.models import sar_resnet31
412
- >>> model = sar_resnet31(pretrained=False)
413
- >>> input_tensor = tf.random.uniform(shape=[1, 64, 256, 3], maxval=1, dtype=tf.float32)
414
- >>> out = model(input_tensor)
415
-
416
- Args:
417
- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
418
- **kwargs: keyword arguments of the SAR architecture
419
-
420
- Returns:
421
- text recognition architecture
422
- """
423
- return _sar("sar_resnet31", pretrained, resnet31, **kwargs)
@@ -1,285 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from copy import deepcopy
7
- from typing import Any
8
-
9
- import tensorflow as tf
10
- from tensorflow.keras import Model, layers
11
-
12
- from doctr.datasets import VOCABS
13
-
14
- from ...classification import vit_b, vit_s
15
- from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
16
- from .base import _ViTSTR, _ViTSTRPostProcessor
17
-
18
- __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
19
-
20
- default_cfgs: dict[str, dict[str, Any]] = {
21
- "vitstr_small": {
22
- "mean": (0.694, 0.695, 0.693),
23
- "std": (0.299, 0.296, 0.301),
24
- "input_shape": (32, 128, 3),
25
- "vocab": VOCABS["french"],
26
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0",
27
- },
28
- "vitstr_base": {
29
- "mean": (0.694, 0.695, 0.693),
30
- "std": (0.299, 0.296, 0.301),
31
- "input_shape": (32, 128, 3),
32
- "vocab": VOCABS["french"],
33
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0",
34
- },
35
- }
36
-
37
-
38
- class ViTSTR(_ViTSTR, Model):
39
- """Implements a ViTSTR architecture as described in `"Vision Transformer for Fast and
40
- Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
41
-
42
- Args:
43
- feature_extractor: the backbone serving as feature extractor
44
- vocab: vocabulary used for encoding
45
- embedding_units: number of embedding units
46
- max_length: maximum word length handled by the model
47
- dropout_prob: dropout probability for the encoder and decoder
48
- input_shape: input shape of the image
49
- exportable: onnx exportable returns only logits
50
- cfg: dictionary containing information about the model
51
- """
52
-
53
- _children_names: list[str] = ["feat_extractor", "postprocessor"]
54
-
55
- def __init__(
56
- self,
57
- feature_extractor,
58
- vocab: str,
59
- embedding_units: int,
60
- max_length: int = 32,
61
- dropout_prob: float = 0.0,
62
- input_shape: tuple[int, int, int] = (32, 128, 3), # different from paper
63
- exportable: bool = False,
64
- cfg: dict[str, Any] | None = None,
65
- ) -> None:
66
- super().__init__()
67
- self.vocab = vocab
68
- self.exportable = exportable
69
- self.cfg = cfg
70
- self.max_length = max_length + 2 # +2 for SOS and EOS
71
-
72
- self.feat_extractor = feature_extractor
73
- self.head = layers.Dense(len(self.vocab) + 1, name="head") # +1 for EOS
74
-
75
- self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
76
-
77
- def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
78
- """Load pretrained parameters onto the model
79
-
80
- Args:
81
- path_or_url: the path or URL to the model parameters (checkpoint)
82
- **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
83
- """
84
- load_pretrained_params(self, path_or_url, **kwargs)
85
-
86
- @staticmethod
87
- def compute_loss(
88
- model_output: tf.Tensor,
89
- gt: tf.Tensor,
90
- seq_len: list[int],
91
- ) -> tf.Tensor:
92
- """Compute categorical cross-entropy loss for the model.
93
- Sequences are masked after the EOS character.
94
-
95
- Args:
96
- model_output: predicted logits of the model
97
- gt: the encoded tensor with gt labels
98
- seq_len: lengths of each gt word inside the batch
99
-
100
- Returns:
101
- The loss of the model on the batch
102
- """
103
- # Input length : number of steps
104
- input_len = tf.shape(model_output)[1]
105
- # Add one for additional <eos> token (sos disappear in shift!)
106
- seq_len = tf.cast(seq_len, tf.int32) + 1
107
- # One-hot gt labels
108
- oh_gt = tf.one_hot(gt, depth=model_output.shape[2])
109
- # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
110
- # The "masked" first gt char is <sos>.
111
- cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt[:, 1:, :], model_output)
112
- # Compute mask
113
- mask_values = tf.zeros_like(cce)
114
- mask_2d = tf.sequence_mask(seq_len, input_len)
115
- masked_loss = tf.where(mask_2d, cce, mask_values)
116
- ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype))
117
-
118
- return tf.expand_dims(ce_loss, axis=1)
119
-
120
- def call(
121
- self,
122
- x: tf.Tensor,
123
- target: list[str] | None = None,
124
- return_model_output: bool = False,
125
- return_preds: bool = False,
126
- **kwargs: Any,
127
- ) -> dict[str, Any]:
128
- features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
129
-
130
- if target is not None:
131
- gt, seq_len = self.build_target(target)
132
- seq_len = tf.cast(seq_len, tf.int32)
133
-
134
- if kwargs.get("training", False) and target is None:
135
- raise ValueError("Need to provide labels during training")
136
-
137
- features = features[:, : self.max_length] # (batch_size, max_length, d_model)
138
- B, N, E = features.shape
139
- features = tf.reshape(features, (B * N, E))
140
- logits = tf.reshape(
141
- self.head(features, **kwargs), (B, N, len(self.vocab) + 1)
142
- ) # (batch_size, max_length, vocab + 1)
143
- decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
144
-
145
- out: dict[str, tf.Tensor] = {}
146
- if self.exportable:
147
- out["logits"] = decoded_features
148
- return out
149
-
150
- if return_model_output:
151
- out["out_map"] = decoded_features
152
-
153
- if target is None or return_preds:
154
- # Post-process boxes
155
- out["preds"] = self.postprocessor(decoded_features)
156
-
157
- if target is not None:
158
- out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
159
-
160
- return out
161
-
162
-
163
- class ViTSTRPostProcessor(_ViTSTRPostProcessor):
164
- """Post processor for ViTSTR architecture
165
-
166
- Args:
167
- vocab: string containing the ordered sequence of supported characters
168
- """
169
-
170
- def __call__(
171
- self,
172
- logits: tf.Tensor,
173
- ) -> list[tuple[str, float]]:
174
- # compute pred with argmax for attention models
175
- out_idxs = tf.math.argmax(logits, axis=2)
176
- preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
177
-
178
- # decode raw output of the model with tf_label_to_idx
179
- out_idxs = tf.cast(out_idxs, dtype="int32")
180
- embedding = tf.constant(self._embedding, dtype=tf.string)
181
- decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
182
- decoded_strings_pred = tf.strings.split(decoded_strings_pred, "<eos>")
183
- decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
184
- word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
185
-
186
- # compute probabilties for each word up to the EOS token
187
- probs = [
188
- preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
189
- for i, word in enumerate(word_values)
190
- ]
191
-
192
- return list(zip(word_values, probs))
193
-
194
-
195
- def _vitstr(
196
- arch: str,
197
- pretrained: bool,
198
- backbone_fn,
199
- input_shape: tuple[int, int, int] | None = None,
200
- **kwargs: Any,
201
- ) -> ViTSTR:
202
- # Patch the config
203
- _cfg = deepcopy(default_cfgs[arch])
204
- _cfg["input_shape"] = input_shape or _cfg["input_shape"]
205
- _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
206
- patch_size = kwargs.get("patch_size", (4, 8))
207
-
208
- kwargs["vocab"] = _cfg["vocab"]
209
-
210
- # Feature extractor
211
- feat_extractor = backbone_fn(
212
- # NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch
213
- pretrained=False,
214
- input_shape=_cfg["input_shape"],
215
- patch_size=patch_size,
216
- include_top=False,
217
- )
218
-
219
- kwargs.pop("patch_size", None)
220
- kwargs.pop("pretrained_backbone", None)
221
-
222
- # Build the model
223
- model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs)
224
- _build_model(model)
225
-
226
- # Load pretrained parameters
227
- if pretrained:
228
- # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
229
- model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
230
-
231
- return model
232
-
233
-
234
- def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
235
- """ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
236
- <https://arxiv.org/pdf/2105.08582.pdf>`_.
237
-
238
- >>> import tensorflow as tf
239
- >>> from doctr.models import vitstr_small
240
- >>> model = vitstr_small(pretrained=False)
241
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
242
- >>> out = model(input_tensor)
243
-
244
- Args:
245
- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
246
- **kwargs: keyword arguments of the ViTSTR architecture
247
-
248
- Returns:
249
- text recognition architecture
250
- """
251
- return _vitstr(
252
- "vitstr_small",
253
- pretrained,
254
- vit_s,
255
- embedding_units=384,
256
- patch_size=(4, 8),
257
- **kwargs,
258
- )
259
-
260
-
261
- def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
262
- """ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
263
- <https://arxiv.org/pdf/2105.08582.pdf>`_.
264
-
265
- >>> import tensorflow as tf
266
- >>> from doctr.models import vitstr_base
267
- >>> model = vitstr_base(pretrained=False)
268
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
269
- >>> out = model(input_tensor)
270
-
271
- Args:
272
- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
273
- **kwargs: keyword arguments of the ViTSTR architecture
274
-
275
- Returns:
276
- text recognition architecture
277
- """
278
- return _vitstr(
279
- "vitstr_base",
280
- pretrained,
281
- vit_b,
282
- embedding_units=768,
283
- patch_size=(4, 8),
284
- **kwargs,
285
- )