python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
@@ -0,0 +1,428 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
+
8
+ from copy import deepcopy
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import tensorflow as tf
13
+ from tensorflow import keras
14
+ from tensorflow.keras import Sequential, layers
15
+
16
+ from doctr.file_utils import CLASS_NAME
17
+ from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params
18
+ from doctr.utils.repr import NestedObject
19
+
20
+ from ...classification import textnet_base, textnet_small, textnet_tiny
21
+ from ...modules.layers import FASTConvLayer
22
+ from .base import _FAST, FASTPostProcessor
23
+
24
+ __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
25
+
26
+
27
+ default_cfgs: Dict[str, Dict[str, Any]] = {
28
+ "fast_tiny": {
29
+ "input_shape": (1024, 1024, 3),
30
+ "mean": (0.798, 0.785, 0.772),
31
+ "std": (0.264, 0.2749, 0.287),
32
+ "url": None,
33
+ },
34
+ "fast_small": {
35
+ "input_shape": (1024, 1024, 3),
36
+ "mean": (0.798, 0.785, 0.772),
37
+ "std": (0.264, 0.2749, 0.287),
38
+ "url": None,
39
+ },
40
+ "fast_base": {
41
+ "input_shape": (1024, 1024, 3),
42
+ "mean": (0.798, 0.785, 0.772),
43
+ "std": (0.264, 0.2749, 0.287),
44
+ "url": None,
45
+ },
46
+ }
47
+
48
+
49
+ class FastNeck(layers.Layer, NestedObject):
50
+ """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer.
51
+
52
+ Args:
53
+ ----
54
+ in_channels: number of input channels
55
+ out_channels: number of output channels
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ in_channels: int,
61
+ out_channels: int = 128,
62
+ ) -> None:
63
+ super().__init__()
64
+ self.reduction = [FASTConvLayer(in_channels * scale, out_channels, kernel_size=3) for scale in [1, 2, 4, 8]]
65
+
66
+ def _upsample(self, x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
67
+ return tf.image.resize(x, size=y.shape[1:3], method="bilinear")
68
+
69
+ def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
70
+ f1, f2, f3, f4 = x
71
+ f1, f2, f3, f4 = [reduction(f, **kwargs) for reduction, f in zip(self.reduction, (f1, f2, f3, f4))]
72
+ f2, f3, f4 = [self._upsample(f, f1) for f in (f2, f3, f4)]
73
+ f = tf.concat((f1, f2, f3, f4), axis=-1)
74
+ return f
75
+
76
+
77
+ class FastHead(Sequential):
78
+ """Head of the FAST architecture
79
+
80
+ Args:
81
+ ----
82
+ in_channels: number of input channels
83
+ num_classes: number of output classes
84
+ out_channels: number of output channels
85
+ dropout: dropout probability
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ in_channels: int,
91
+ num_classes: int,
92
+ out_channels: int = 128,
93
+ dropout: float = 0.1,
94
+ ) -> None:
95
+ _layers = [
96
+ FASTConvLayer(in_channels, out_channels, kernel_size=3),
97
+ layers.Dropout(dropout),
98
+ layers.Conv2D(num_classes, kernel_size=1, use_bias=False),
99
+ ]
100
+ super().__init__(_layers)
101
+
102
+
103
+ class FAST(_FAST, keras.Model, NestedObject):
104
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
105
+ <https://arxiv.org/pdf/2111.02394.pdf>`_.
106
+
107
+ Args:
108
+ ----
109
+ feature extractor: the backbone serving as feature extractor
110
+ bin_thresh: threshold for binarization
111
+ box_thresh: minimal objectness score to consider a box
112
+ dropout_prob: dropout probability
113
+ pooling_size: size of the pooling layer
114
+ assume_straight_pages: if True, fit straight bounding boxes only
115
+ exportable: onnx exportable returns only logits
116
+ cfg: the configuration dict of the model
117
+ class_names: list of class names
118
+ """
119
+
120
+ _children_names: List[str] = ["feat_extractor", "neck", "head", "postprocessor"]
121
+
122
+ def __init__(
123
+ self,
124
+ feature_extractor: IntermediateLayerGetter,
125
+ bin_thresh: float = 0.3,
126
+ box_thresh: float = 0.1,
127
+ dropout_prob: float = 0.1,
128
+ pooling_size: int = 4, # different from paper performs better on close text-rich images
129
+ assume_straight_pages: bool = True,
130
+ exportable: bool = False,
131
+ cfg: Optional[Dict[str, Any]] = {},
132
+ class_names: List[str] = [CLASS_NAME],
133
+ ) -> None:
134
+ super().__init__()
135
+ self.class_names = class_names
136
+ num_classes: int = len(self.class_names)
137
+ self.cfg = cfg
138
+
139
+ self.feat_extractor = feature_extractor
140
+ self.exportable = exportable
141
+ self.assume_straight_pages = assume_straight_pages
142
+
143
+ # Identify the number of channels for the neck & head initialization
144
+ feat_out_channels = [
145
+ layers.Input(shape=in_shape[1:]).shape[-1] for in_shape in self.feat_extractor.output_shape
146
+ ]
147
+ # Initialize neck & head
148
+ self.neck = FastNeck(feat_out_channels[0], feat_out_channels[1])
149
+ self.head = FastHead(feat_out_channels[-1], num_classes, feat_out_channels[1], dropout_prob)
150
+
151
+ # NOTE: The post processing from the paper works not well for text-rich images
152
+ # so we use a modified version from DBNet
153
+ self.postprocessor = FASTPostProcessor(
154
+ assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
155
+ )
156
+
157
+ # Pooling layer as erosion reversal as described in the paper
158
+ self.pooling = layers.MaxPooling2D(pool_size=pooling_size // 2 + 1, strides=1, padding="same")
159
+
160
+ def compute_loss(
161
+ self,
162
+ out_map: tf.Tensor,
163
+ target: List[Dict[str, np.ndarray]],
164
+ eps: float = 1e-6,
165
+ ) -> tf.Tensor:
166
+ """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
167
+
168
+ Args:
169
+ ----
170
+ out_map: output feature map of the model of shape (N, num_classes, H, W)
171
+ target: list of dictionary where each dict has a `boxes` and a `flags` entry
172
+ eps: epsilon factor in dice loss
173
+
174
+ Returns:
175
+ -------
176
+ A loss tensor
177
+ """
178
+ targets = self.build_target(target, out_map.shape[1:], True)
179
+
180
+ seg_target = tf.convert_to_tensor(targets[0], dtype=out_map.dtype)
181
+ seg_mask = tf.convert_to_tensor(targets[1], dtype=out_map.dtype)
182
+ shrunken_kernel = tf.convert_to_tensor(targets[2], dtype=out_map.dtype)
183
+
184
+ def ohem(score: tf.Tensor, gt: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
185
+ pos_num = tf.reduce_sum(tf.cast(gt > 0.5, dtype=tf.int32)) - tf.reduce_sum(
186
+ tf.cast((gt > 0.5) & (mask <= 0.5), dtype=tf.int32)
187
+ )
188
+ neg_num = tf.reduce_sum(tf.cast(gt <= 0.5, dtype=tf.int32))
189
+ neg_num = tf.minimum(pos_num * 3, neg_num)
190
+
191
+ if neg_num == 0 or pos_num == 0:
192
+ return mask
193
+
194
+ neg_score_sorted, _ = tf.nn.top_k(-tf.boolean_mask(score, gt <= 0.5), k=neg_num)
195
+ threshold = -neg_score_sorted[-1]
196
+
197
+ selected_mask = tf.math.logical_and((score >= threshold) | (gt > 0.5), (mask > 0.5))
198
+ return tf.cast(selected_mask, dtype=tf.float32)
199
+
200
+ if len(self.class_names) > 1:
201
+ kernels = tf.nn.softmax(out_map, axis=-1)
202
+ prob_map = tf.nn.softmax(self.pooling(out_map), axis=-1)
203
+ else:
204
+ kernels = tf.sigmoid(out_map)
205
+ prob_map = tf.sigmoid(self.pooling(out_map))
206
+
207
+ # As described in the paper, we use the Dice loss for the text segmentation map and the Dice loss scaled by 0.5.
208
+ selected_masks = tf.stack(
209
+ [ohem(score, gt, mask) for score, gt, mask in zip(prob_map, seg_target, seg_mask)], axis=0
210
+ )
211
+ inter = tf.reduce_sum(selected_masks * prob_map * seg_target, axis=(0, 1, 2))
212
+ cardinality = tf.reduce_sum(selected_masks * (prob_map + seg_target), axis=(0, 1, 2))
213
+ text_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps))) * 0.5
214
+
215
+ # As described in the paper, we use the Dice loss for the text kernel map.
216
+ selected_masks = seg_target * seg_mask
217
+ inter = tf.reduce_sum(selected_masks * kernels * shrunken_kernel, axis=(0, 1, 2))
218
+ cardinality = tf.reduce_sum(selected_masks * (kernels + shrunken_kernel), axis=(0, 1, 2))
219
+ kernel_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps)))
220
+
221
+ return text_loss + kernel_loss
222
+
223
+ def call(
224
+ self,
225
+ x: tf.Tensor,
226
+ target: Optional[List[Dict[str, np.ndarray]]] = None,
227
+ return_model_output: bool = False,
228
+ return_preds: bool = False,
229
+ **kwargs: Any,
230
+ ) -> Dict[str, Any]:
231
+ feat_maps = self.feat_extractor(x, **kwargs)
232
+ # Pass through the Neck & Head & Upsample
233
+ feat_concat = self.neck(feat_maps, **kwargs)
234
+ logits: tf.Tensor = self.head(feat_concat, **kwargs)
235
+ logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs)
236
+
237
+ out: Dict[str, tf.Tensor] = {}
238
+ if self.exportable:
239
+ out["logits"] = logits
240
+ return out
241
+
242
+ if return_model_output or target is None or return_preds:
243
+ prob_map = _bf16_to_float32(tf.math.sigmoid(self.pooling(logits, **kwargs)))
244
+
245
+ if return_model_output:
246
+ out["out_map"] = prob_map
247
+
248
+ if target is None or return_preds:
249
+ # Post-process boxes (keep only text predictions)
250
+ out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]
251
+
252
+ if target is not None:
253
+ loss = self.compute_loss(logits, target)
254
+ out["loss"] = loss
255
+
256
+ return out
257
+
258
+
259
+ def reparameterize(model: Union[FAST, layers.Layer]) -> FAST:
260
+ """Fuse batchnorm and conv layers and reparameterize the model
261
+
262
+ args:
263
+ ----
264
+ model: the FAST model to reparameterize
265
+
266
+ Returns:
267
+ -------
268
+ the reparameterized model
269
+ """
270
+ last_conv = None
271
+ last_conv_idx = None
272
+
273
+ for idx, layer in enumerate(model.layers):
274
+ if hasattr(layer, "layers") or isinstance(
275
+ layer, (FASTConvLayer, FastNeck, FastHead, layers.BatchNormalization, layers.Conv2D)
276
+ ):
277
+ if isinstance(layer, layers.BatchNormalization):
278
+ # fuse batchnorm only if it is followed by a conv layer
279
+ if last_conv is None:
280
+ continue
281
+ conv_w = last_conv.kernel
282
+ conv_b = last_conv.bias if last_conv.use_bias else tf.zeros_like(layer.moving_mean)
283
+
284
+ factor = layer.gamma / tf.sqrt(layer.moving_variance + layer.epsilon)
285
+ last_conv.kernel = conv_w * factor.numpy().reshape([1, 1, 1, -1])
286
+ if last_conv.use_bias:
287
+ last_conv.bias.assign((conv_b - layer.moving_mean) * factor + layer.beta)
288
+ model.layers[last_conv_idx] = last_conv # Replace the last conv layer with the fused version
289
+ model.layers[idx] = layers.Lambda(lambda x: x)
290
+ last_conv = None
291
+ elif isinstance(layer, layers.Conv2D):
292
+ last_conv = layer
293
+ last_conv_idx = idx
294
+ elif isinstance(layer, FASTConvLayer):
295
+ layer.reparameterize_layer()
296
+ elif isinstance(layer, FastNeck):
297
+ for reduction in layer.reduction:
298
+ reduction.reparameterize_layer()
299
+ elif isinstance(layer, FastHead):
300
+ reparameterize(layer)
301
+ else:
302
+ reparameterize(layer)
303
+ return model
304
+
305
+
306
+ def _fast(
307
+ arch: str,
308
+ pretrained: bool,
309
+ backbone_fn,
310
+ feat_layers: List[str],
311
+ pretrained_backbone: bool = True,
312
+ input_shape: Optional[Tuple[int, int, int]] = None,
313
+ **kwargs: Any,
314
+ ) -> FAST:
315
+ pretrained_backbone = pretrained_backbone and not pretrained
316
+
317
+ # Patch the config
318
+ _cfg = deepcopy(default_cfgs[arch])
319
+ _cfg["input_shape"] = input_shape or _cfg["input_shape"]
320
+ if not kwargs.get("class_names", None):
321
+ kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME])
322
+ else:
323
+ kwargs["class_names"] = sorted(kwargs["class_names"])
324
+
325
+ # Feature extractor
326
+ feat_extractor = IntermediateLayerGetter(
327
+ backbone_fn(
328
+ input_shape=_cfg["input_shape"],
329
+ include_top=False,
330
+ pretrained=pretrained_backbone,
331
+ ),
332
+ feat_layers,
333
+ )
334
+
335
+ # Build the model
336
+ model = FAST(feat_extractor, cfg=_cfg, **kwargs)
337
+ # Load pretrained parameters
338
+ if pretrained:
339
+ load_pretrained_params(model, _cfg["url"])
340
+
341
+ # Build the model for reparameterization to access the layers
342
+ _ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False)
343
+
344
+ return model
345
+
346
+
347
+ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
348
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
349
+ <https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
350
+
351
+ >>> import tensorflow as tf
352
+ >>> from doctr.models import fast_tiny
353
+ >>> model = fast_tiny(pretrained=True)
354
+ >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
355
+ >>> out = model(input_tensor)
356
+
357
+ Args:
358
+ ----
359
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
360
+ **kwargs: keyword arguments of the DBNet architecture
361
+
362
+ Returns:
363
+ -------
364
+ text detection architecture
365
+ """
366
+ return _fast(
367
+ "fast_tiny",
368
+ pretrained,
369
+ textnet_tiny,
370
+ ["stage_0", "stage_1", "stage_2", "stage_3"],
371
+ **kwargs,
372
+ )
373
+
374
+
375
+ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
376
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
377
+ <https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
378
+
379
+ >>> import tensorflow as tf
380
+ >>> from doctr.models import fast_small
381
+ >>> model = fast_small(pretrained=True)
382
+ >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
383
+ >>> out = model(input_tensor)
384
+
385
+ Args:
386
+ ----
387
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
388
+ **kwargs: keyword arguments of the DBNet architecture
389
+
390
+ Returns:
391
+ -------
392
+ text detection architecture
393
+ """
394
+ return _fast(
395
+ "fast_small",
396
+ pretrained,
397
+ textnet_small,
398
+ ["stage_0", "stage_1", "stage_2", "stage_3"],
399
+ **kwargs,
400
+ )
401
+
402
+
403
+ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
404
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
405
+ <https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
406
+
407
+ >>> import tensorflow as tf
408
+ >>> from doctr.models import fast_base
409
+ >>> model = fast_base(pretrained=True)
410
+ >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
411
+ >>> out = model(input_tensor)
412
+
413
+ Args:
414
+ ----
415
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
416
+ **kwargs: keyword arguments of the DBNet architecture
417
+
418
+ Returns:
419
+ -------
420
+ text detection architecture
421
+ """
422
+ return _fast(
423
+ "fast_base",
424
+ pretrained,
425
+ textnet_base,
426
+ ["stage_0", "stage_1", "stage_2", "stage_3"],
427
+ **kwargs,
428
+ )
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -23,6 +23,7 @@ class LinkNetPostProcessor(DetectionPostProcessor):
23
23
  """Implements a post processor for LinkNet model.
24
24
 
25
25
  Args:
26
+ ----
26
27
  bin_thresh: threshold used to binzarized p_map at inference time
27
28
  box_thresh: minimal objectness score to consider a box
28
29
  assume_straight_pages: whether the inputs were expected to have horizontal text elements
@@ -35,7 +36,7 @@ class LinkNetPostProcessor(DetectionPostProcessor):
35
36
  assume_straight_pages: bool = True,
36
37
  ) -> None:
37
38
  super().__init__(box_thresh, bin_thresh, assume_straight_pages)
38
- self.unclip_ratio = 1.2
39
+ self.unclip_ratio = 1.5
39
40
 
40
41
  def polygon_to_box(
41
42
  self,
@@ -44,9 +45,11 @@ class LinkNetPostProcessor(DetectionPostProcessor):
44
45
  """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
