python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -1,414 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
-
8
- from copy import deepcopy
9
- from typing import Any
10
-
11
- import numpy as np
12
- import tensorflow as tf
13
- from tensorflow.keras import Model, Sequential, layers, losses
14
- from tensorflow.keras.applications import ResNet50
15
-
16
- from doctr.file_utils import CLASS_NAME
17
- from doctr.models.utils import (
18
- IntermediateLayerGetter,
19
- _bf16_to_float32,
20
- _build_model,
21
- conv_sequence,
22
- load_pretrained_params,
23
- )
24
- from doctr.utils.repr import NestedObject
25
-
26
- from ...classification import mobilenet_v3_large
27
- from .base import DBPostProcessor, _DBNet
28
-
29
- __all__ = ["DBNet", "db_resnet50", "db_mobilenet_v3_large"]
30
-
31
-
32
- default_cfgs: dict[str, dict[str, Any]] = {
33
- "db_resnet50": {
34
- "mean": (0.798, 0.785, 0.772),
35
- "std": (0.264, 0.2749, 0.287),
36
- "input_shape": (1024, 1024, 3),
37
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0",
38
- },
39
- "db_mobilenet_v3_large": {
40
- "mean": (0.798, 0.785, 0.772),
41
- "std": (0.264, 0.2749, 0.287),
42
- "input_shape": (1024, 1024, 3),
43
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0",
44
- },
45
- }
46
-
47
-
48
- class FeaturePyramidNetwork(layers.Layer, NestedObject):
49
- """Feature Pyramid Network as described in `"Feature Pyramid Networks for Object Detection"
50
- <https://arxiv.org/pdf/1612.03144.pdf>`_.
51
-
52
- Args:
53
- channels: number of channel to output
54
- """
55
-
56
- def __init__(
57
- self,
58
- channels: int,
59
- ) -> None:
60
- super().__init__()
61
- self.channels = channels
62
- self.upsample = layers.UpSampling2D(size=(2, 2), interpolation="nearest")
63
- self.inner_blocks = [layers.Conv2D(channels, 1, strides=1, kernel_initializer="he_normal") for _ in range(4)]
64
- self.layer_blocks = [self.build_upsampling(channels, dilation_factor=2**idx) for idx in range(4)]
65
-
66
- @staticmethod
67
- def build_upsampling(
68
- channels: int,
69
- dilation_factor: int = 1,
70
- ) -> layers.Layer:
71
- """Module which performs a 3x3 convolution followed by up-sampling
72
-
73
- Args:
74
- channels: number of output channels
75
- dilation_factor (int): dilation factor to scale the convolution output before concatenation
76
-
77
- Returns:
78
- a keras.layers.Layer object, wrapping these operations in a sequential module
79
-
80
- """
81
- _layers = conv_sequence(channels, "relu", True, kernel_size=3)
82
-
83
- if dilation_factor > 1:
84
- _layers.append(layers.UpSampling2D(size=(dilation_factor, dilation_factor), interpolation="nearest"))
85
-
86
- module = Sequential(_layers)
87
-
88
- return module
89
-
90
- def extra_repr(self) -> str:
91
- return f"channels={self.channels}"
92
-
93
- def call(
94
- self,
95
- x: list[tf.Tensor],
96
- **kwargs: Any,
97
- ) -> tf.Tensor:
98
- # Channel mapping
99
- results = [block(fmap, **kwargs) for block, fmap in zip(self.inner_blocks, x)]
100
- # Upsample & sum
101
- for idx in range(len(results) - 1, -1):
102
- results[idx] += self.upsample(results[idx + 1])
103
- # Conv & upsample
104
- results = [block(fmap, **kwargs) for block, fmap in zip(self.layer_blocks, results)]
105
-
106
- return layers.concatenate(results)
107
-
108
-
109
- class DBNet(_DBNet, Model, NestedObject):
110
- """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
111
- <https://arxiv.org/pdf/1911.08947.pdf>`_.
112
-
113
- Args:
114
- feature extractor: the backbone serving as feature extractor
115
- fpn_channels: number of channels each extracted feature maps is mapped to
116
- bin_thresh: threshold for binarization
117
- box_thresh: minimal objectness score to consider a box
118
- assume_straight_pages: if True, fit straight bounding boxes only
119
- exportable: onnx exportable returns only logits
120
- cfg: the configuration dict of the model
121
- class_names: list of class names
122
- """
123
-
124
- _children_names: list[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"]
125
-
126
- def __init__(
127
- self,
128
- feature_extractor: IntermediateLayerGetter,
129
- fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea
130
- bin_thresh: float = 0.3,
131
- box_thresh: float = 0.1,
132
- assume_straight_pages: bool = True,
133
- exportable: bool = False,
134
- cfg: dict[str, Any] | None = None,
135
- class_names: list[str] = [CLASS_NAME],
136
- ) -> None:
137
- super().__init__()
138
- self.class_names = class_names
139
- num_classes: int = len(self.class_names)
140
- self.cfg = cfg
141
-
142
- self.feat_extractor = feature_extractor
143
- self.exportable = exportable
144
- self.assume_straight_pages = assume_straight_pages
145
-
146
- self.fpn = FeaturePyramidNetwork(channels=fpn_channels)
147
- # Initialize kernels
148
- _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
149
- output_shape = tuple(self.fpn(_inputs).shape)
150
-
151
- self.probability_head = Sequential([
152
- *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
153
- layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
154
- layers.BatchNormalization(),
155
- layers.Activation("relu"),
156
- layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
157
- ])
158
- self.threshold_head = Sequential([
159
- *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
160
- layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
161
- layers.BatchNormalization(),
162
- layers.Activation("relu"),
163
- layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
164
- ])
165
-
166
- self.postprocessor = DBPostProcessor(
167
- assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
168
- )
169
-
170
- def compute_loss(
171
- self,
172
- out_map: tf.Tensor,
173
- thresh_map: tf.Tensor,
174
- target: list[dict[str, np.ndarray]],
175
- gamma: float = 2.0,
176
- alpha: float = 0.5,
177
- eps: float = 1e-8,
178
- ) -> tf.Tensor:
179
- """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
180
- and a list of masks for each image. From there it computes the loss with the model output
181
-
182
- Args:
183
- out_map: output feature map of the model of shape (N, H, W, C)
184
- thresh_map: threshold map of shape (N, H, W, C)
185
- target: list of dictionary where each dict has a `boxes` and a `flags` entry
186
- gamma: modulating factor in the focal loss formula
187
- alpha: balancing factor in the focal loss formula
188
- eps: epsilon factor in dice loss
189
-
190
- Returns:
191
- A loss tensor
192
- """
193
- if gamma < 0:
194
- raise ValueError("Value of gamma should be greater than or equal to zero.")
195
-
196
- prob_map = tf.math.sigmoid(out_map)
197
- thresh_map = tf.math.sigmoid(thresh_map)
198
-
199
- seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True)
200
- seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
201
- seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
202
- seg_mask = tf.cast(seg_mask, tf.float32)
203
- thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
204
- thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool)
205
-
206
- # Focal loss
207
- focal_scale = 10.0
208
- bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
209
-
210
- # Convert logits to prob, compute gamma factor
211
- p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
212
- alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha)
213
- # Unreduced loss
214
- focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
215
- # Class reduced
216
- focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))
217
-
218
- # Compute dice loss for each class or for approx binary_map
219
- if len(self.class_names) > 1:
220
- dice_map = tf.nn.softmax(out_map, axis=-1)
221
- else:
222
- # compute binary map instead
223
- dice_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map)))
224
- # Class-reduced dice loss
225
- inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2])
226
- cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2])
227
- dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps))
228
-
229
- # Compute l1 loss for thresh_map
230
- if tf.reduce_any(thresh_mask):
231
- thresh_mask = tf.cast(thresh_mask, tf.float32)
232
- l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / (
233
- tf.reduce_sum(thresh_mask) + eps
234
- )
235
- else:
236
- l1_loss = tf.constant(0.0)
237
-
238
- return l1_loss + focal_scale * focal_loss + dice_loss
239
-
240
- def call(
241
- self,
242
- x: tf.Tensor,
243
- target: list[dict[str, np.ndarray]] | None = None,
244
- return_model_output: bool = False,
245
- return_preds: bool = False,
246
- **kwargs: Any,
247
- ) -> dict[str, Any]:
248
- feat_maps = self.feat_extractor(x, **kwargs)
249
- feat_concat = self.fpn(feat_maps, **kwargs)
250
- logits = self.probability_head(feat_concat, **kwargs)
251
-
252
- out: dict[str, tf.Tensor] = {}
253
- if self.exportable:
254
- out["logits"] = logits
255
- return out
256
-
257
- if return_model_output or target is None or return_preds:
258
- prob_map = _bf16_to_float32(tf.math.sigmoid(logits))
259
-
260
- if return_model_output:
261
- out["out_map"] = prob_map
262
-
263
- if target is None or return_preds:
264
- # Post-process boxes (keep only text predictions)
265
- out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]
266
-
267
- if target is not None:
268
- thresh_map = self.threshold_head(feat_concat, **kwargs)
269
- loss = self.compute_loss(logits, thresh_map, target)
270
- out["loss"] = loss
271
-
272
- return out
273
-
274
-
275
- def _db_resnet(
276
- arch: str,
277
- pretrained: bool,
278
- backbone_fn,
279
- fpn_layers: list[str],
280
- pretrained_backbone: bool = True,
281
- input_shape: tuple[int, int, int] | None = None,
282
- **kwargs: Any,
283
- ) -> DBNet:
284
- pretrained_backbone = pretrained_backbone and not pretrained
285
-
286
- # Patch the config
287
- _cfg = deepcopy(default_cfgs[arch])
288
- _cfg["input_shape"] = input_shape or _cfg["input_shape"]
289
- if not kwargs.get("class_names", None):
290
- kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME])
291
- else:
292
- kwargs["class_names"] = sorted(kwargs["class_names"])
293
-
294
- # Feature extractor
295
- feat_extractor = IntermediateLayerGetter(
296
- backbone_fn(
297
- weights="imagenet" if pretrained_backbone else None,
298
- include_top=False,
299
- pooling=None,
300
- input_shape=_cfg["input_shape"],
301
- ),
302
- fpn_layers,
303
- )
304
-
305
- # Build the model
306
- model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
307
- _build_model(model)
308
-
309
- # Load pretrained parameters
310
- if pretrained:
311
- # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
312
- load_pretrained_params(
313
- model,
314
- _cfg["url"],
315
- skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
316
- )
317
-
318
- return model
319
-
320
-
321
- def _db_mobilenet(
322
- arch: str,
323
- pretrained: bool,
324
- backbone_fn,
325
- fpn_layers: list[str],
326
- pretrained_backbone: bool = True,
327
- input_shape: tuple[int, int, int] | None = None,
328
- **kwargs: Any,
329
- ) -> DBNet:
330
- pretrained_backbone = pretrained_backbone and not pretrained
331
-
332
- # Patch the config
333
- _cfg = deepcopy(default_cfgs[arch])
334
- _cfg["input_shape"] = input_shape or _cfg["input_shape"]
335
- if not kwargs.get("class_names", None):
336
- kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
337
- else:
338
- kwargs["class_names"] = sorted(kwargs["class_names"])
339
-
340
- # Feature extractor
341
- feat_extractor = IntermediateLayerGetter(
342
- backbone_fn(
343
- input_shape=_cfg["input_shape"],
344
- include_top=False,
345
- pretrained=pretrained_backbone,
346
- ),
347
- fpn_layers,
348
- )
349
-
350
- # Build the model
351
- model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
352
- _build_model(model)
353
- # Load pretrained parameters
354
- if pretrained:
355
- # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
356
- load_pretrained_params(
357
- model,
358
- _cfg["url"],
359
- skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
360
- )
361
-
362
- return model
363
-
364
-
365
- def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
366
- """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
367
- <https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
368
-
369
- >>> import tensorflow as tf
370
- >>> from doctr.models import db_resnet50
371
- >>> model = db_resnet50(pretrained=True)
372
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
373
- >>> out = model(input_tensor)
374
-
375
- Args:
376
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
377
- **kwargs: keyword arguments of the DBNet architecture
378
-
379
- Returns:
380
- text detection architecture
381
- """
382
- return _db_resnet(
383
- "db_resnet50",
384
- pretrained,
385
- ResNet50,
386
- ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"],
387
- **kwargs,
388
- )
389
-
390
-
391
- def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
392
- """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
393
- <https://arxiv.org/pdf/1911.08947.pdf>`_, using a mobilenet v3 large backbone.
394
-
395
- >>> import tensorflow as tf
396
- >>> from doctr.models import db_mobilenet_v3_large
397
- >>> model = db_mobilenet_v3_large(pretrained=True)
398
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
399
- >>> out = model(input_tensor)
400
-
401
- Args:
402
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
403
- **kwargs: keyword arguments of the DBNet architecture
404
-
405
- Returns:
406
- text detection architecture
407
- """
408
- return _db_mobilenet(
409
- "db_mobilenet_v3_large",
410
- pretrained,
411
- mobilenet_v3_large,
412
- ["inverted_2", "inverted_5", "inverted_11", "final_block"],
413
- **kwargs,
414
- )