python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/contrib/artefacts.py +1 -1
  3. doctr/contrib/base.py +1 -1
  4. doctr/datasets/__init__.py +0 -5
  5. doctr/datasets/coco_text.py +1 -1
  6. doctr/datasets/cord.py +1 -1
  7. doctr/datasets/datasets/__init__.py +1 -6
  8. doctr/datasets/datasets/base.py +1 -1
  9. doctr/datasets/datasets/pytorch.py +3 -3
  10. doctr/datasets/detection.py +1 -1
  11. doctr/datasets/doc_artefacts.py +1 -1
  12. doctr/datasets/funsd.py +1 -1
  13. doctr/datasets/generator/__init__.py +1 -6
  14. doctr/datasets/generator/base.py +1 -1
  15. doctr/datasets/generator/pytorch.py +1 -1
  16. doctr/datasets/ic03.py +1 -1
  17. doctr/datasets/ic13.py +1 -1
  18. doctr/datasets/iiit5k.py +1 -1
  19. doctr/datasets/iiithws.py +1 -1
  20. doctr/datasets/imgur5k.py +1 -1
  21. doctr/datasets/mjsynth.py +1 -1
  22. doctr/datasets/ocr.py +1 -1
  23. doctr/datasets/orientation.py +1 -1
  24. doctr/datasets/recognition.py +1 -1
  25. doctr/datasets/sroie.py +1 -1
  26. doctr/datasets/svhn.py +1 -1
  27. doctr/datasets/svt.py +1 -1
  28. doctr/datasets/synthtext.py +1 -1
  29. doctr/datasets/utils.py +1 -1
  30. doctr/datasets/vocabs.py +1 -3
  31. doctr/datasets/wildreceipt.py +1 -1
  32. doctr/file_utils.py +3 -102
  33. doctr/io/elements.py +1 -1
  34. doctr/io/html.py +1 -1
  35. doctr/io/image/__init__.py +1 -7
  36. doctr/io/image/base.py +1 -1
  37. doctr/io/image/pytorch.py +2 -2
  38. doctr/io/pdf.py +1 -1
  39. doctr/io/reader.py +1 -1
  40. doctr/models/_utils.py +56 -18
  41. doctr/models/builder.py +1 -1
  42. doctr/models/classification/magc_resnet/__init__.py +1 -6
  43. doctr/models/classification/magc_resnet/pytorch.py +3 -3
  44. doctr/models/classification/mobilenet/__init__.py +1 -6
  45. doctr/models/classification/mobilenet/pytorch.py +1 -1
  46. doctr/models/classification/predictor/__init__.py +1 -6
  47. doctr/models/classification/predictor/pytorch.py +2 -2
  48. doctr/models/classification/resnet/__init__.py +1 -6
  49. doctr/models/classification/resnet/pytorch.py +1 -1
  50. doctr/models/classification/textnet/__init__.py +1 -6
  51. doctr/models/classification/textnet/pytorch.py +2 -2
  52. doctr/models/classification/vgg/__init__.py +1 -6
  53. doctr/models/classification/vgg/pytorch.py +1 -1
  54. doctr/models/classification/vip/__init__.py +1 -4
  55. doctr/models/classification/vip/layers/__init__.py +1 -4
  56. doctr/models/classification/vip/layers/pytorch.py +2 -2
  57. doctr/models/classification/vip/pytorch.py +1 -1
  58. doctr/models/classification/vit/__init__.py +1 -6
  59. doctr/models/classification/vit/pytorch.py +3 -3
  60. doctr/models/classification/zoo.py +7 -12
  61. doctr/models/core.py +1 -1
  62. doctr/models/detection/_utils/__init__.py +1 -6
  63. doctr/models/detection/_utils/base.py +1 -1
  64. doctr/models/detection/_utils/pytorch.py +1 -1
  65. doctr/models/detection/core.py +2 -2
  66. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  67. doctr/models/detection/differentiable_binarization/base.py +5 -13
  68. doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
  69. doctr/models/detection/fast/__init__.py +1 -6
  70. doctr/models/detection/fast/base.py +5 -15
  71. doctr/models/detection/fast/pytorch.py +5 -5
  72. doctr/models/detection/linknet/__init__.py +1 -6
  73. doctr/models/detection/linknet/base.py +4 -13
  74. doctr/models/detection/linknet/pytorch.py +3 -3
  75. doctr/models/detection/predictor/__init__.py +1 -6
  76. doctr/models/detection/predictor/pytorch.py +2 -2
  77. doctr/models/detection/zoo.py +16 -33
  78. doctr/models/factory/hub.py +26 -34
  79. doctr/models/kie_predictor/__init__.py +1 -6
  80. doctr/models/kie_predictor/base.py +1 -1
  81. doctr/models/kie_predictor/pytorch.py +3 -7
  82. doctr/models/modules/layers/__init__.py +1 -6
  83. doctr/models/modules/layers/pytorch.py +4 -4
  84. doctr/models/modules/transformer/__init__.py +1 -6
  85. doctr/models/modules/transformer/pytorch.py +3 -3
  86. doctr/models/modules/vision_transformer/__init__.py +1 -6
  87. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  88. doctr/models/predictor/__init__.py +1 -6
  89. doctr/models/predictor/base.py +4 -9
  90. doctr/models/predictor/pytorch.py +3 -6
  91. doctr/models/preprocessor/__init__.py +1 -6
  92. doctr/models/preprocessor/pytorch.py +28 -33
  93. doctr/models/recognition/core.py +1 -1
  94. doctr/models/recognition/crnn/__init__.py +1 -6
  95. doctr/models/recognition/crnn/pytorch.py +7 -7
  96. doctr/models/recognition/master/__init__.py +1 -6
  97. doctr/models/recognition/master/base.py +1 -1
  98. doctr/models/recognition/master/pytorch.py +6 -6
  99. doctr/models/recognition/parseq/__init__.py +1 -6
  100. doctr/models/recognition/parseq/base.py +1 -1
  101. doctr/models/recognition/parseq/pytorch.py +6 -6
  102. doctr/models/recognition/predictor/__init__.py +1 -6
  103. doctr/models/recognition/predictor/_utils.py +8 -17
  104. doctr/models/recognition/predictor/pytorch.py +2 -3
  105. doctr/models/recognition/sar/__init__.py +1 -6
  106. doctr/models/recognition/sar/pytorch.py +4 -4
  107. doctr/models/recognition/utils.py +1 -1
  108. doctr/models/recognition/viptr/__init__.py +1 -4
  109. doctr/models/recognition/viptr/pytorch.py +4 -4
  110. doctr/models/recognition/vitstr/__init__.py +1 -6
  111. doctr/models/recognition/vitstr/base.py +1 -1
  112. doctr/models/recognition/vitstr/pytorch.py +4 -4
  113. doctr/models/recognition/zoo.py +14 -14
  114. doctr/models/utils/__init__.py +1 -6
  115. doctr/models/utils/pytorch.py +3 -2
  116. doctr/models/zoo.py +1 -1
  117. doctr/transforms/functional/__init__.py +1 -6
  118. doctr/transforms/functional/base.py +3 -2
  119. doctr/transforms/functional/pytorch.py +5 -5
  120. doctr/transforms/modules/__init__.py +1 -7
  121. doctr/transforms/modules/base.py +28 -94
  122. doctr/transforms/modules/pytorch.py +29 -27
  123. doctr/utils/common_types.py +1 -1
  124. doctr/utils/data.py +1 -2
  125. doctr/utils/fonts.py +1 -1
  126. doctr/utils/geometry.py +7 -11
  127. doctr/utils/metrics.py +1 -1
  128. doctr/utils/multithreading.py +1 -1
  129. doctr/utils/reconstitution.py +1 -1
  130. doctr/utils/repr.py +1 -1
  131. doctr/utils/visualization.py +2 -2
  132. doctr/version.py +1 -1
  133. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
  134. python_doctr-1.0.1.dist-info/RECORD +149 -0
  135. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
  136. doctr/datasets/datasets/tensorflow.py +0 -59
  137. doctr/datasets/generator/tensorflow.py +0 -58
  138. doctr/datasets/loader.py +0 -94
  139. doctr/io/image/tensorflow.py +0 -101
  140. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  141. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  142. doctr/models/classification/predictor/tensorflow.py +0 -60
  143. doctr/models/classification/resnet/tensorflow.py +0 -418
  144. doctr/models/classification/textnet/tensorflow.py +0 -275
  145. doctr/models/classification/vgg/tensorflow.py +0 -125
  146. doctr/models/classification/vit/tensorflow.py +0 -201
  147. doctr/models/detection/_utils/tensorflow.py +0 -34
  148. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  149. doctr/models/detection/fast/tensorflow.py +0 -427
  150. doctr/models/detection/linknet/tensorflow.py +0 -377
  151. doctr/models/detection/predictor/tensorflow.py +0 -70
  152. doctr/models/kie_predictor/tensorflow.py +0 -187
  153. doctr/models/modules/layers/tensorflow.py +0 -171
  154. doctr/models/modules/transformer/tensorflow.py +0 -235
  155. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  156. doctr/models/predictor/tensorflow.py +0 -155
  157. doctr/models/preprocessor/tensorflow.py +0 -122
  158. doctr/models/recognition/crnn/tensorflow.py +0 -317
  159. doctr/models/recognition/master/tensorflow.py +0 -320
  160. doctr/models/recognition/parseq/tensorflow.py +0 -516
  161. doctr/models/recognition/predictor/tensorflow.py +0 -79
  162. doctr/models/recognition/sar/tensorflow.py +0 -423
  163. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  164. doctr/models/utils/tensorflow.py +0 -189
  165. doctr/transforms/functional/tensorflow.py +0 -254
  166. doctr/transforms/modules/tensorflow.py +0 -562
  167. python_doctr-0.12.0.dist-info/RECORD +0 -180
  168. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
  169. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
  170. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