45
46
 
46
47
  Args:
48
+ ----
47
49
  points: The first parameter.
48
50
 
49
51
  Returns:
52
+ -------
50
53
  a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
51
54
  """
52
55
  if not self.assume_straight_pages:
@@ -78,7 +81,7 @@ class LinkNetPostProcessor(DetectionPostProcessor):
78
81
  if len(expanded_points) < 1:
79
82
  return None # type: ignore[return-value]
80
83
  return (
81
- cv2.boundingRect(expanded_points)
84
+ cv2.boundingRect(expanded_points) # type: ignore[return-value]
82
85
  if self.assume_straight_pages
83
86
  else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
84
87
  )
@@ -91,12 +94,14 @@ class LinkNetPostProcessor(DetectionPostProcessor):
91
94
  """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
92
95
 
93
96
  Args:
97
+ ----
94
98
  pred: Pred map from differentiable linknet output
95
99
  bitmap: Bitmap map computed from pred (binarized)
96
100
  angle_tol: Comparison tolerance of the angle with the median angle across the page
97
101
  ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
98
102
 
99
103
  Returns:
104
+ -------
100
105
  np tensor boxes for the bitmap, each box is a 6-element list
101
106
  containing x, y, w, h, alpha, score for the box
102
107
  """
@@ -146,6 +151,7 @@ class _LinkNet(BaseModel):
146
151
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
147
152
 
148
153
  Args:
154
+ ----
149
155
  out_chan: number of channels for the output
150
156
  """
