python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -61,6 +61,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
61
61
 
62
62
 
63
63
  def resnet_stage(in_channels: int, out_channels: int, num_blocks: int, stride: int) -> List[nn.Module]:
64
+ """Build a ResNet stage"""
64
65
  _layers: List[nn.Module] = []
65
66
 
66
67
  in_chan = in_channels
@@ -83,6 +84,7 @@ class ResNet(nn.Sequential):
83
84
  Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
84
85
 
85
86
  Args:
87
+ ----
86
88
  num_blocks: number of resnet block in each stage
87
89
  output_channels: number of channels in each stage
88
90
  stage_conv: whether to add a conv_sequence after each stage
@@ -134,13 +136,11 @@ class ResNet(nn.Sequential):
134
136
  _layers.append(nn.Sequential(*_stage))
135
137
 
136
138
  if include_top:
137
- _layers.extend(
138
- [
139
- nn.AdaptiveAvgPool2d(1),
140
- nn.Flatten(1),
141
- nn.Linear(output_channels[-1], num_classes, bias=True),
142
- ]
143
- )
139
+ _layers.extend([
140
+ nn.AdaptiveAvgPool2d(1),
141
+ nn.Flatten(1),
142
+ nn.Linear(output_channels[-1], num_classes, bias=True),
143
+ ])
144
144
 
145
145
  super().__init__(*_layers)
146
146
  self.cfg = cfg
@@ -224,12 +224,14 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> TVResNet:
224
224
  >>> out = model(input_tensor)
225
225
 
226
226
  Args:
227
+ ----
227
228
  pretrained: boolean, True if model is pretrained
229
+ **kwargs: keyword arguments of the ResNet architecture
228
230
 
229
231
  Returns:
232
+ -------
230
233
  A resnet18 model
231
234
  """
232
-
233
235
  return _tv_resnet(
234
236
  "resnet18",
235
237
  pretrained,
@@ -251,12 +253,14 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
251
253
  >>> out = model(input_tensor)
252
254
 
253
255
  Args:
256
+ ----
254
257
  pretrained: boolean, True if model is pretrained
258
+ **kwargs: keyword arguments of the ResNet architecture
255
259
 
256
260
  Returns:
261
+ -------
257
262
  A resnet31 model
258
263
  """