@@ -1,427 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
-
8
- from copy import deepcopy
9
- from typing import Any
10
-
11
- import numpy as np
12
- import tensorflow as tf
13
- from tensorflow.keras import Model, Sequential, layers
14
-
15
- from doctr.file_utils import CLASS_NAME
16
- from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params
17
- from doctr.utils.repr import NestedObject
18
-
19
- from ...classification import textnet_base, textnet_small, textnet_tiny
20
- from ...modules.layers import FASTConvLayer
21
- from .base import _FAST, FASTPostProcessor
22
-
23
- __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
24
-
25
-
26
- default_cfgs: dict[str, dict[str, Any]] = {
27
- "fast_tiny": {
28
- "input_shape": (1024, 1024, 3),
29
- "mean": (0.798, 0.785, 0.772),
30
- "std": (0.264, 0.2749, 0.287),
31
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0",
32
- },
33
- "fast_small": {
34
- "input_shape": (1024, 1024, 3),
35
- "mean": (0.798, 0.785, 0.772),
36
- "std": (0.264, 0.2749, 0.287),
37
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0",
38
- },
39
- "fast_base": {
40
- "input_shape": (1024, 1024, 3),
41
- "mean": (0.798, 0.785, 0.772),
42
- "std": (0.264, 0.2749, 0.287),
43
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0",
44
- },
45
- }
46
-
47
-
48
- class FastNeck(layers.Layer, NestedObject):
49
- """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer.
50
-
51
- Args:
52
- in_channels: number of input channels
53
- out_channels: number of output channels
54
- """
55
-
56
- def __init__(
57
- self,
58
- in_channels: int,
59
- out_channels: int = 128,
60
- ) -> None:
61
- super().__init__()
62
- self.reduction = [FASTConvLayer(in_channels * scale, out_channels, kernel_size=3) for scale in [1, 2, 4, 8]]
63
-
64
- def _upsample(self, x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
65
- return tf.image.resize(x, size=y.shape[1:3], method="bilinear")
66
-
67
- def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
68
- f1, f2, f3, f4 = x
69
- f1, f2, f3, f4 = [reduction(f, **kwargs) for reduction, f in zip(self.reduction, (f1, f2, f3, f4))]
70
- f2, f3, f4 = [self._upsample(f, f1) for f in (f2, f3, f4)]
71
- f = tf.concat((f1, f2, f3, f4), axis=-1)
72
- return f
73
-
74
-
75
- class FastHead(Sequential):
76
- """Head of the FAST architecture
77
-
78
- Args:
79
- in_channels: number of input channels
80
- num_classes: number of output classes
81
- out_channels: number of output channels
82
- dropout: dropout probability
83
- """
84
-
85
- def __init__(
86
- self,
87
- in_channels: int,
88
- num_classes: int,
89
- out_channels: int = 128,
90
- dropout: float = 0.1,
91
- ) -> None:
92
- _layers = [
93
- FASTConvLayer(in_channels, out_channels, kernel_size=3),
94
- layers.Dropout(dropout),
95
- layers.Conv2D(num_classes, kernel_size=1, use_bias=False),
96
- ]
97
- super().__init__(_layers)
98
-
99
-
100
- class FAST(_FAST, Model, NestedObject):
101
- """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
102
- <https://arxiv.org/pdf/2111.02394.pdf>`_.
103
-
104
- Args:
105
- feature extractor: the backbone serving as feature extractor
106
- bin_thresh: threshold for binarization
107
- box_thresh: minimal objectness score to consider a box
108
- dropout_prob: dropout probability
109
- pooling_size: size of the pooling layer
110
- assume_straight_pages: if True, fit straight bounding boxes only
111
- exportable: onnx exportable returns only logits
112
- cfg: the configuration dict of the model
113
- class_names: list of class names
114
- """
115
-
116
- _children_names: list[str] = ["feat_extractor", "neck", "head", "postprocessor"]
117
-
118
- def __init__(
119
- self,
120
- feature_extractor: IntermediateLayerGetter,
121
- bin_thresh: float = 0.1,
122
- box_thresh: float = 0.1,
123
- dropout_prob: float = 0.1,
124
- pooling_size: int = 4, # different from paper performs better on close text-rich images
125
- assume_straight_pages: bool = True,
126
- exportable: bool = False,
127
- cfg: dict[str, Any] = {},
128
- class_names: list[str] = [CLASS_NAME],
129
- ) -> None:
130
- super().__init__()
131
- self.class_names = class_names
132
- num_classes: int = len(self.class_names)
133
- self.cfg = cfg
134
-
135
- self.feat_extractor = feature_extractor
136
- self.exportable = exportable
137
- self.assume_straight_pages = assume_straight_pages
138
-
139
- # Identify the number of channels for the neck & head initialization
140
- feat_out_channels = [
141
- layers.Input(shape=in_shape[1:]).shape[-1] for in_shape in self.feat_extractor.output_shape
142
- ]
143
- # Initialize neck & head
144
- self.neck = FastNeck(feat_out_channels[0], feat_out_channels[1])
145
- self.head = FastHead(feat_out_channels[-1], num_classes, feat_out_channels[1], dropout_prob)
146
-
147
- # NOTE: The post processing from the paper works not well for text-rich images
148
- # so we use a modified version from DBNet
149
- self.postprocessor = FASTPostProcessor(
150
- assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
151
- )
152
-
153
- # Pooling layer as erosion reversal as described in the paper
154
- self.pooling = layers.MaxPooling2D(pool_size=pooling_size // 2 + 1, strides=1, padding="same")
155
-
156
- def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
157
- """Load pretrained parameters onto the model
158
-
159
- Args:
160
- path_or_url: the path or URL to the model parameters (checkpoint)
161
- **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
162
- """
163
- load_pretrained_params(self, path_or_url, **kwargs)
164
-
165
- def compute_loss(
166
- self,
167
- out_map: tf.Tensor,
168
- target: list[dict[str, np.ndarray]],
169
- eps: float = 1e-6,
170
- ) -> tf.Tensor:
171
- """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
172
-
173
- Args:
174
- out_map: output feature map of the model of shape (N, num_classes, H, W)
175
- target: list of dictionary where each dict has a `boxes` and a `flags` entry
176
- eps: epsilon factor in dice loss
177
-
178
- Returns:
179
- A loss tensor
180
- """
181
- targets = self.build_target(target, out_map.shape[1:], True)
182
-
183
- seg_target = tf.convert_to_tensor(targets[0], dtype=out_map.dtype)
184
- seg_mask = tf.convert_to_tensor(targets[1], dtype=out_map.dtype)
185
- shrunken_kernel = tf.convert_to_tensor(targets[2], dtype=out_map.dtype)
186
-
187
- def ohem(score: tf.Tensor, gt: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
188
- pos_num = tf.reduce_sum(tf.cast(gt > 0.5, dtype=tf.int32)) - tf.reduce_sum(
189
- tf.cast((gt > 0.5) & (mask <= 0.5), dtype=tf.int32)
190
- )
191
- neg_num = tf.reduce_sum(tf.cast(gt <= 0.5, dtype=tf.int32))
192
- neg_num = tf.minimum(pos_num * 3, neg_num)
193
-
194
- if neg_num == 0 or pos_num == 0:
195
- return mask
196
-
197
- neg_score_sorted, _ = tf.nn.top_k(-tf.boolean_mask(score, gt <= 0.5), k=neg_num)
198
- threshold = -neg_score_sorted[-1]
199
-
200
- selected_mask = tf.math.logical_and((score >= threshold) | (gt > 0.5), (mask > 0.5))
201
- return tf.cast(selected_mask, dtype=tf.float32)
202
-
203
- if len(self.class_names) > 1:
204
- kernels = tf.nn.softmax(out_map, axis=-1)
205
- prob_map = tf.nn.softmax(self.pooling(out_map), axis=-1)
206
- else:
207
- kernels = tf.sigmoid(out_map)
208
- prob_map = tf.sigmoid(self.pooling(out_map))
209
-
210
- # As described in the paper, we use the Dice loss for the text segmentation map and the Dice loss scaled by 0.5.
211
- selected_masks = tf.stack(
212
- [ohem(score, gt, mask) for score, gt, mask in zip(prob_map, seg_target, seg_mask)], axis=0
213
- )
214
- inter = tf.reduce_sum(selected_masks * prob_map * seg_target, axis=(0, 1, 2))
215
- cardinality = tf.reduce_sum(selected_masks * (prob_map + seg_target), axis=(0, 1, 2))
216
- text_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps))) * 0.5
217
-
218
- # As described in the paper, we use the Dice loss for the text kernel map.
219
- selected_masks = seg_target * seg_mask
220
- inter = tf.reduce_sum(selected_masks * kernels * shrunken_kernel, axis=(0, 1, 2))
221
- cardinality = tf.reduce_sum(selected_masks * (kernels + shrunken_kernel), axis=(0, 1, 2))
222
- kernel_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps)))
223
-
224
- return text_loss + kernel_loss
225
-
226
- def call(
227
- self,
228
- x: tf.Tensor,
229
- target: list[dict[str, np.ndarray]] | None = None,
230
- return_model_output: bool = False,
231
- return_preds: bool = False,
232
- **kwargs: Any,
233
- ) -> dict[str, Any]:
234
- feat_maps = self.feat_extractor(x, **kwargs)
235
- # Pass through the Neck & Head & Upsample
236
- feat_concat = self.neck(feat_maps, **kwargs)
237
- logits: tf.Tensor = self.head(feat_concat, **kwargs)
238
- logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs)
239
-
240
- out: dict[str, tf.Tensor] = {}
241
- if self.exportable:
242
- out["logits"] = logits
243
- return out
244
-
245
- if return_model_output or target is None or return_preds:
246
- prob_map = _bf16_to_float32(tf.math.sigmoid(self.pooling(logits, **kwargs)))
247
-
248
- if return_model_output:
249
- out["out_map"] = prob_map
250
-
251
- if target is None or return_preds:
252
- # Post-process boxes (keep only text predictions)
253
- out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]
254
-
255
- if target is not None:
256
- loss = self.compute_loss(logits, target)
257
- out["loss"] = loss
258
-
259
- return out
260
-
261
-
262
- def reparameterize(model: FAST | layers.Layer) -> FAST:
263
- """Fuse batchnorm and conv layers and reparameterize the model
264
-
265
- args:
266
-
267
- model: the FAST model to reparameterize
268
-
269
- Returns:
270
- the reparameterized model
271
- """
272
- last_conv = None
273
- last_conv_idx = None
274
-
275
- for idx, layer in enumerate(model.layers):
276
- if hasattr(layer, "layers") or isinstance(
277
- layer, (FASTConvLayer, FastNeck, FastHead, layers.BatchNormalization, layers.Conv2D)
278
- ):
279
- if isinstance(layer, layers.BatchNormalization):
280
- # fuse batchnorm only if it is followed by a conv layer
281
- if last_conv is None:
282
- continue
283
- conv_w = last_conv.kernel
284
- conv_b = last_conv.bias if last_conv.use_bias else tf.zeros_like(layer.moving_mean)
285
-
286
- factor = layer.gamma / tf.sqrt(layer.moving_variance + layer.epsilon)
287
- last_conv.kernel = conv_w * factor.numpy().reshape([1, 1, 1, -1])
288
- if last_conv.use_bias:
289
- last_conv.bias.assign((conv_b - layer.moving_mean) * factor + layer.beta)
290
- model.layers[last_conv_idx] = last_conv # Replace the last conv layer with the fused version
291
- model.layers[idx] = layers.Lambda(lambda x: x)
292
- last_conv = None
293
- elif isinstance(layer, layers.Conv2D):
294
- last_conv = layer
295
- last_conv_idx = idx
296
- elif isinstance(layer, FASTConvLayer):
297
- layer.reparameterize_layer()
298
- elif isinstance(layer, FastNeck):
299
- for reduction in layer.reduction:
300
- reduction.reparameterize_layer()
301
- elif isinstance(layer, FastHead):
302
- reparameterize(layer)
303
- else:
304
- reparameterize(layer)
305
- return model
306
-
307
-
308
- def _fast(
309
- arch: str,
310
- pretrained: bool,
311
- backbone_fn,
312
- feat_layers: list[str],
313
- pretrained_backbone: bool = True,
314
- input_shape: tuple[int, int, int] | None = None,
315
- **kwargs: Any,
316
- ) -> FAST:
317
- pretrained_backbone = pretrained_backbone and not pretrained
318
-
319
- # Patch the config
320
- _cfg = deepcopy(default_cfgs[arch])
321
- _cfg["input_shape"] = input_shape or _cfg["input_shape"]
322
- if not kwargs.get("class_names", None):
323
- kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME])
324
- else:
325
- kwargs["class_names"] = sorted(kwargs["class_names"])
326
-
327
- # Feature extractor
328
- feat_extractor = IntermediateLayerGetter(
329
- backbone_fn(
330
- input_shape=_cfg["input_shape"],
331
- include_top=False,
332
- pretrained=pretrained_backbone,
333
- ),
334
- feat_layers,
335
- )
336
-
337
- # Build the model
338
- model = FAST(feat_extractor, cfg=_cfg, **kwargs)
339
- _build_model(model)
340
-
341
- # Load pretrained parameters
342
- if pretrained:
343
- # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
344
- model.from_pretrained(
345
- _cfg["url"],
346
- skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
347
- )
348
-
349
- return model
350
-
351
-
352
- def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
353
- """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
354
- <https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
355
-
356
- >>> import tensorflow as tf
357
- >>> from doctr.models import fast_tiny
358
- >>> model = fast_tiny(pretrained=True)
359
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
360
- >>> out = model(input_tensor)
361
-
362
- Args:
363
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
364
- **kwargs: keyword arguments of the DBNet architecture
365
-
366
- Returns:
367
- text detection architecture
368
- """
369
- return _fast(
370
- "fast_tiny",
371
- pretrained,
372
- textnet_tiny,
373
- ["stage_0", "stage_1", "stage_2", "stage_3"],
374
- **kwargs,
375
- )
376
-
377
-
378
- def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
379
- """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
380
- <https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
381
-
382
- >>> import tensorflow as tf
383
- >>> from doctr.models import fast_small
384
- >>> model = fast_small(pretrained=True)
385
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
386
- >>> out = model(input_tensor)
387
-
388
- Args:
389
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
390
- **kwargs: keyword arguments of the DBNet architecture
391
-
392
- Returns:
393
- text detection architecture
394
- """
395
- return _fast(
396
- "fast_small",
397
- pretrained,
398
- textnet_small,
399
- ["stage_0", "stage_1", "stage_2", "stage_3"],
400
- **kwargs,
401
- )
402
-
403
-
404
- def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
405
- """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
406
- <https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
407
-
408
- >>> import tensorflow as tf
409
- >>> from doctr.models import fast_base
410
- >>> model = fast_base(pretrained=True)
411
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
412
- >>> out = model(input_tensor)
413
-
414
- Args:
415
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
416
- **kwargs: keyword arguments of the DBNet architecture
417
-
418
- Returns:
419
- text detection architecture
420
- """
421
- return _fast(
422
- "fast_base",
423
- pretrained,
424
- textnet_base,
425
- ["stage_0", "stage_1", "stage_2", "stage_3"],
426
- **kwargs,
427
- )