151
157
 
@@ -162,14 +168,15 @@ class _LinkNet(BaseModel):
162
168
  """Build the target, and it's mask to be used from loss computation.
163
169
 
164
170
  Args:
171
+ ----
165
172
  target: target coming from dataset
166
173
  output_shape: shape of the output of the model without batch_size
167
174
  channels_last: whether channels are last or not
168
175
 
169
176
  Returns:
177
+ -------
170
178
  the new formatted target and the mask
171
179
  """
172
-
173
180
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
174
181
  raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
175
182
  if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
@@ -239,7 +246,7 @@ class _LinkNet(BaseModel):
239
246
  if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
240
247
  seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
241
248
  continue
242
- cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1)
249
+ cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
243
250
 
244
251
  # Don't forget to switch back to channel last if Tensorflow is used
245
252
  if channels_last:
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -14,7 +14,7 @@ from torchvision.models._utils import IntermediateLayerGetter
14
14
  from doctr.file_utils import CLASS_NAME
15
15
  from doctr.models.classification import resnet18, resnet34, resnet50
16
16
 
17
- from ...utils import load_pretrained_params
17
+ from ...utils import _bf16_to_float32, load_pretrained_params
18
18
  from .base import LinkNetPostProcessor, _LinkNet
19
19
 
20
20
  __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
