python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
@@ -0,0 +1,442 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, Callable, Dict, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ from torchvision.models._utils import IntermediateLayerGetter
13
+
14
+ from doctr.file_utils import CLASS_NAME
15
+
16
+ from ...classification import textnet_base, textnet_small, textnet_tiny
17
+ from ...modules.layers import FASTConvLayer
18
+ from ...utils import _bf16_to_float32, load_pretrained_params
19
+ from .base import _FAST, FASTPostProcessor
20
+
21
+ __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
22
+
23
+
24
+ default_cfgs: Dict[str, Dict[str, Any]] = {
25
+ "fast_tiny": {
26
+ "input_shape": (3, 1024, 1024),
27
+ "mean": (0.798, 0.785, 0.772),
28
+ "std": (0.264, 0.2749, 0.287),
29
+ "url": None,
30
+ },
31
+ "fast_small": {
32
+ "input_shape": (3, 1024, 1024),
33
+ "mean": (0.798, 0.785, 0.772),
34
+ "std": (0.264, 0.2749, 0.287),
35
+ "url": None,
36
+ },
37
+ "fast_base": {
38
+ "input_shape": (3, 1024, 1024),
39
+ "mean": (0.798, 0.785, 0.772),
40
+ "std": (0.264, 0.2749, 0.287),
41
+ "url": None,
42
+ },
43
+ }
44
+
45
+
46
+ class FastNeck(nn.Module):
47
+ """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layers.
48
+
49
+ Args:
50
+ ----
51
+ in_channels: number of input channels
52
+ out_channels: number of output channels
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ in_channels: int,
58
+ out_channels: int = 128,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.reduction = nn.ModuleList([
62
+ FASTConvLayer(in_channels * scale, out_channels, kernel_size=3) for scale in [1, 2, 4, 8]
63
+ ])
64
+
65
+ def _upsample(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
66
+ return F.interpolate(x, size=y.shape[-2:], mode="bilinear")
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ f1, f2, f3, f4 = x
70
+ f1, f2, f3, f4 = [reduction(f) for reduction, f in zip(self.reduction, (f1, f2, f3, f4))]
71
+ f2, f3, f4 = [self._upsample(f, f1) for f in (f2, f3, f4)]
72
+ f = torch.cat((f1, f2, f3, f4), 1)
73
+ return f
74
+
75
+
76
+ class FastHead(nn.Sequential):
77
+ """Head of the FAST architecture
78
+
79
+ Args:
80
+ ----
81
+ in_channels: number of input channels
82
+ num_classes: number of output classes
83
+ out_channels: number of output channels
84
+ dropout: dropout probability
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ in_channels: int,
90
+ num_classes: int,
91
+ out_channels: int = 128,
92
+ dropout: float = 0.1,
93
+ ) -> None:
94
+ _layers: List[nn.Module] = [
95
+ FASTConvLayer(in_channels, out_channels, kernel_size=3),
96
+ nn.Dropout(dropout),
97
+ nn.Conv2d(out_channels, num_classes, kernel_size=1, bias=False),
98
+ ]
99
+ super().__init__(*_layers)
100
+
101
+
102
+ class FAST(_FAST, nn.Module):
103
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
104
+ <https://arxiv.org/pdf/2111.02394.pdf>`_.
105
+
106
+ Args:
107
+ ----
108
+ feat extractor: the backbone serving as feature extractor
109
+ bin_thresh: threshold for binarization
110
+ box_thresh: minimal objectness score to consider a box
111
+ dropout_prob: dropout probability
112
+ pooling_size: size of the pooling layer
113
+ assume_straight_pages: if True, fit straight bounding boxes only
114
+ exportable: onnx exportable returns only logits
115
+ cfg: the configuration dict of the model
116
+ class_names: list of class names
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ feat_extractor: IntermediateLayerGetter,
122
+ bin_thresh: float = 0.3,
123
+ box_thresh: float = 0.1,
124
+ dropout_prob: float = 0.1,
125
+ pooling_size: int = 4, # different from paper performs better on close text-rich images
126
+ assume_straight_pages: bool = True,
127
+ exportable: bool = False,
128
+ cfg: Optional[Dict[str, Any]] = {},
129
+ class_names: List[str] = [CLASS_NAME],
130
+ ) -> None:
131
+ super().__init__()
132
+ self.class_names = class_names
133
+ num_classes: int = len(self.class_names)
134
+ self.cfg = cfg
135
+
136
+ self.exportable = exportable
137
+ self.assume_straight_pages = assume_straight_pages
138
+
139
+ self.feat_extractor = feat_extractor
140
+ # Identify the number of channels for the neck & head initialization
141
+ _is_training = self.feat_extractor.training
142
+ self.feat_extractor = self.feat_extractor.eval()
143
+ with torch.no_grad():
144
+ out = self.feat_extractor(torch.zeros((1, 3, 32, 32)))
145
+ feat_out_channels = [v.shape[1] for _, v in out.items()]
146
+
147
+ if _is_training:
148
+ self.feat_extractor = self.feat_extractor.train()
149
+
150
+ # Initialize neck & head
151
+ self.neck = FastNeck(feat_out_channels[0], feat_out_channels[1])
152
+ self.prob_head = FastHead(feat_out_channels[-1], num_classes, feat_out_channels[1], dropout_prob)
153
+
154
+ # NOTE: The post processing from the paper works not well for text-rich images
155
+ # so we use a modified version from DBNet
156
+ self.postprocessor = FASTPostProcessor(
157
+ assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
158
+ )
159
+
160
+ # Pooling layer as erosion reversal as described in the paper
161
+ self.pooling = nn.MaxPool2d(kernel_size=pooling_size // 2 + 1, stride=1, padding=(pooling_size // 2) // 2)
162
+
163
+ for n, m in self.named_modules():
164
+ # Don't override the initialization of the backbone
165
+ if n.startswith("feat_extractor."):
166
+ continue
167
+ if isinstance(m, nn.Conv2d):
168
+ nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu")
169
+ if m.bias is not None:
170
+ m.bias.data.zero_()
171
+ elif isinstance(m, nn.BatchNorm2d):
172
+ m.weight.data.fill_(1.0)
173
+ m.bias.data.zero_()
174
+
175
+ def forward(
176
+ self,
177
+ x: torch.Tensor,
178
+ target: Optional[List[np.ndarray]] = None,
179
+ return_model_output: bool = False,
180
+ return_preds: bool = False,
181
+ ) -> Dict[str, torch.Tensor]:
182
+ # Extract feature maps at different stages
183
+ feats = self.feat_extractor(x)
184
+ feats = [feats[str(idx)] for idx in range(len(feats))]
185
+ # Pass through the Neck & Head & Upsample
186
+ feat_concat = self.neck(feats)
187
+ logits = F.interpolate(self.prob_head(feat_concat), size=x.shape[-2:], mode="bilinear")
188
+
189
+ out: Dict[str, Any] = {}
190
+ if self.exportable:
191
+ out["logits"] = logits
192
+ return out
193
+
194
+ if return_model_output or target is None or return_preds:
195
+ prob_map = _bf16_to_float32(torch.sigmoid(self.pooling(logits)))
196
+
197
+ if return_model_output:
198
+ out["out_map"] = prob_map
199
+
200
+ if target is None or return_preds:
201
+ # Post-process boxes (keep only text predictions)
202
+ out["preds"] = [
203
+ dict(zip(self.class_names, preds))
204
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
205
+ ]
206
+
207
+ if target is not None:
208
+ loss = self.compute_loss(logits, target)
209
+ out["loss"] = loss
210
+
211
+ return out
212
+
213
+ def compute_loss(
214
+ self,
215
+ out_map: torch.Tensor,
216
+ target: List[np.ndarray],
217
+ eps: float = 1e-6,
218
+ ) -> torch.Tensor:
219
+ """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
220
+
221
+ Args:
222
+ ----
223
+ out_map: output feature map of the model of shape (N, num_classes, H, W)
224
+ target: list of dictionary where each dict has a `boxes` and a `flags` entry
225
+ eps: epsilon factor in dice loss
226
+
227
+ Returns:
228
+ -------
229
+ A loss tensor
230
+ """
231
+ targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
232
+
233
+ seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
234
+ shrunken_kernel = torch.from_numpy(targets[2]).to(out_map.device)
235
+ seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
236
+
237
+ def ohem_sample(score: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
238
+ masks = []
239
+ for class_idx in range(gt.shape[0]):
240
+ pos_num = int(torch.sum(gt[class_idx] > 0.5)) - int(
241
+ torch.sum((gt[class_idx] > 0.5) & (mask[class_idx] <= 0.5))
242
+ )
243
+ neg_num = int(torch.sum(gt[class_idx] <= 0.5))
244
+ neg_num = int(min(pos_num * 3, neg_num))
245
+
246
+ if neg_num == 0 or pos_num == 0:
247
+ masks.append(mask[class_idx])
248
+ continue
249
+
250
+ neg_score_sorted, _ = torch.sort(-score[class_idx][gt[class_idx] <= 0.5])
251
+ threshold = -neg_score_sorted[neg_num - 1]
252
+
253
+ selected_mask = ((score[class_idx] >= threshold) | (gt[class_idx] > 0.5)) & (mask[class_idx] > 0.5)
254
+ masks.append(selected_mask)
255
+ # combine all masks to shape (len(masks), H, W)
256
+ return torch.stack(masks).unsqueeze(0).float()
257
+
258
+ if len(self.class_names) > 1:
259
+ kernels = torch.softmax(out_map, dim=1)
260
+ prob_map = torch.softmax(self.pooling(out_map), dim=1)
261
+ else:
262
+ kernels = torch.sigmoid(out_map)
263
+ prob_map = torch.sigmoid(self.pooling(out_map))
264
+
265
+ # As described in the paper, we use the Dice loss for the text segmentation map and the Dice loss scaled by 0.5.
266
+ selected_masks = torch.cat(
267
+ [ohem_sample(score, gt, mask) for score, gt, mask in zip(prob_map, seg_target, seg_mask)], 0
268
+ ).float()
269
+ inter = (selected_masks * prob_map * seg_target).sum((0, 2, 3))
270
+ cardinality = (selected_masks * (prob_map + seg_target)).sum((0, 2, 3))
271
+ text_loss = (1 - 2 * inter / (cardinality + eps)).mean() * 0.5
272
+
273
+ # As described in the paper, we use the Dice loss for the text kernel map.
274
+ selected_masks = seg_target * seg_mask
275
+ inter = (selected_masks * kernels * shrunken_kernel).sum((0, 2, 3)) # noqa
276
+ cardinality = (selected_masks * (kernels + shrunken_kernel)).sum((0, 2, 3)) # noqa
277
+ kernel_loss = (1 - 2 * inter / (cardinality + eps)).mean()
278
+
279
+ return text_loss + kernel_loss
280
+
281
+
282
+ def reparameterize(model: Union[FAST, nn.Module]) -> FAST:
283
+ """Fuse batchnorm and conv layers and reparameterize the model
284
+
285
+ args:
286
+ ----
287
+ model: the FAST model to reparameterize
288
+
289
+ Returns:
290
+ -------
291
+ the reparameterized model
292
+ """
293
+ last_conv = None
294
+ last_conv_name = None
295
+
296
+ for module in model.modules():
297
+ if hasattr(module, "reparameterize_layer"):
298
+ module.reparameterize_layer()
299
+
300
+ for name, child in model.named_children():
301
+ if isinstance(child, nn.BatchNorm2d):
302
+ # fuse batchnorm only if it is followed by a conv layer
303
+ if last_conv is None:
304
+ continue
305
+ conv_w = last_conv.weight
306
+ conv_b = last_conv.bias if last_conv.bias is not None else torch.zeros_like(child.running_mean)
307
+
308
+ factor = child.weight / torch.sqrt(child.running_var + child.eps)
309
+ last_conv.weight = nn.Parameter(conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1]))
310
+ last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias)
311
+ model._modules[last_conv_name] = last_conv
312
+ model._modules[name] = nn.Identity()
313
+ last_conv = None
314
+ elif isinstance(child, nn.Conv2d):
315
+ last_conv = child
316
+ last_conv_name = name
317
+ else:
318
+ reparameterize(child)
319
+
320
+ return model # type: ignore[return-value]
321
+
322
+
323
+ def _fast(
324
+ arch: str,
325
+ pretrained: bool,
326
+ backbone_fn: Callable[[bool], nn.Module],
327
+ feat_layers: List[str],
328
+ pretrained_backbone: bool = True,
329
+ ignore_keys: Optional[List[str]] = None,
330
+ **kwargs: Any,
331
+ ) -> FAST:
332
+ pretrained_backbone = pretrained_backbone and not pretrained
333
+
334
+ # Build the feature extractor
335
+ feat_extractor = IntermediateLayerGetter(
336
+ backbone_fn(pretrained_backbone),
337
+ {layer_name: str(idx) for idx, layer_name in enumerate(feat_layers)},
338
+ )
339
+
340
+ if not kwargs.get("class_names", None):
341
+ kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
342
+ else:
343
+ kwargs["class_names"] = sorted(kwargs["class_names"])
344
+ # Build the model
345
+ model = FAST(feat_extractor, cfg=default_cfgs[arch], **kwargs)
346
+ # Load pretrained parameters
347
+ if pretrained:
348
+ # The number of class_names is not the same as the number of classes in the pretrained model =>
349
+ # remove the layer weights
350
+ _ignore_keys = (
351
+ ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
352
+ )
353
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
354
+
355
+ return model
356
+
357
+
358
+ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
359
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
360
+ <https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
361
+
362
+ >>> import torch
363
+ >>> from doctr.models import fast_tiny
364
+ >>> model = fast_tiny(pretrained=True)
365
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
366
+ >>> out = model(input_tensor)
367
+
368
+ Args:
369
+ ----
370
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
371
+ **kwargs: keyword arguments of the DBNet architecture
372
+
373
+ Returns:
374
+ -------
375
+ text detection architecture
376
+ """
377
+ return _fast(
378
+ "fast_tiny",
379
+ pretrained,
380
+ textnet_tiny,
381
+ ["3", "4", "5", "6"],
382
+ ignore_keys=["prob_head.2.weight"],
383
+ **kwargs,
384
+ )
385
+
386
+
387
+ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
388
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
389
+ <https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
390
+
391
+ >>> import torch
392
+ >>> from doctr.models import fast_small
393
+ >>> model = fast_small(pretrained=True)
394
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
395
+ >>> out = model(input_tensor)
396
+
397
+ Args:
398
+ ----
399
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
400
+ **kwargs: keyword arguments of the DBNet architecture
401
+
402
+ Returns:
403
+ -------
404
+ text detection architecture
405
+ """
406
+ return _fast(
407
+ "fast_small",
408
+ pretrained,
409
+ textnet_small,
410
+ ["3", "4", "5", "6"],
411
+ ignore_keys=["prob_head.2.weight"],
412
+ **kwargs,
413
+ )
414
+
415
+
416
+ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
417
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
418
+ <https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
419
+
420
+ >>> import torch
421
+ >>> from doctr.models import fast_base
422
+ >>> model = fast_base(pretrained=True)
423
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
424
+ >>> out = model(input_tensor)
425
+
426
+ Args:
427
+ ----
428
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
429
+ **kwargs: keyword arguments of the DBNet architecture
430
+
431
+ Returns:
432
+ -------
433
+ text detection architecture
434
+ """
435
+ return _fast(
436
+ "fast_base",
437
+ pretrained,
438
+ textnet_base,
439
+ ["3", "4", "5", "6"],
440
+ ignore_keys=["prob_head.2.weight"],
441
+ **kwargs,
442
+ )