python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -1,122 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import math
7
- from typing import Any
8
-
9
- import numpy as np
10
- import tensorflow as tf
11
-
12
- from doctr.transforms import Normalize, Resize
13
- from doctr.utils.multithreading import multithread_exec
14
- from doctr.utils.repr import NestedObject
15
-
16
- __all__ = ["PreProcessor"]
17
-
18
-
19
- class PreProcessor(NestedObject):
20
- """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
21
-
22
- Args:
23
- output_size: expected size of each page in format (H, W)
24
- batch_size: the size of page batches
25
- mean: mean value of the training distribution by channel
26
- std: standard deviation of the training distribution by channel
27
- **kwargs: additional arguments for the resizing operation
28
- """
29
-
30
- _children_names: list[str] = ["resize", "normalize"]
31
-
32
- def __init__(
33
- self,
34
- output_size: tuple[int, int],
35
- batch_size: int,
36
- mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
37
- std: tuple[float, float, float] = (1.0, 1.0, 1.0),
38
- **kwargs: Any,
39
- ) -> None:
40
- self.batch_size = batch_size
41
- self.resize = Resize(output_size, **kwargs)
42
- # Perform the division by 255 at the same time
43
- self.normalize = Normalize(mean, std)
44
- self._runs_on_cuda = tf.config.list_physical_devices("GPU") != []
45
-
46
- def batch_inputs(self, samples: list[tf.Tensor]) -> list[tf.Tensor]:
47
- """Gather samples into batches for inference purposes
48
-
49
- Args:
50
- samples: list of samples (tf.Tensor)
51
-
52
- Returns:
53
- list of batched samples
54
- """
55
- num_batches = int(math.ceil(len(samples) / self.batch_size))
56
- batches = [
57
- tf.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], axis=0)
58
- for idx in range(int(num_batches))
59
- ]
60
-
61
- return batches
62
-
63
- def sample_transforms(self, x: np.ndarray | tf.Tensor) -> tf.Tensor:
64
- if x.ndim != 3:
65
- raise AssertionError("expected list of 3D Tensors")
66
- if isinstance(x, np.ndarray):
67
- if x.dtype not in (np.uint8, np.float32):
68
- raise TypeError("unsupported data type for numpy.ndarray")
69
- x = tf.convert_to_tensor(x)
70
- elif x.dtype not in (tf.uint8, tf.float16, tf.float32):
71
- raise TypeError("unsupported data type for torch.Tensor")
72
- # Data type & 255 division
73
- if x.dtype == tf.uint8:
74
- x = tf.image.convert_image_dtype(x, dtype=tf.float32)
75
- # Resizing
76
- x = self.resize(x)
77
-
78
- return x
79
-
80
- def __call__(self, x: tf.Tensor | np.ndarray | list[tf.Tensor | np.ndarray]) -> list[tf.Tensor]:
81
- """Prepare document data for model forwarding
82
-
83
- Args:
84
- x: list of images (np.array) or tensors (already resized and batched)
85
-
86
- Returns:
87
- list of page batches
88
- """
89
- # Input type check
90
- if isinstance(x, (np.ndarray, tf.Tensor)):
91
- if x.ndim != 4:
92
- raise AssertionError("expected 4D Tensor")
93
- if isinstance(x, np.ndarray):
94
- if x.dtype not in (np.uint8, np.float32):
95
- raise TypeError("unsupported data type for numpy.ndarray")
96
- x = tf.convert_to_tensor(x)
97
- elif x.dtype not in (tf.uint8, tf.float16, tf.float32):
98
- raise TypeError("unsupported data type for torch.Tensor")
99
-
100
- # Data type & 255 division
101
- if x.dtype == tf.uint8:
102
- x = tf.image.convert_image_dtype(x, dtype=tf.float32)
103
- # Resizing
104
- if (x.shape[1], x.shape[2]) != self.resize.output_size:
105
- x = tf.image.resize(
106
- x, self.resize.output_size, method=self.resize.method, antialias=self.resize.antialias
107
- )
108
-
109
- batches = [x]
110
-
111
- elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x):
112
- # Sample transform (to tensor, resize)
113
- samples = list(multithread_exec(self.sample_transforms, x, threads=1 if self._runs_on_cuda else None))
114
- # Batching
115
- batches = self.batch_inputs(samples)
116
- else:
117
- raise TypeError(f"invalid input type: {type(x)}")
118
-
119
- # Batch transforms (normalize)
120
- batches = list(multithread_exec(self.normalize, batches, threads=1 if self._runs_on_cuda else None))
121
-
122
- return batches
@@ -1,308 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from copy import deepcopy
7
- from typing import Any
8
-
9
- import tensorflow as tf
10
- from tensorflow.keras import layers
11
- from tensorflow.keras.models import Model, Sequential
12
-
13
- from doctr.datasets import VOCABS
14
-
15
- from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
16
- from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
17
- from ..core import RecognitionModel, RecognitionPostProcessor
18
-
19
- __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
20
-
21
- default_cfgs: dict[str, dict[str, Any]] = {
22
- "crnn_vgg16_bn": {
23
- "mean": (0.694, 0.695, 0.693),
24
- "std": (0.299, 0.296, 0.301),
25
- "input_shape": (32, 128, 3),
26
- "vocab": VOCABS["legacy_french"],
27
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0",
28
- },
29
- "crnn_mobilenet_v3_small": {
30
- "mean": (0.694, 0.695, 0.693),
31
- "std": (0.299, 0.296, 0.301),
32
- "input_shape": (32, 128, 3),
33
- "vocab": VOCABS["french"],
34
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0",
35
- },
36
- "crnn_mobilenet_v3_large": {
37
- "mean": (0.694, 0.695, 0.693),
38
- "std": (0.299, 0.296, 0.301),
39
- "input_shape": (32, 128, 3),
40
- "vocab": VOCABS["french"],
41
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0",
42
- },
43
- }
44
-
45
-
46
- class CTCPostProcessor(RecognitionPostProcessor):
47
- """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
48
-
49
- Args:
50
- vocab: string containing the ordered sequence of supported characters
51
- ignore_case: if True, ignore case of letters
52
- ignore_accents: if True, ignore accents of letters
53
- """
54
-
55
- def __call__(
56
- self,
57
- logits: tf.Tensor,
58
- beam_width: int = 1,
59
- top_paths: int = 1,
60
- ) -> list[tuple[str, float]] | list[tuple[list[str] | list[float]]]:
61
- """Performs decoding of raw output with CTC and decoding of CTC predictions
62
- with label_to_idx mapping dictionnary
63
-
64
- Args:
65
- logits: raw output of the model, shape BATCH_SIZE X SEQ_LEN X NUM_CLASSES + 1
66
- beam_width: An int scalar >= 0 (beam search beam width).
67
- top_paths: An int scalar >= 0, <= beam_width (controls output size).
68
-
69
- Returns:
70
- A list of decoded words of length BATCH_SIZE
71
-
72
-
73
- """
74
- # Decode CTC
75
- _decoded, _log_prob = tf.nn.ctc_beam_search_decoder(
76
- tf.transpose(logits, perm=[1, 0, 2]),
77
- tf.fill(tf.shape(logits)[:1], tf.shape(logits)[1]),
78
- beam_width=beam_width,
79
- top_paths=top_paths,
80
- )
81
-
82
- _decoded = tf.sparse.concat(
83
- 1,
84
- [tf.sparse.expand_dims(dec, axis=1) for dec in _decoded],
85
- expand_nonconcat_dims=True,
86
- ) # dim : batchsize x beamwidth x actual_max_len_predictions
87
- out_idxs = tf.sparse.to_dense(_decoded, default_value=len(self.vocab))
88
-
89
- # Map it to characters
90
- _decoded_strings_pred = tf.strings.reduce_join(
91
- inputs=tf.nn.embedding_lookup(tf.constant(self._embedding, dtype=tf.string), out_idxs),
92
- axis=-1,
93
- )
94
- _decoded_strings_pred = tf.strings.split(_decoded_strings_pred, "<eos>")
95
- decoded_strings_pred = tf.sparse.to_dense(_decoded_strings_pred.to_sparse(), default_value="not valid")[
96
- :, :, 0
97
- ] # dim : batch_size x beam_width
98
-
99
- if top_paths == 1:
100
- probs = tf.math.exp(tf.squeeze(_log_prob, axis=1)) # dim : batchsize
101
- decoded_strings_pred = tf.squeeze(decoded_strings_pred, axis=1)
102
- word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
103
- else:
104
- probs = tf.math.exp(_log_prob) # dim : batchsize x beamwidth
105
- word_values = [[word.decode() for word in words] for words in decoded_strings_pred.numpy().tolist()]
106
- return list(zip(word_values, probs.numpy().tolist()))
107
-
108
-
109
- class CRNN(RecognitionModel, Model):
110
- """Implements a CRNN architecture as described in `"An End-to-End Trainable Neural Network for Image-based
111
- Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
112
-
113
- Args:
114
- feature_extractor: the backbone serving as feature extractor
115
- vocab: vocabulary used for encoding
116
- rnn_units: number of units in the LSTM layers
117
- exportable: onnx exportable returns only logits
118
- beam_width: beam width for beam search decoding
119
- top_paths: number of top paths for beam search decoding
120
- cfg: configuration dictionary
121
- """
122
-
123
- _children_names: list[str] = ["feat_extractor", "decoder", "postprocessor"]
124
-
125
- def __init__(
126
- self,
127
- feature_extractor: Model,
128
- vocab: str,
129
- rnn_units: int = 128,
130
- exportable: bool = False,
131
- beam_width: int = 1,
132
- top_paths: int = 1,
133
- cfg: dict[str, Any] | None = None,
134
- ) -> None:
135
- # Initialize kernels
136
- h, w, c = feature_extractor.output_shape[1:]
137
-
138
- super().__init__()
139
- self.vocab = vocab
140
- self.max_length = w
141
- self.cfg = cfg
142
- self.exportable = exportable
143
- self.feat_extractor = feature_extractor
144
-
145
- self.decoder = Sequential([
146
- layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
147
- layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
148
- layers.Dense(units=len(vocab) + 1),
149
- ])
150
- self.decoder.build(input_shape=(None, w, h * c))
151
-
152
- self.postprocessor = CTCPostProcessor(vocab=vocab)
153
-
154
- self.beam_width = beam_width
155
- self.top_paths = top_paths
156
-
157
- def compute_loss(
158
- self,
159
- model_output: tf.Tensor,
160
- target: list[str],
161
- ) -> tf.Tensor:
162
- """Compute CTC loss for the model.
163
-
164
- Args:
165
- model_output: predicted logits of the model
166
- target: lengths of each gt word inside the batch
167
-
168
- Returns:
169
- The loss of the model on the batch
170
- """
171
- gt, seq_len = self.build_target(target)
172
- batch_len = model_output.shape[0]
173
- input_length = tf.fill((batch_len,), model_output.shape[1])
174
- ctc_loss = tf.nn.ctc_loss(
175
- gt, model_output, seq_len, input_length, logits_time_major=False, blank_index=len(self.vocab)
176
- )
177
- return ctc_loss
178
-
179
- def call(
180
- self,
181
- x: tf.Tensor,
182
- target: list[str] | None = None,
183
- return_model_output: bool = False,
184
- return_preds: bool = False,
185
- beam_width: int = 1,
186
- top_paths: int = 1,
187
- **kwargs: Any,
188
- ) -> dict[str, Any]:
189
- if kwargs.get("training", False) and target is None:
190
- raise ValueError("Need to provide labels during training")
191
-
192
- features = self.feat_extractor(x, **kwargs)
193
- # B x H x W x C --> B x W x H x C
194
- transposed_feat = tf.transpose(features, perm=[0, 2, 1, 3])
195
- w, h, c = transposed_feat.get_shape().as_list()[1:]
196
- # B x W x H x C --> B x W x H * C
197
- features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
198
- logits = _bf16_to_float32(self.decoder(features_seq, **kwargs))
199
-
200
- out: dict[str, tf.Tensor] = {}
201
- if self.exportable:
202
- out["logits"] = logits
203
- return out
204
-
205
- if return_model_output:
206
- out["out_map"] = logits
207
-
208
- if target is None or return_preds:
209
- # Post-process boxes
210
- out["preds"] = self.postprocessor(logits, beam_width=beam_width, top_paths=top_paths)
211
-
212
- if target is not None:
213
- out["loss"] = self.compute_loss(logits, target)
214
-
215
- return out
216
-
217
-
218
- def _crnn(
219
- arch: str,
220
- pretrained: bool,
221
- backbone_fn,
222
- pretrained_backbone: bool = True,
223
- input_shape: tuple[int, int, int] | None = None,
224
- **kwargs: Any,
225
- ) -> CRNN:
226
- pretrained_backbone = pretrained_backbone and not pretrained
227
-
228
- kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
229
-
230
- _cfg = deepcopy(default_cfgs[arch])
231
- _cfg["vocab"] = kwargs["vocab"]
232
- _cfg["input_shape"] = input_shape or default_cfgs[arch]["input_shape"]
233
-
234
- feat_extractor = backbone_fn(
235
- input_shape=_cfg["input_shape"],
236
- include_top=False,
237
- pretrained=pretrained_backbone,
238
- )
239
-
240
- # Build the model
241
- model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
242
- _build_model(model)
243
- # Load pretrained parameters
244
- if pretrained:
245
- # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
246
- load_pretrained_params(model, _cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
247
-
248
- return model
249
-
250
-
251
- def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
252
- """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
253
- Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
254
-
255
- >>> import tensorflow as tf
256
- >>> from doctr.models import crnn_vgg16_bn
257
- >>> model = crnn_vgg16_bn(pretrained=True)
258
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
259
- >>> out = model(input_tensor)
260
-
261
- Args:
262
- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
263
- **kwargs: keyword arguments of the CRNN architecture
264
-
265
- Returns:
266
- text recognition architecture
267
- """
268
- return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, **kwargs)
269
-
270
-
271
- def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
272
- """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
273
- Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
274
-
275
- >>> import tensorflow as tf
276
- >>> from doctr.models import crnn_mobilenet_v3_small
277
- >>> model = crnn_mobilenet_v3_small(pretrained=True)
278
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
279
- >>> out = model(input_tensor)
280
-
281
- Args:
282
- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
283
- **kwargs: keyword arguments of the CRNN architecture
284
-
285
- Returns:
286
- text recognition architecture
287
- """
288
- return _crnn("crnn_mobilenet_v3_small", pretrained, mobilenet_v3_small_r, **kwargs)
289
-
290
-
291
- def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
292
- """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
293
- Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
294
-
295
- >>> import tensorflow as tf
296
- >>> from doctr.models import crnn_mobilenet_v3_large
297
- >>> model = crnn_mobilenet_v3_large(pretrained=True)
298
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
299
- >>> out = model(input_tensor)
300
-
301
- Args:
302
- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
303
- **kwargs: keyword arguments of the CRNN architecture
304
-
305
- Returns:
306
- text recognition architecture
307
- """
308
- return _crnn("crnn_mobilenet_v3_large", pretrained, mobilenet_v3_large_r, **kwargs)