@@ -25,19 +25,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
25
25
  "input_shape": (3, 1024, 1024),
26
26
  "mean": (0.798, 0.785, 0.772),
27
27
  "std": (0.264, 0.2749, 0.287),
28
- "url": None,
28
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-e47a14dc.pt&src=0",
29
29
  },
30
30
  "linknet_resnet34": {
31
31
  "input_shape": (3, 1024, 1024),
32
32
  "mean": (0.798, 0.785, 0.772),
33
33
  "std": (0.264, 0.2749, 0.287),
34
- "url": None,
34
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-9ca2df3e.pt&src=0",
35
35
  },
36
36
  "linknet_resnet50": {
37
37
  "input_shape": (3, 1024, 1024),
38
38
  "mean": (0.798, 0.785, 0.772),
39
39
  "std": (0.264, 0.2749, 0.287),
40
- "url": None,
40
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-6cf565c1.pt&src=0",
41
41
  },
42
42
  }
43
43
 
@@ -61,7 +61,6 @@ class LinkNetFPN(nn.Module):
61
61
  @staticmethod
62
62
  def decoder_block(in_chan: int, out_chan: int, stride: int) -> nn.Sequential:
63
63
  """Creates a LinkNet decoder block"""
64
-
65
64
  mid_chan = in_chan // 4
