python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -1,419 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
-
8
- from copy import deepcopy
9
- from typing import Any
10
-
11
- import numpy as np
12
- import tensorflow as tf
13
- from tensorflow.keras import Model, Sequential, layers
14
-
15
- from doctr.file_utils import CLASS_NAME
16
- from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params
17
- from doctr.utils.repr import NestedObject
18
-
19
- from ...classification import textnet_base, textnet_small, textnet_tiny
20
- from ...modules.layers import FASTConvLayer
21
- from .base import _FAST, FASTPostProcessor
22
-
23
- __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
24
-
25
-
26
- default_cfgs: dict[str, dict[str, Any]] = {
27
- "fast_tiny": {
28
- "input_shape": (1024, 1024, 3),
29
- "mean": (0.798, 0.785, 0.772),
30
- "std": (0.264, 0.2749, 0.287),
31
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0",
32
- },
33
- "fast_small": {
34
- "input_shape": (1024, 1024, 3),
35
- "mean": (0.798, 0.785, 0.772),
36
- "std": (0.264, 0.2749, 0.287),
37
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0",
38
- },
39
- "fast_base": {
40
- "input_shape": (1024, 1024, 3),
41
- "mean": (0.798, 0.785, 0.772),
42
- "std": (0.264, 0.2749, 0.287),
43
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0",
44
- },
45
- }
46
-
47
-
48
- class FastNeck(layers.Layer, NestedObject):
49
- """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer.
50
-
51
- Args:
52
- in_channels: number of input channels
53
- out_channels: number of output channels
54
- """
55
-
56
- def __init__(
57
- self,
58
- in_channels: int,
59
- out_channels: int = 128,
60
- ) -> None:
61
- super().__init__()
62
- self.reduction = [FASTConvLayer(in_channels * scale, out_channels, kernel_size=3) for scale in [1, 2, 4, 8]]
63
-
64
- def _upsample(self, x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
65
- return tf.image.resize(x, size=y.shape[1:3], method="bilinear")
66
-
67
- def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
68
- f1, f2, f3, f4 = x
69
- f1, f2, f3, f4 = [reduction(f, **kwargs) for reduction, f in zip(self.reduction, (f1, f2, f3, f4))]
70
- f2, f3, f4 = [self._upsample(f, f1) for f in (f2, f3, f4)]
71
- f = tf.concat((f1, f2, f3, f4), axis=-1)
72
- return f
73
-
74
-
75
- class FastHead(Sequential):
76
- """Head of the FAST architecture
77
-
78
- Args:
79
- in_channels: number of input channels
80
- num_classes: number of output classes
81
- out_channels: number of output channels
82
- dropout: dropout probability
83
- """
84
-
85
- def __init__(
86
- self,
87
- in_channels: int,
88
- num_classes: int,
89
- out_channels: int = 128,
90
- dropout: float = 0.1,
91
- ) -> None:
92
- _layers = [
93
- FASTConvLayer(in_channels, out_channels, kernel_size=3),
94
- layers.Dropout(dropout),
95
- layers.Conv2D(num_classes, kernel_size=1, use_bias=False),
96
- ]
97
- super().__init__(_layers)
98
-
99
-
100
- class FAST(_FAST, Model, NestedObject):
101
- """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
102
- <https://arxiv.org/pdf/2111.02394.pdf>`_.
103
-
104
- Args:
105
- feature extractor: the backbone serving as feature extractor
106
- bin_thresh: threshold for binarization
107
- box_thresh: minimal objectness score to consider a box
108
- dropout_prob: dropout probability
109
- pooling_size: size of the pooling layer
110
- assume_straight_pages: if True, fit straight bounding boxes only
111
- exportable: onnx exportable returns only logits
112
- cfg: the configuration dict of the model
113
- class_names: list of class names
114
- """
115
-
116
- _children_names: list[str] = ["feat_extractor", "neck", "head", "postprocessor"]
117
-
118
- def __init__(
119
- self,
120
- feature_extractor: IntermediateLayerGetter,
121
- bin_thresh: float = 0.1,
122
- box_thresh: float = 0.1,
123
- dropout_prob: float = 0.1,
124
- pooling_size: int = 4, # different from paper performs better on close text-rich images
125
- assume_straight_pages: bool = True,
126
- exportable: bool = False,
127
- cfg: dict[str, Any] = {},
128
- class_names: list[str] = [CLASS_NAME],
129
- ) -> None:
130
- super().__init__()
131
- self.class_names = class_names
132
- num_classes: int = len(self.class_names)
133
- self.cfg = cfg
134
-
135
- self.feat_extractor = feature_extractor
136
- self.exportable = exportable
137
- self.assume_straight_pages = assume_straight_pages
138
-
139
- # Identify the number of channels for the neck & head initialization
140
- feat_out_channels = [
141
- layers.Input(shape=in_shape[1:]).shape[-1] for in_shape in self.feat_extractor.output_shape
142
- ]
143
- # Initialize neck & head
144
- self.neck = FastNeck(feat_out_channels[0], feat_out_channels[1])
145
- self.head = FastHead(feat_out_channels[-1], num_classes, feat_out_channels[1], dropout_prob)
146
-
147
- # NOTE: The post processing from the paper works not well for text-rich images
148
- # so we use a modified version from DBNet
149
- self.postprocessor = FASTPostProcessor(
150
- assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
151
- )
152
-
153
- # Pooling layer as erosion reversal as described in the paper
154
- self.pooling = layers.MaxPooling2D(pool_size=pooling_size // 2 + 1, strides=1, padding="same")
155
-
156
- def compute_loss(
157
- self,
158
- out_map: tf.Tensor,
159
- target: list[dict[str, np.ndarray]],
160
- eps: float = 1e-6,
161
- ) -> tf.Tensor:
162
- """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
163
-
164
- Args:
165
- out_map: output feature map of the model of shape (N, num_classes, H, W)
166
- target: list of dictionary where each dict has a `boxes` and a `flags` entry
167
- eps: epsilon factor in dice loss
168
-
169
- Returns:
170
- A loss tensor
171
- """
172
- targets = self.build_target(target, out_map.shape[1:], True)
173
-
174
- seg_target = tf.convert_to_tensor(targets[0], dtype=out_map.dtype)
175
- seg_mask = tf.convert_to_tensor(targets[1], dtype=out_map.dtype)
176
- shrunken_kernel = tf.convert_to_tensor(targets[2], dtype=out_map.dtype)
177
-
178
- def ohem(score: tf.Tensor, gt: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
179
- pos_num = tf.reduce_sum(tf.cast(gt > 0.5, dtype=tf.int32)) - tf.reduce_sum(
180
- tf.cast((gt > 0.5) & (mask <= 0.5), dtype=tf.int32)
181
- )
182
- neg_num = tf.reduce_sum(tf.cast(gt <= 0.5, dtype=tf.int32))
183
- neg_num = tf.minimum(pos_num * 3, neg_num)
184
-
185
- if neg_num == 0 or pos_num == 0:
186
- return mask
187
-
188
- neg_score_sorted, _ = tf.nn.top_k(-tf.boolean_mask(score, gt <= 0.5), k=neg_num)
189
- threshold = -neg_score_sorted[-1]
190
-
191
- selected_mask = tf.math.logical_and((score >= threshold) | (gt > 0.5), (mask > 0.5))
192
- return tf.cast(selected_mask, dtype=tf.float32)
193
-
194
- if len(self.class_names) > 1:
195
- kernels = tf.nn.softmax(out_map, axis=-1)
196
- prob_map = tf.nn.softmax(self.pooling(out_map), axis=-1)
197
- else:
198
- kernels = tf.sigmoid(out_map)
199
- prob_map = tf.sigmoid(self.pooling(out_map))
200
-
201
- # As described in the paper, we use the Dice loss for the text segmentation map and the Dice loss scaled by 0.5.
202
- selected_masks = tf.stack(
203
- [ohem(score, gt, mask) for score, gt, mask in zip(prob_map, seg_target, seg_mask)], axis=0
204
- )
205
- inter = tf.reduce_sum(selected_masks * prob_map * seg_target, axis=(0, 1, 2))
206
- cardinality = tf.reduce_sum(selected_masks * (prob_map + seg_target), axis=(0, 1, 2))
207
- text_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps))) * 0.5
208
-
209
- # As described in the paper, we use the Dice loss for the text kernel map.
210
- selected_masks = seg_target * seg_mask
211
- inter = tf.reduce_sum(selected_masks * kernels * shrunken_kernel, axis=(0, 1, 2))
212
- cardinality = tf.reduce_sum(selected_masks * (kernels + shrunken_kernel), axis=(0, 1, 2))
213
- kernel_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps)))
214
-
215
- return text_loss + kernel_loss
216
-
217
- def call(
218
- self,
219
- x: tf.Tensor,
220
- target: list[dict[str, np.ndarray]] | None = None,
221
- return_model_output: bool = False,
222
- return_preds: bool = False,
223
- **kwargs: Any,
224
- ) -> dict[str, Any]:
225
- feat_maps = self.feat_extractor(x, **kwargs)
226
- # Pass through the Neck & Head & Upsample
227
- feat_concat = self.neck(feat_maps, **kwargs)
228
- logits: tf.Tensor = self.head(feat_concat, **kwargs)
229
- logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs)
230
-
231
- out: dict[str, tf.Tensor] = {}
232
- if self.exportable:
233
- out["logits"] = logits
234
- return out
235
-
236
- if return_model_output or target is None or return_preds:
237
- prob_map = _bf16_to_float32(tf.math.sigmoid(self.pooling(logits, **kwargs)))
238
-
239
- if return_model_output:
240
- out["out_map"] = prob_map
241
-
242
- if target is None or return_preds:
243
- # Post-process boxes (keep only text predictions)
244
- out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]
245
-
246
- if target is not None:
247
- loss = self.compute_loss(logits, target)
248
- out["loss"] = loss
249
-
250
- return out
251
-
252
-
253
- def reparameterize(model: FAST | layers.Layer) -> FAST:
254
- """Fuse batchnorm and conv layers and reparameterize the model
255
-
256
- args:
257
-
258
- model: the FAST model to reparameterize
259
-
260
- Returns:
261
- the reparameterized model
262
- """
263
- last_conv = None
264
- last_conv_idx = None
265
-
266
- for idx, layer in enumerate(model.layers):
267
- if hasattr(layer, "layers") or isinstance(
268
- layer, (FASTConvLayer, FastNeck, FastHead, layers.BatchNormalization, layers.Conv2D)
269
- ):
270
- if isinstance(layer, layers.BatchNormalization):
271
- # fuse batchnorm only if it is followed by a conv layer
272
- if last_conv is None:
273
- continue
274
- conv_w = last_conv.kernel
275
- conv_b = last_conv.bias if last_conv.use_bias else tf.zeros_like(layer.moving_mean)
276
-
277
- factor = layer.gamma / tf.sqrt(layer.moving_variance + layer.epsilon)
278
- last_conv.kernel = conv_w * factor.numpy().reshape([1, 1, 1, -1])
279
- if last_conv.use_bias:
280
- last_conv.bias.assign((conv_b - layer.moving_mean) * factor + layer.beta)
281
- model.layers[last_conv_idx] = last_conv # Replace the last conv layer with the fused version
282
- model.layers[idx] = layers.Lambda(lambda x: x)
283
- last_conv = None
284
- elif isinstance(layer, layers.Conv2D):
285
- last_conv = layer
286
- last_conv_idx = idx
287
- elif isinstance(layer, FASTConvLayer):
288
- layer.reparameterize_layer()
289
- elif isinstance(layer, FastNeck):
290
- for reduction in layer.reduction:
291
- reduction.reparameterize_layer()
292
- elif isinstance(layer, FastHead):
293
- reparameterize(layer)
294
- else:
295
- reparameterize(layer)
296
- return model
297
-
298
-
299
- def _fast(
300
- arch: str,
301
- pretrained: bool,
302
- backbone_fn,
303
- feat_layers: list[str],
304
- pretrained_backbone: bool = True,
305
- input_shape: tuple[int, int, int] | None = None,
306
- **kwargs: Any,
307
- ) -> FAST:
308
- pretrained_backbone = pretrained_backbone and not pretrained
309
-
310
- # Patch the config
311
- _cfg = deepcopy(default_cfgs[arch])
312
- _cfg["input_shape"] = input_shape or _cfg["input_shape"]
313
- if not kwargs.get("class_names", None):
314
- kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME])
315
- else:
316
- kwargs["class_names"] = sorted(kwargs["class_names"])
317
-
318
- # Feature extractor
319
- feat_extractor = IntermediateLayerGetter(
320
- backbone_fn(
321
- input_shape=_cfg["input_shape"],
322
- include_top=False,
323
- pretrained=pretrained_backbone,
324
- ),
325
- feat_layers,
326
- )
327
-
328
- # Build the model
329
- model = FAST(feat_extractor, cfg=_cfg, **kwargs)
330
- _build_model(model)
331
-
332
- # Load pretrained parameters
333
- if pretrained:
334
- # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
335
- load_pretrained_params(
336
- model,
337
- _cfg["url"],
338
- skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
339
- )
340
-
341
- return model
342
-
343
-
344
- def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
345
- """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
346
- <https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
347
-
348
- >>> import tensorflow as tf
349
- >>> from doctr.models import fast_tiny
350
- >>> model = fast_tiny(pretrained=True)
351
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
352
- >>> out = model(input_tensor)
353
-
354
- Args:
355
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
356
- **kwargs: keyword arguments of the DBNet architecture
357
-
358
- Returns:
359
- text detection architecture
360
- """
361
- return _fast(
362
- "fast_tiny",
363
- pretrained,
364
- textnet_tiny,
365
- ["stage_0", "stage_1", "stage_2", "stage_3"],
366
- **kwargs,
367
- )
368
-
369
-
370
- def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
371
- """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
372
- <https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
373
-
374
- >>> import tensorflow as tf
375
- >>> from doctr.models import fast_small
376
- >>> model = fast_small(pretrained=True)
377
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
378
- >>> out = model(input_tensor)
379
-
380
- Args:
381
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
382
- **kwargs: keyword arguments of the DBNet architecture
383
-
384
- Returns:
385
- text detection architecture
386
- """
387
- return _fast(
388
- "fast_small",
389
- pretrained,
390
- textnet_small,
391
- ["stage_0", "stage_1", "stage_2", "stage_3"],
392
- **kwargs,
393
- )
394
-
395
-
396
- def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
397
- """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
398
- <https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
399
-
400
- >>> import tensorflow as tf
401
- >>> from doctr.models import fast_base
402
- >>> model = fast_base(pretrained=True)
403
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
404
- >>> out = model(input_tensor)
405
-
406
- Args:
407
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
408
- **kwargs: keyword arguments of the DBNet architecture
409
-
410
- Returns:
411
- text detection architecture
412
- """
413
- return _fast(
414
- "fast_base",
415
- pretrained,
416
- textnet_base,
417
- ["stage_0", "stage_1", "stage_2", "stage_3"],
418
- **kwargs,
419
- )