python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/contrib/artefacts.py +1 -1
  3. doctr/contrib/base.py +1 -1
  4. doctr/datasets/__init__.py +0 -5
  5. doctr/datasets/coco_text.py +1 -1
  6. doctr/datasets/cord.py +1 -1
  7. doctr/datasets/datasets/__init__.py +1 -6
  8. doctr/datasets/datasets/base.py +1 -1
  9. doctr/datasets/datasets/pytorch.py +3 -3
  10. doctr/datasets/detection.py +1 -1
  11. doctr/datasets/doc_artefacts.py +1 -1
  12. doctr/datasets/funsd.py +1 -1
  13. doctr/datasets/generator/__init__.py +1 -6
  14. doctr/datasets/generator/base.py +1 -1
  15. doctr/datasets/generator/pytorch.py +1 -1
  16. doctr/datasets/ic03.py +1 -1
  17. doctr/datasets/ic13.py +1 -1
  18. doctr/datasets/iiit5k.py +1 -1
  19. doctr/datasets/iiithws.py +1 -1
  20. doctr/datasets/imgur5k.py +1 -1
  21. doctr/datasets/mjsynth.py +1 -1
  22. doctr/datasets/ocr.py +1 -1
  23. doctr/datasets/orientation.py +1 -1
  24. doctr/datasets/recognition.py +1 -1
  25. doctr/datasets/sroie.py +1 -1
  26. doctr/datasets/svhn.py +1 -1
  27. doctr/datasets/svt.py +1 -1
  28. doctr/datasets/synthtext.py +1 -1
  29. doctr/datasets/utils.py +1 -1
  30. doctr/datasets/vocabs.py +1 -3
  31. doctr/datasets/wildreceipt.py +1 -1
  32. doctr/file_utils.py +3 -102
  33. doctr/io/elements.py +1 -1
  34. doctr/io/html.py +1 -1
  35. doctr/io/image/__init__.py +1 -7
  36. doctr/io/image/base.py +1 -1
  37. doctr/io/image/pytorch.py +2 -2
  38. doctr/io/pdf.py +1 -1
  39. doctr/io/reader.py +1 -1
  40. doctr/models/_utils.py +56 -18
  41. doctr/models/builder.py +1 -1
  42. doctr/models/classification/magc_resnet/__init__.py +1 -6
  43. doctr/models/classification/magc_resnet/pytorch.py +3 -3
  44. doctr/models/classification/mobilenet/__init__.py +1 -6
  45. doctr/models/classification/mobilenet/pytorch.py +1 -1
  46. doctr/models/classification/predictor/__init__.py +1 -6
  47. doctr/models/classification/predictor/pytorch.py +2 -2
  48. doctr/models/classification/resnet/__init__.py +1 -6
  49. doctr/models/classification/resnet/pytorch.py +1 -1
  50. doctr/models/classification/textnet/__init__.py +1 -6
  51. doctr/models/classification/textnet/pytorch.py +2 -2
  52. doctr/models/classification/vgg/__init__.py +1 -6
  53. doctr/models/classification/vgg/pytorch.py +1 -1
  54. doctr/models/classification/vip/__init__.py +1 -4
  55. doctr/models/classification/vip/layers/__init__.py +1 -4
  56. doctr/models/classification/vip/layers/pytorch.py +2 -2
  57. doctr/models/classification/vip/pytorch.py +1 -1
  58. doctr/models/classification/vit/__init__.py +1 -6
  59. doctr/models/classification/vit/pytorch.py +3 -3
  60. doctr/models/classification/zoo.py +7 -12
  61. doctr/models/core.py +1 -1
  62. doctr/models/detection/_utils/__init__.py +1 -6
  63. doctr/models/detection/_utils/base.py +1 -1
  64. doctr/models/detection/_utils/pytorch.py +1 -1
  65. doctr/models/detection/core.py +2 -2
  66. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  67. doctr/models/detection/differentiable_binarization/base.py +5 -13
  68. doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
  69. doctr/models/detection/fast/__init__.py +1 -6
  70. doctr/models/detection/fast/base.py +5 -15
  71. doctr/models/detection/fast/pytorch.py +5 -5
  72. doctr/models/detection/linknet/__init__.py +1 -6
  73. doctr/models/detection/linknet/base.py +4 -13
  74. doctr/models/detection/linknet/pytorch.py +3 -3
  75. doctr/models/detection/predictor/__init__.py +1 -6
  76. doctr/models/detection/predictor/pytorch.py +2 -2
  77. doctr/models/detection/zoo.py +16 -33
  78. doctr/models/factory/hub.py +26 -34
  79. doctr/models/kie_predictor/__init__.py +1 -6
  80. doctr/models/kie_predictor/base.py +1 -1
  81. doctr/models/kie_predictor/pytorch.py +3 -7
  82. doctr/models/modules/layers/__init__.py +1 -6
  83. doctr/models/modules/layers/pytorch.py +4 -4
  84. doctr/models/modules/transformer/__init__.py +1 -6
  85. doctr/models/modules/transformer/pytorch.py +3 -3
  86. doctr/models/modules/vision_transformer/__init__.py +1 -6
  87. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  88. doctr/models/predictor/__init__.py +1 -6
  89. doctr/models/predictor/base.py +4 -9
  90. doctr/models/predictor/pytorch.py +3 -6
  91. doctr/models/preprocessor/__init__.py +1 -6
  92. doctr/models/preprocessor/pytorch.py +28 -33
  93. doctr/models/recognition/core.py +1 -1
  94. doctr/models/recognition/crnn/__init__.py +1 -6
  95. doctr/models/recognition/crnn/pytorch.py +7 -7
  96. doctr/models/recognition/master/__init__.py +1 -6
  97. doctr/models/recognition/master/base.py +1 -1
  98. doctr/models/recognition/master/pytorch.py +6 -6
  99. doctr/models/recognition/parseq/__init__.py +1 -6
  100. doctr/models/recognition/parseq/base.py +1 -1
  101. doctr/models/recognition/parseq/pytorch.py +6 -6
  102. doctr/models/recognition/predictor/__init__.py +1 -6
  103. doctr/models/recognition/predictor/_utils.py +8 -17
  104. doctr/models/recognition/predictor/pytorch.py +2 -3
  105. doctr/models/recognition/sar/__init__.py +1 -6
  106. doctr/models/recognition/sar/pytorch.py +4 -4
  107. doctr/models/recognition/utils.py +1 -1
  108. doctr/models/recognition/viptr/__init__.py +1 -4
  109. doctr/models/recognition/viptr/pytorch.py +4 -4
  110. doctr/models/recognition/vitstr/__init__.py +1 -6
  111. doctr/models/recognition/vitstr/base.py +1 -1
  112. doctr/models/recognition/vitstr/pytorch.py +4 -4
  113. doctr/models/recognition/zoo.py +14 -14
  114. doctr/models/utils/__init__.py +1 -6
  115. doctr/models/utils/pytorch.py +3 -2
  116. doctr/models/zoo.py +1 -1
  117. doctr/transforms/functional/__init__.py +1 -6
  118. doctr/transforms/functional/base.py +3 -2
  119. doctr/transforms/functional/pytorch.py +5 -5
  120. doctr/transforms/modules/__init__.py +1 -7
  121. doctr/transforms/modules/base.py +28 -94
  122. doctr/transforms/modules/pytorch.py +29 -27
  123. doctr/utils/common_types.py +1 -1
  124. doctr/utils/data.py +1 -2
  125. doctr/utils/fonts.py +1 -1
  126. doctr/utils/geometry.py +7 -11
  127. doctr/utils/metrics.py +1 -1
  128. doctr/utils/multithreading.py +1 -1
  129. doctr/utils/reconstitution.py +1 -1
  130. doctr/utils/repr.py +1 -1
  131. doctr/utils/visualization.py +2 -2
  132. doctr/version.py +1 -1
  133. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
  134. python_doctr-1.0.1.dist-info/RECORD +149 -0
  135. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
  136. doctr/datasets/datasets/tensorflow.py +0 -59
  137. doctr/datasets/generator/tensorflow.py +0 -58
  138. doctr/datasets/loader.py +0 -94
  139. doctr/io/image/tensorflow.py +0 -101
  140. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  141. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  142. doctr/models/classification/predictor/tensorflow.py +0 -60
  143. doctr/models/classification/resnet/tensorflow.py +0 -418
  144. doctr/models/classification/textnet/tensorflow.py +0 -275
  145. doctr/models/classification/vgg/tensorflow.py +0 -125
  146. doctr/models/classification/vit/tensorflow.py +0 -201
  147. doctr/models/detection/_utils/tensorflow.py +0 -34
  148. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  149. doctr/models/detection/fast/tensorflow.py +0 -427
  150. doctr/models/detection/linknet/tensorflow.py +0 -377
  151. doctr/models/detection/predictor/tensorflow.py +0 -70
  152. doctr/models/kie_predictor/tensorflow.py +0 -187
  153. doctr/models/modules/layers/tensorflow.py +0 -171
  154. doctr/models/modules/transformer/tensorflow.py +0 -235
  155. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  156. doctr/models/predictor/tensorflow.py +0 -155
  157. doctr/models/preprocessor/tensorflow.py +0 -122
  158. doctr/models/recognition/crnn/tensorflow.py +0 -317
  159. doctr/models/recognition/master/tensorflow.py +0 -320
  160. doctr/models/recognition/parseq/tensorflow.py +0 -516
  161. doctr/models/recognition/predictor/tensorflow.py +0 -79
  162. doctr/models/recognition/sar/tensorflow.py +0 -423
  163. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  164. doctr/models/utils/tensorflow.py +0 -189
  165. doctr/transforms/functional/tensorflow.py +0 -254
  166. doctr/transforms/modules/tensorflow.py +0 -562
  167. python_doctr-0.12.0.dist-info/RECORD +0 -180
  168. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
  169. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
  170. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