66
65
  return nn.Sequential(
67
66
  nn.Conv2d(in_chan, mid_chan, kernel_size=1, bias=False),
@@ -90,7 +89,10 @@ class LinkNet(nn.Module, _LinkNet):
90
89
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
91
90
 
92
91
  Args:
92
+ ----
93
93
  feature extractor: the backbone serving as feature extractor
94
+ bin_thresh: threshold for binarization of the output feature map
95
+ box_thresh: minimal objectness score to consider a box
94
96
  head_chans: number of channels in the head layers
95
97
  assume_straight_pages: if True, fit straight bounding boxes only
96
98
  exportable: onnx exportable returns only logits
@@ -102,6 +104,7 @@ class LinkNet(nn.Module, _LinkNet):
102
104
  self,
103
105
  feat_extractor: IntermediateLayerGetter,
104
106
  bin_thresh: float = 0.1,
107
+ box_thresh: float = 0.1,
105
108
  head_chans: int = 32,
106
109
  assume_straight_pages: bool = True,
107
110
  exportable: bool = False,
@@ -142,7 +145,7 @@ class LinkNet(nn.Module, _LinkNet):
142
145
  )
143
146
 
144
147
  self.postprocessor = LinkNetPostProcessor(
145
- assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh
148
+ assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
146
149
  )