259
-
260
264
  return _resnet(
261
265
  "resnet31",
262
266
  pretrained,
@@ -283,12 +287,14 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> TVResNet:
283
287
  >>> out = model(input_tensor)
284
288
 
285
289
  Args:
290
+ ----
286
291
  pretrained: boolean, True if model is pretrained
292
+ **kwargs: keyword arguments of the ResNet architecture
287
293
 
288
294
  Returns:
295
+ -------
289
296
  A resnet34 model
290
297
  """
291
-
292
298
  return _tv_resnet(
293
299
  "resnet34",
294
300
  pretrained,
@@ -309,12 +315,14 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
309
315
  >>> out = model(input_tensor)
310
316
 
311
317
  Args:
318
+ ----
312
319
  pretrained: boolean, True if model is pretrained
320
+ **kwargs: keyword arguments of the ResNet architecture
313
321
 
314
322
  Returns:
323
+ -------
315
324
  A resnet34_wide model
316
325
  """
317
-
318
326
  return _resnet(
319
327
  "resnet34_wide",
320
328
  pretrained,
@@ -341,12 +349,14 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> TVResNet:
341
349
  >>> out = model(input_tensor)
342
350
 
343
351
  Args:
352
+ ----
344
353
  pretrained: boolean, True if model is pretrained
354
+ **kwargs: keyword arguments of the ResNet architecture
345
355
 
346
356
  Returns:
357
+ -------
347
358
  A resnet50 model
348
359
  """
349
-
350
360
  return _tv_resnet(
351
361
  "resnet50",
352
362
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -58,10 +58,10 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
58
58
 
59
59
 
60
60
  class ResnetBlock(layers.Layer):
61
-
62
61
  """Implements a resnet31 block with shortcut
63
62
 
64
63
  Args:
64
+ ----
65
65
  conv_shortcut: Use of shortcut
66
66
  output_channels: number of channels to use in Conv2D
67
67
  kernel_size: size of square kernels
@@ -71,19 +71,17 @@ class ResnetBlock(layers.Layer):
71
71
  def __init__(self, output_channels: int, conv_shortcut: bool, strides: int = 1, **kwargs) -> None:
72
72
  super().__init__(**kwargs)
73
73
  if conv_shortcut:
74
- self.shortcut = Sequential(
75
- [
76
- layers.Conv2D(
77
- filters=output_channels,
78
- strides=strides,
79
- padding="same",
80
- kernel_size=1,
81
- use_bias=False,
82
- kernel_initializer="he_normal",
83
- ),
84
- layers.BatchNormalization(),
85
- ]
86
- )
74
+ self.shortcut = Sequential([
75
+ layers.Conv2D(
76
+ filters=output_channels,
77
+ strides=strides,
78
+ padding="same",
79
+ kernel_size=1,
80
+ use_bias=False,
81
+ kernel_initializer="he_normal",
82
+ ),
83
+ layers.BatchNormalization(),
84
+ ])
87
85
  else:
88
86
  self.shortcut = layers.Lambda(lambda x: x)
89
87
  self.conv_block = Sequential(self.conv_resnetblock(output_channels, 3, strides))
@@ -123,6 +121,7 @@ class ResNet(Sequential):
123
121
  """Implements a ResNet architecture
124
122
 
125
123
  Args:
124
+ ----
126
125
  num_blocks: number of resnet block in each stage
127
126
  output_channels: number of channels in each stage
128
127
  stage_downsample: whether the first residual block of a stage should downsample
@@ -177,12 +176,10 @@ class ResNet(Sequential):
177
176
  inplanes = out_chan
178
177
 
179
178
  if include_top:
180
- _layers.extend(
181
- [
182
- layers.GlobalAveragePooling2D(),
183
- layers.Dense(num_classes),
184
- ]
185
- )
179
+ _layers.extend([
180
+ layers.GlobalAveragePooling2D(),
181
+ layers.Dense(num_classes),
182
+ ])
186
183
 
187
184
  super().__init__(_layers)
188
185
  self.cfg = cfg
@@ -231,12 +228,14 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet:
231
228
  >>> out = model(input_tensor)
232
229
 
233
230
  Args:
231
+ ----
234
232
  pretrained: boolean, True if model is pretrained
233
+ **kwargs: keyword arguments of the ResNet architecture
235
234
 
236
235
  Returns:
236
+ -------
237
237
  A classification model
238
238
  """
239
-
240
239
  return _resnet(
241
240
  "resnet18",
242
241
  pretrained,
@@ -262,12 +261,14 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
262
261
  >>> out = model(input_tensor)
263
262
 
264
263
  Args:
264
+ ----
265
265
  pretrained: boolean, True if model is pretrained
266
+ **kwargs: keyword arguments of the ResNet architecture
266
267
 
267
268
  Returns:
269
+ -------
268
270
  A classification model
269
271
  """
270
-
271
272
  return _resnet(
272
273
  "resnet31",
273
274
  pretrained,
@@ -293,12 +294,14 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet:
293
294
  >>> out = model(input_tensor)
294
295
 
295
296
  Args:
297
+ ----
296
298
  pretrained: boolean, True if model is pretrained
299
+ **kwargs: keyword arguments of the ResNet architecture
297
300
 
298
301
  Returns:
302
+ -------
299
303
  A classification model
300
304
  """
301
-
302
305
  return _resnet(
303
306
  "resnet34",
304
307
  pretrained,
@@ -323,12 +326,14 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
323
326
  >>> out = model(input_tensor)
324
327
 
325
328
  Args:
329
+ ----
326
330
  pretrained: boolean, True if model is pretrained
331
+ **kwargs: keyword arguments of the ResNet architecture
327
332
 
328
333
  Returns:
334
+ -------
329
335
  A classification model
330
336
  """
331
-
332
337
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs["resnet50"]["classes"]))
333
338
  kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs["resnet50"]["input_shape"])
334
339
  kwargs["classes"] = kwargs.get("classes", default_cfgs["resnet50"]["classes"])
@@ -368,12 +373,14 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
368
373
  >>> out = model(input_tensor)
369
374
 
370
375
  Args:
376
+ ----
371
377
  pretrained: boolean, True if model is pretrained
378
+ **kwargs: keyword arguments of the ResNet architecture
372
379
 
373
380
  Returns:
381
+ -------
374
382
  A classification model
375
383
  """
376
-
377
384
  return _resnet(
378
385
  "resnet34_wide",
379
386
  pretrained,
@@ -0,0 +1,6 @@
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
+
3
+ if is_tf_available():
4
+ from .tensorflow import *
5
+ elif is_torch_available():
6
+ from .pytorch import * # type: ignore[assignment]
@@ -0,0 +1,275 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+
7
+ from copy import deepcopy
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ from torch import nn
11
+
12
+ from doctr.datasets import VOCABS
13
+
14
+ from ...modules.layers.pytorch import FASTConvLayer
15
+ from ...utils import conv_sequence_pt, load_pretrained_params
16
+
17
+ __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
18
+
19
+ default_cfgs: Dict[str, Dict[str, Any]] = {
20
+ "textnet_tiny": {
21
+ "mean": (0.694, 0.695, 0.693),
22
+ "std": (0.299, 0.296, 0.301),
23
+ "input_shape": (3, 32, 32),
24
+ "classes": list(VOCABS["french"]),
25
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-c5970fe0.pt&src=0",
26
+ },
27
+ "textnet_small": {
28
+ "mean": (0.694, 0.695, 0.693),
29
+ "std": (0.299, 0.296, 0.301),
30
+ "input_shape": (3, 32, 32),
31
+ "classes": list(VOCABS["french"]),
32
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-6e8ab0ce.pt&src=0",
33
+ },
34
+ "textnet_base": {
35
+ "mean": (0.694, 0.695, 0.693),
36
+ "std": (0.299, 0.296, 0.301),
37
+ "input_shape": (3, 32, 32),
38
+ "classes": list(VOCABS["french"]),
39
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-8295dc85.pt&src=0",
40
+ },
41
+ }
42
+
43
+
44
+ class TextNet(nn.Sequential):
45
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
46
+ Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
47
+ Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
48
+
49
+ Args:
50
+ ----
51
+ stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
52
+ include_top (bool, optional): Whether to include the classifier head. Defaults to True.
53
+ num_classes (int, optional): Number of output classes. Defaults to 1000.
54
+ cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ stages: List[Dict[str, List[int]]],
60
+ input_shape: Tuple[int, int, int] = (3, 32, 32),
61
+ num_classes: int = 1000,
62
+ include_top: bool = True,
63
+ cfg: Optional[Dict[str, Any]] = None,
64
+ ) -> None:
65
+ _layers: List[nn.Module] = [
66
+ *conv_sequence_pt(
67
+ in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2, padding=(1, 1)
68
+ ),
69
+ *[
70
+ nn.Sequential(*[
71
+ FASTConvLayer(**params) # type: ignore[arg-type]
72
+ for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))]
73
+ ])
74
+ for stage in stages
75
+ ],
76
+ ]
77
+
78
+ if include_top:
79
+ _layers.append(
80
+ nn.Sequential(
81
+ nn.AdaptiveAvgPool2d(1),
82
+ nn.Flatten(1),
83
+ nn.Linear(stages[-1]["out_channels"][-1], num_classes),
84
+ )
85
+ )
86
+
87
+ super().__init__(*_layers)
88
+ self.cfg = cfg
89
+
90
+ for m in self.modules():
91
+ if isinstance(m, nn.Conv2d):
92
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
93
+ elif isinstance(m, nn.BatchNorm2d):
94
+ nn.init.constant_(m.weight, 1)
95
+ nn.init.constant_(m.bias, 0)
96
+
97
+
98
+ def _textnet(
99
+ arch: str,
100
+ pretrained: bool,
101
+ ignore_keys: Optional[List[str]] = None,
102
+ **kwargs: Any,
103
+ ) -> TextNet:
104
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
105
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
106
+
107
+ _cfg = deepcopy(default_cfgs[arch])
108
+ _cfg["num_classes"] = kwargs["num_classes"]
109
+ _cfg["classes"] = kwargs["classes"]
110
+ kwargs.pop("classes")
111
+
112
+ # Build the model
113
+ model = TextNet(**kwargs)
114
+ # Load pretrained parameters
115
+ if pretrained:
116
+ # The number of classes is not the same as the number of classes in the pretrained model =>
117
+ # remove the last layer weights
118
+ _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
119
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
120
+
121
+ model.cfg = _cfg
122
+
123
+ return model
124
+
125
+
126
+ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
127
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
128
+ Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
129
+ Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
130
+
131
+ >>> import torch
132
+ >>> from doctr.models import textnet_tiny
133
+ >>> model = textnet_tiny(pretrained=False)
134
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
135
+ >>> out = model(input_tensor)
136
+
137
+ Args:
138
+ ----
139
+ pretrained: boolean, True if model is pretrained
140
+ **kwargs: keyword arguments of the TextNet architecture
141
+
142
+ Returns:
143
+ -------
144
+ A textnet tiny model
145
+ """
146
+ return _textnet(
147
+ "textnet_tiny",
148
+ pretrained,
149
+ stages=[
150
+ {"in_channels": [64] * 3, "out_channels": [64] * 3, "kernel_size": [(3, 3)] * 3, "stride": [1, 2, 1]},
151
+ {
152
+ "in_channels": [64, 128, 128, 128],
153
+ "out_channels": [128] * 4,
154
+ "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1)],
155
+ "stride": [2, 1, 1, 1],
156
+ },
157
+ {
158
+ "in_channels": [128, 256, 256, 256],
159
+ "out_channels": [256] * 4,
160
+ "kernel_size": [(3, 3), (3, 3), (3, 1), (1, 3)],
161
+ "stride": [2, 1, 1, 1],
162
+ },
163
+ {
164
+ "in_channels": [256, 512, 512, 512],
165
+ "out_channels": [512] * 4,
166
+ "kernel_size": [(3, 3), (3, 1), (1, 3), (3, 3)],
167
+ "stride": [2, 1, 1, 1],
168
+ },
169
+ ],
170
+ ignore_keys=["7.2.weight", "7.2.bias"],
171
+ **kwargs,
172
+ )
173
+
174
+
175
+ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
176
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
177
+ Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
178
+ Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
179
+
180
+ >>> import torch
181
+ >>> from doctr.models import textnet_small
182
+ >>> model = textnet_small(pretrained=False)
183
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
184
+ >>> out = model(input_tensor)
185
+
186
+ Args:
187
+ ----
188
+ pretrained: boolean, True if model is pretrained
189
+ **kwargs: keyword arguments of the TextNet architecture
190
+
191
+ Returns:
192
+ -------
193
+ A TextNet small model
194
+ """
195
+ return _textnet(
196
+ "textnet_small",
197
+ pretrained,
198
+ stages=[
199
+ {"in_channels": [64] * 2, "out_channels": [64] * 2, "kernel_size": [(3, 3)] * 2, "stride": [1, 2]},
200
+ {
201
+ "in_channels": [64, 128, 128, 128, 128, 128, 128, 128],
202
+ "out_channels": [128] * 8,
203
+ "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1), (1, 3), (3, 3)],
204
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1],
205
+ },
206
+ {
207
+ "in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
208
+ "out_channels": [256] * 8,
209
+ "kernel_size": [(3, 3), (3, 3), (1, 3), (3, 1), (3, 3), (1, 3), (3, 1), (3, 3)],
210
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1],
211
+ },
212
+ {
213
+ "in_channels": [256, 512, 512, 512, 512],
214
+ "out_channels": [512] * 5,
215
+ "kernel_size": [(3, 3), (3, 1), (1, 3), (1, 3), (3, 1)],
216
+ "stride": [2, 1, 1, 1, 1],
217
+ },
218
+ ],
219
+ ignore_keys=["7.2.weight", "7.2.bias"],
220
+ **kwargs,
221
+ )
222
+
223
+
224
+ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
225
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
226
+ Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
227
+ Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
228
+
229
+ >>> import torch
230
+ >>> from doctr.models import textnet_base
231
+ >>> model = textnet_base(pretrained=False)
232
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
233
+ >>> out = model(input_tensor)
234
+
235
+ Args:
236
+ ----
237
+ pretrained: boolean, True if model is pretrained
238
+ **kwargs: keyword arguments of the TextNet architecture
239
+
240
+ Returns:
241
+ -------
242
+ A TextNet base model
243
+ """
244
+ return _textnet(
245
+ "textnet_base",
246
+ pretrained,
247
+ stages=[
248
+ {
249
+ "in_channels": [64] * 10,
250
+ "out_channels": [64] * 10,
251
+ "kernel_size": [(3, 3), (3, 3), (3, 1), (3, 3), (3, 1), (3, 3), (3, 3), (1, 3), (3, 3), (3, 3)],
252
+ "stride": [1, 2, 1, 1, 1, 1, 1, 1, 1, 1],
253
+ },
254
+ {
255
+ "in_channels": [64, 128, 128, 128, 128, 128, 128, 128, 128, 128],
256
+ "out_channels": [128] * 10,
257
+ "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 3), (3, 1), (3, 1), (3, 3), (3, 3)],
258
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1],
259
+ },
260
+ {
261
+ "in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
262
+ "out_channels": [256] * 8,
263
+ "kernel_size": [(3, 3), (3, 3), (3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1)],
264
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1],
265
+ },
266
+ {
267
+ "in_channels": [256, 512, 512, 512, 512],
268
+ "out_channels": [512] * 5,
269
+ "kernel_size": [(3, 3), (1, 3), (3, 1), (3, 1), (1, 3)],
270
+ "stride": [2, 1, 1, 1, 1],
271
+ },
272
+ ],
273
+ ignore_keys=["7.2.weight", "7.2.bias"],
274
+ **kwargs,
275
+ )