@@ -1,421 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
-
8
- from copy import deepcopy
9
- from typing import Any
10
-
11
- import numpy as np
12
- import tensorflow as tf
13
- from tensorflow.keras import Model, Sequential, layers, losses
14
- from tensorflow.keras.applications import ResNet50
15
-
16
- from doctr.file_utils import CLASS_NAME
17
- from doctr.models.utils import (
18
- IntermediateLayerGetter,
19
- _bf16_to_float32,
20
- _build_model,
21
- conv_sequence,
22
- load_pretrained_params,
23
- )
24
- from doctr.utils.repr import NestedObject
25
-
26
- from ...classification import mobilenet_v3_large
27
- from .base import DBPostProcessor, _DBNet
28
-
29
- __all__ = ["DBNet", "db_resnet50", "db_mobilenet_v3_large"]
30
-
31
-
32
- default_cfgs: dict[str, dict[str, Any]] = {
33
- "db_resnet50": {
34
- "mean": (0.798, 0.785, 0.772),
35
- "std": (0.264, 0.2749, 0.287),
36
- "input_shape": (1024, 1024, 3),
37
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0",
38
- },
39
- "db_mobilenet_v3_large": {
40
- "mean": (0.798, 0.785, 0.772),
41
- "std": (0.264, 0.2749, 0.287),
42
- "input_shape": (1024, 1024, 3),
43
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0",
44
- },
45
- }
46
-
47
-
48
- class FeaturePyramidNetwork(layers.Layer, NestedObject):
49
- """Feature Pyramid Network as described in `"Feature Pyramid Networks for Object Detection"
50
- <https://arxiv.org/pdf/1612.03144.pdf>`_.
51
-
52
- Args:
53
- channels: number of channel to output
54
- """
55
-
56
- def __init__(
57
- self,
58
- channels: int,
59
- ) -> None:
60
- super().__init__()
61
- self.channels = channels
62
- self.upsample = layers.UpSampling2D(size=(2, 2), interpolation="nearest")
63
- self.inner_blocks = [layers.Conv2D(channels, 1, strides=1, kernel_initializer="he_normal") for _ in range(4)]
64
- self.layer_blocks = [self.build_upsampling(channels, dilation_factor=2**idx) for idx in range(4)]
65
-
66
- @staticmethod
67
- def build_upsampling(
68
- channels: int,
69
- dilation_factor: int = 1,
70
- ) -> layers.Layer:
71
- """Module which performs a 3x3 convolution followed by up-sampling
72
-
73
- Args:
74
- channels: number of output channels
75
- dilation_factor (int): dilation factor to scale the convolution output before concatenation
76
-
77
- Returns:
78
- a keras.layers.Layer object, wrapping these operations in a sequential module
79
-
80
- """
81
- _layers = conv_sequence(channels, "relu", True, kernel_size=3)
82
-
83
- if dilation_factor > 1:
84
- _layers.append(layers.UpSampling2D(size=(dilation_factor, dilation_factor), interpolation="nearest"))
85
-
86
- module = Sequential(_layers)
87
-
88
- return module
89
-
90
- def extra_repr(self) -> str:
91
- return f"channels={self.channels}"
92
-
93
- def call(
94
- self,
95
- x: list[tf.Tensor],
96
- **kwargs: Any,
97
- ) -> tf.Tensor:
98
- # Channel mapping
99
- results = [block(fmap, **kwargs) for block, fmap in zip(self.inner_blocks, x)]
100
- # Upsample & sum
101
- for idx in range(len(results) - 1, -1):
102
- results[idx] += self.upsample(results[idx + 1])
103
- # Conv & upsample
104
- results = [block(fmap, **kwargs) for block, fmap in zip(self.layer_blocks, results)]
105
-
106
- return layers.concatenate(results)
107
-
108
-
109
- class DBNet(_DBNet, Model, NestedObject):
110
- """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
111
- <https://arxiv.org/pdf/1911.08947.pdf>`_.
112
-
113
- Args:
114
- feature extractor: the backbone serving as feature extractor
115
- fpn_channels: number of channels each extracted feature maps is mapped to
116
- bin_thresh: threshold for binarization
117
- box_thresh: minimal objectness score to consider a box
118
- assume_straight_pages: if True, fit straight bounding boxes only
119
- exportable: onnx exportable returns only logits
120
- cfg: the configuration dict of the model
121
- class_names: list of class names
122
- """
123
-
124
- _children_names: list[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"]
125
-
126
- def __init__(
127
- self,
128
- feature_extractor: IntermediateLayerGetter,
129
- fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea
130
- bin_thresh: float = 0.3,
131
- box_thresh: float = 0.1,
132
- assume_straight_pages: bool = True,
133
- exportable: bool = False,
134
- cfg: dict[str, Any] | None = None,
135
- class_names: list[str] = [CLASS_NAME],
136
- ) -> None:
137
- super().__init__()
138
- self.class_names = class_names
139
- num_classes: int = len(self.class_names)
140
- self.cfg = cfg
141
-
142
- self.feat_extractor = feature_extractor
143
- self.exportable = exportable
144
- self.assume_straight_pages = assume_straight_pages
145
-
146
- self.fpn = FeaturePyramidNetwork(channels=fpn_channels)
147
- # Initialize kernels
148
- _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
149
- output_shape = tuple(self.fpn(_inputs).shape)
150
-
151
- self.probability_head = Sequential([
152
- *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
153
- layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
154
- layers.BatchNormalization(),
155
- layers.Activation("relu"),
156
- layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
157
- ])
158
- self.threshold_head = Sequential([
159
- *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
160
- layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
161
- layers.BatchNormalization(),
162
- layers.Activation("relu"),
163
- layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
164
- ])
165
-
166
- self.postprocessor = DBPostProcessor(
167
- assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
168
- )
169
-
170
- def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
171
- """Load pretrained parameters onto the model
172
-
173
- Args:
174
- path_or_url: the path or URL to the model parameters (checkpoint)
175
- **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
176
- """
177
- load_pretrained_params(self, path_or_url, **kwargs)
178
-
179
- def compute_loss(
180
- self,
181
- out_map: tf.Tensor,
182
- thresh_map: tf.Tensor,
183
- target: list[dict[str, np.ndarray]],
184
- gamma: float = 2.0,
185
- alpha: float = 0.5,
186
- eps: float = 1e-8,
187
- ) -> tf.Tensor:
188
- """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
189
- and a list of masks for each image. From there it computes the loss with the model output
190
-
191
- Args:
192
- out_map: output feature map of the model of shape (N, H, W, C)
193
- thresh_map: threshold map of shape (N, H, W, C)
194
- target: list of dictionary where each dict has a `boxes` and a `flags` entry
195
- gamma: modulating factor in the focal loss formula
196
- alpha: balancing factor in the focal loss formula
197
- eps: epsilon factor in dice loss
198
-
199
- Returns:
200
- A loss tensor
201
- """
202
- if gamma < 0:
203
- raise ValueError("Value of gamma should be greater than or equal to zero.")
204
-
205
- prob_map = tf.math.sigmoid(out_map)
206
- thresh_map = tf.math.sigmoid(thresh_map)
207
-
208
- seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True)
209
- seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
210
- seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
211
- seg_mask = tf.cast(seg_mask, tf.float32)
212
- thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
213
- thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool)
214
-
215
- # Focal loss
216
- focal_scale = 10.0
217
- bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
218
-
219
- # Convert logits to prob, compute gamma factor
220
- p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
221
- alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha)
222
- # Unreduced loss
223
- focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
224
- # Class reduced
225
- focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))
226
-
227
- # Compute dice loss for each class or for approx binary_map
228
- if len(self.class_names) > 1:
229
- dice_map = tf.nn.softmax(out_map, axis=-1)
230
- else:
231
- # compute binary map instead
232
- dice_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map)))
233
- # Class-reduced dice loss
234
- inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2])
235
- cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2])
236
- dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps))
237
-
238
- # Compute l1 loss for thresh_map
239
- if tf.reduce_any(thresh_mask):
240
- thresh_mask = tf.cast(thresh_mask, tf.float32)
241
- l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / (
242
- tf.reduce_sum(thresh_mask) + eps
243
- )
244
- else:
245
- l1_loss = tf.constant(0.0)
246
-
247
- return l1_loss + focal_scale * focal_loss + dice_loss
248
-
249
- def call(
250
- self,
251
- x: tf.Tensor,
252
- target: list[dict[str, np.ndarray]] | None = None,
253
- return_model_output: bool = False,
254
- return_preds: bool = False,
255
- **kwargs: Any,
256
- ) -> dict[str, Any]:
257
- feat_maps = self.feat_extractor(x, **kwargs)
258
- feat_concat = self.fpn(feat_maps, **kwargs)
259
- logits = self.probability_head(feat_concat, **kwargs)
260
-
261
- out: dict[str, tf.Tensor] = {}
262
- if self.exportable:
263
- out["logits"] = logits
264
- return out
265
-
266
- if return_model_output or target is None or return_preds:
267
- prob_map = _bf16_to_float32(tf.math.sigmoid(logits))
268
-
269
- if return_model_output:
270
- out["out_map"] = prob_map
271
-
272
- if target is None or return_preds:
273
- # Post-process boxes (keep only text predictions)
274
- out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]
275
-
276
- if target is not None:
277
- thresh_map = self.threshold_head(feat_concat, **kwargs)
278
- loss = self.compute_loss(logits, thresh_map, target)
279
- out["loss"] = loss
280
-
281
- return out
282
-
283
-
284
- def _db_resnet(
285
- arch: str,
286
- pretrained: bool,
287
- backbone_fn,
288
- fpn_layers: list[str],
289
- pretrained_backbone: bool = True,
290
- input_shape: tuple[int, int, int] | None = None,
291
- **kwargs: Any,
292
- ) -> DBNet:
293
- pretrained_backbone = pretrained_backbone and not pretrained
294
-
295
- # Patch the config
296
- _cfg = deepcopy(default_cfgs[arch])
297
- _cfg["input_shape"] = input_shape or _cfg["input_shape"]
298
- if not kwargs.get("class_names", None):
299
- kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME])
300
- else:
301
- kwargs["class_names"] = sorted(kwargs["class_names"])
302
-
303
- # Feature extractor
304
- feat_extractor = IntermediateLayerGetter(
305
- backbone_fn(
306
- weights="imagenet" if pretrained_backbone else None,
307
- include_top=False,
308
- pooling=None,
309
- input_shape=_cfg["input_shape"],
310
- ),
311
- fpn_layers,
312
- )
313
-
314
- # Build the model
315
- model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
316
- _build_model(model)
317
-
318
- # Load pretrained parameters
319
- if pretrained:
320
- # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
321
- model.from_pretrained(
322
- _cfg["url"],
323
- skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
324
- )
325
-
326
- return model
327
-
328
-
329
- def _db_mobilenet(
330
- arch: str,
331
- pretrained: bool,
332
- backbone_fn,
333
- fpn_layers: list[str],
334
- pretrained_backbone: bool = True,
335
- input_shape: tuple[int, int, int] | None = None,
336
- **kwargs: Any,
337
- ) -> DBNet:
338
- pretrained_backbone = pretrained_backbone and not pretrained
339
-
340
- # Patch the config
341
- _cfg = deepcopy(default_cfgs[arch])
342
- _cfg["input_shape"] = input_shape or _cfg["input_shape"]
343
- if not kwargs.get("class_names", None):
344
- kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
345
- else:
346
- kwargs["class_names"] = sorted(kwargs["class_names"])
347
-
348
- # Feature extractor
349
- feat_extractor = IntermediateLayerGetter(
350
- backbone_fn(
351
- input_shape=_cfg["input_shape"],
352
- include_top=False,
353
- pretrained=pretrained_backbone,
354
- ),
355
- fpn_layers,
356
- )
357
-
358
- # Build the model
359
- model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
360
- _build_model(model)
361
- # Load pretrained parameters
362
- if pretrained:
363
- # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
364
- model.from_pretrained(
365
- _cfg["url"],
366
- skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
367
- )
368
-
369
- return model
370
-
371
-
372
- def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
373
- """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
374
- <https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
375
-
376
- >>> import tensorflow as tf
377
- >>> from doctr.models import db_resnet50
378
- >>> model = db_resnet50(pretrained=True)
379
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
380
- >>> out = model(input_tensor)
381
-
382
- Args:
383
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
384
- **kwargs: keyword arguments of the DBNet architecture
385
-
386
- Returns:
387
- text detection architecture
388
- """
389
- return _db_resnet(
390
- "db_resnet50",
391
- pretrained,
392
- ResNet50,
393
- ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"],
394
- **kwargs,
395
- )
396
-
397
-
398
- def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
399
- """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
400
- <https://arxiv.org/pdf/1911.08947.pdf>`_, using a mobilenet v3 large backbone.
401
-
402
- >>> import tensorflow as tf
403
- >>> from doctr.models import db_mobilenet_v3_large
404
- >>> model = db_mobilenet_v3_large(pretrained=True)
405
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
406
- >>> out = model(input_tensor)
407
-
408
- Args:
409
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
410
- **kwargs: keyword arguments of the DBNet architecture
411
-
412
- Returns:
413
- text detection architecture
414
- """
415
- return _db_mobilenet(
416
- "db_mobilenet_v3_large",
417
- pretrained,
418
- mobilenet_v3_large,
419
- ["inverted_2", "inverted_5", "inverted_11", "final_block"],
420
- **kwargs,
421
- )