147
150
 
148
151
  for n, m in self.named_modules():
@@ -175,7 +178,7 @@ class LinkNet(nn.Module, _LinkNet):
175
178
  return out
176
179
 
177
180
  if return_model_output or target is None or return_preds:
178
- prob_map = torch.sigmoid(logits)
181
+ prob_map = _bf16_to_float32(torch.sigmoid(logits))
179
182
  if return_model_output:
180
183
  out["out_map"] = prob_map
181
184
 
@@ -204,6 +207,7 @@ class LinkNet(nn.Module, _LinkNet):
204
207
  <https://github.com/tensorflow/addons/>`_.
205
208
 
206
209
  Args:
210
+ ----
207
211
  out_map: output feature map of the model of shape (N, num_classes, H, W)
208
212
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
209
213
  gamma: modulating factor in the focal loss formula
@@ -211,6 +215,7 @@ class LinkNet(nn.Module, _LinkNet):
211
215
  eps: epsilon factor in dice loss
212
216
 
213
217
  Returns:
218
+ -------
214
219
  A loss tensor
215
220
  """
216
221
  _target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
@@ -232,10 +237,12 @@ class LinkNet(nn.Module, _LinkNet):
232
237
  # Class reduced
233
238
  focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3))
234
239
 
235
- # Dice loss
236
- inter = (seg_mask * proba_map * seg_target).sum((0, 1, 2, 3))
237
- cardinality = (seg_mask * (proba_map + seg_target)).sum((0, 1, 2, 3))
238
- dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)
240
+ # Compute dice loss for each class
241
+ dice_map = torch.softmax(out_map, dim=1) if len(self.class_names) > 1 else proba_map
242
+ # Class reduced
243
+ inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
244
+ cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
245
+ dice_loss = (1 - 2 * inter / (cardinality + eps)).mean()
239
246
 
240
247
  # Return the full loss (equal sum of focal loss and dice loss)
241
248
  return focal_loss + dice_loss
@@ -288,12 +295,14 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
288
295
  >>> out = model(input_tensor)
289
296
 
290
297
  Args:
298
+ ----
291
299
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
300
+ **kwargs: keyword arguments of the LinkNet architecture
292
301
 
293
302
  Returns:
303
+ -------
294
304
  text detection architecture
295
305
  """
296
-
297
306
  return _linknet(
298
307
  "linknet_resnet18",
299
308
  pretrained,
@@ -318,12 +327,14 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
318
327
  >>> out = model(input_tensor)
319
328
 
320
329
  Args:
330
+ ----
321
331
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
332
+ **kwargs: keyword arguments of the LinkNet architecture
322
333
 
323
334
  Returns:
335
+ -------
324
336
  text detection architecture
325
337
  """
326
-
327
338
  return _linknet(
328
339
  "linknet_resnet34",
329
340
  pretrained,
@@ -348,12 +359,14 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
348
359
  >>> out = model(input_tensor)
349
360
 
350
361
  Args:
362
+ ----
351
363
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
364
+ **kwargs: keyword arguments of the LinkNet architecture
352
365
 
353
366
  Returns:
367
+ -------
354
368
  text detection architecture
355
369
  """
356
-
357
370
  return _linknet(
358
371
  "linknet_resnet50",
359
372
  pretrained,