python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
@@ -0,0 +1,267 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+
7
+ from copy import deepcopy
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ from tensorflow.keras import Sequential, layers
11
+
12
+ from doctr.datasets import VOCABS
13
+
14
+ from ...modules.layers.tensorflow import FASTConvLayer
15
+ from ...utils import conv_sequence, load_pretrained_params
16
+
17
+ __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
18
+
19
+ default_cfgs: Dict[str, Dict[str, Any]] = {
20
+ "textnet_tiny": {
21
+ "mean": (0.694, 0.695, 0.693),
22
+ "std": (0.299, 0.296, 0.301),
23
+ "input_shape": (32, 32, 3),
24
+ "classes": list(VOCABS["french"]),
25
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-9e605bd8.zip&src=0",
26
+ },
27
+ "textnet_small": {
28
+ "mean": (0.694, 0.695, 0.693),
29
+ "std": (0.299, 0.296, 0.301),
30
+ "input_shape": (32, 32, 3),
31
+ "classes": list(VOCABS["french"]),
32
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-4784b292.zip&src=0",
33
+ },
34
+ "textnet_base": {
35
+ "mean": (0.694, 0.695, 0.693),
36
+ "std": (0.299, 0.296, 0.301),
37
+ "input_shape": (32, 32, 3),
38
+ "classes": list(VOCABS["french"]),
39
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-2c3f3265.zip&src=0",
40
+ },
41
+ }
42
+
43
+
44
+ class TextNet(Sequential):
45
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
46
+ Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
47
+ Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
48
+
49
+ Args:
50
+ ----
51
+ stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
52
+ include_top (bool, optional): Whether to include the classifier head. Defaults to True.
53
+ num_classes (int, optional): Number of output classes. Defaults to 1000.
54
+ cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ stages: List[Dict[str, List[int]]],
60
+ input_shape: Tuple[int, int, int] = (32, 32, 3),
61
+ num_classes: int = 1000,
62
+ include_top: bool = True,
63
+ cfg: Optional[Dict[str, Any]] = None,
64
+ ) -> None:
65
+ _layers = [
66
+ *conv_sequence(
67
+ out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2, input_shape=input_shape
68
+ ),
69
+ *[
70
+ Sequential(
71
+ [
72
+ FASTConvLayer(**params) # type: ignore[arg-type]
73
+ for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))]
74
+ ],
75
+ name=f"stage_{i}",
76
+ )
77
+ for i, stage in enumerate(stages)
78
+ ],
79
+ ]
80
+
81
+ if include_top:
82
+ _layers.append(
83
+ Sequential(
84
+ [
85
+ layers.AveragePooling2D(1),
86
+ layers.Flatten(),
87
+ layers.Dense(num_classes),
88
+ ],
89
+ name="classifier",
90
+ )
91
+ )
92
+
93
+ super().__init__(_layers)
94
+ self.cfg = cfg
95
+
96
+
97
+ def _textnet(
98
+ arch: str,
99
+ pretrained: bool,
100
+ **kwargs: Any,
101
+ ) -> TextNet:
102
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
103
+ kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
104
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
105
+
106
+ _cfg = deepcopy(default_cfgs[arch])
107
+ _cfg["num_classes"] = kwargs["num_classes"]
108
+ _cfg["input_shape"] = kwargs["input_shape"]
109
+ _cfg["classes"] = kwargs["classes"]
110
+ kwargs.pop("classes")
111
+
112
+ # Build the model
113
+ model = TextNet(cfg=_cfg, **kwargs)
114
+ # Load pretrained parameters
115
+ if pretrained:
116
+ load_pretrained_params(model, default_cfgs[arch]["url"])
117
+
118
+ return model
119
+
120
+
121
+ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
122
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
123
+ Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
124
+ Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
125
+
126
+ >>> import tensorflow as tf
127
+ >>> from doctr.models import textnet_tiny
128
+ >>> model = textnet_tiny(pretrained=False)
129
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
130
+ >>> out = model(input_tensor)
131
+
132
+ Args:
133
+ ----
134
+ pretrained: boolean, True if model is pretrained
135
+ **kwargs: keyword arguments of the TextNet architecture
136
+
137
+ Returns:
138
+ -------
139
+ A textnet tiny model
140
+ """
141
+ return _textnet(
142
+ "textnet_tiny",
143
+ pretrained,
144
+ stages=[
145
+ {"in_channels": [64] * 3, "out_channels": [64] * 3, "kernel_size": [(3, 3)] * 3, "stride": [1, 2, 1]},
146
+ {
147
+ "in_channels": [64, 128, 128, 128],
148
+ "out_channels": [128] * 4,
149
+ "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1)],
150
+ "stride": [2, 1, 1, 1],
151
+ },
152
+ {
153
+ "in_channels": [128, 256, 256, 256],
154
+ "out_channels": [256] * 4,
155
+ "kernel_size": [(3, 3), (3, 3), (3, 1), (1, 3)],
156
+ "stride": [2, 1, 1, 1],
157
+ },
158
+ {
159
+ "in_channels": [256, 512, 512, 512],
160
+ "out_channels": [512] * 4,
161
+ "kernel_size": [(3, 3), (3, 1), (1, 3), (3, 3)],
162
+ "stride": [2, 1, 1, 1],
163
+ },
164
+ ],
165
+ **kwargs,
166
+ )
167
+
168
+
169
+ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
170
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
171
+ Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
172
+ Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
173
+
174
+ >>> import tensorflow as tf
175
+ >>> from doctr.models import textnet_small
176
+ >>> model = textnet_small(pretrained=False)
177
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
178
+ >>> out = model(input_tensor)
179
+
180
+ Args:
181
+ ----
182
+ pretrained: boolean, True if model is pretrained
183
+ **kwargs: keyword arguments of the TextNet architecture
184
+
185
+ Returns:
186
+ -------
187
+ A TextNet small model
188
+ """
189
+ return _textnet(
190
+ "textnet_small",
191
+ pretrained,
192
+ stages=[
193
+ {"in_channels": [64] * 2, "out_channels": [64] * 2, "kernel_size": [(3, 3)] * 2, "stride": [1, 2]},
194
+ {
195
+ "in_channels": [64, 128, 128, 128, 128, 128, 128, 128],
196
+ "out_channels": [128] * 8,
197
+ "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1), (1, 3), (3, 3)],
198
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1],
199
+ },
200
+ {
201
+ "in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
202
+ "out_channels": [256] * 8,
203
+ "kernel_size": [(3, 3), (3, 3), (1, 3), (3, 1), (3, 3), (1, 3), (3, 1), (3, 3)],
204
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1],
205
+ },
206
+ {
207
+ "in_channels": [256, 512, 512, 512, 512],
208
+ "out_channels": [512] * 5,
209
+ "kernel_size": [(3, 3), (3, 1), (1, 3), (1, 3), (3, 1)],
210
+ "stride": [2, 1, 1, 1, 1],
211
+ },
212
+ ],
213
+ **kwargs,
214
+ )
215
+
216
+
217
+ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
218
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
219
+ Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
220
+ Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
221
+
222
+ >>> import tensorflow as tf
223
+ >>> from doctr.models import textnet_base
224
+ >>> model = textnet_base(pretrained=False)
225
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
226
+ >>> out = model(input_tensor)
227
+
228
+ Args:
229
+ ----
230
+ pretrained: boolean, True if model is pretrained
231
+ **kwargs: keyword arguments of the TextNet architecture
232
+
233
+ Returns:
234
+ -------
235
+ A TextNet base model
236
+ """
237
+ return _textnet(
238
+ "textnet_base",
239
+ pretrained,
240
+ stages=[
241
+ {
242
+ "in_channels": [64] * 10,
243
+ "out_channels": [64] * 10,
244
+ "kernel_size": [(3, 3), (3, 3), (3, 1), (3, 3), (3, 1), (3, 3), (3, 3), (1, 3), (3, 3), (3, 3)],
245
+ "stride": [1, 2, 1, 1, 1, 1, 1, 1, 1, 1],
246
+ },
247
+ {
248
+ "in_channels": [64, 128, 128, 128, 128, 128, 128, 128, 128, 128],
249
+ "out_channels": [128] * 10,
250
+ "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 3), (3, 1), (3, 1), (3, 3), (3, 3)],
251
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1],
252
+ },
253
+ {
254
+ "in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
255
+ "out_channels": [256] * 8,
256
+ "kernel_size": [(3, 3), (3, 3), (3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1)],
257
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1],
258
+ },
259
+ {
260
+ "in_channels": [256, 512, 512, 512, 512],
261
+ "out_channels": [512] * 5,
262
+ "kernel_size": [(3, 3), (1, 3), (3, 1), (3, 1), (1, 3)],
263
+ "stride": [2, 1, 1, 1, 1],
264
+ },
265
+ ],
266
+ **kwargs,
267
+ )
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -77,12 +77,14 @@ def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> tv_vgg.VGG:
77
77
  >>> out = model(input_tensor)
78
78
 
79
79
  Args:
80
+ ----
80
81
  pretrained (bool): If True, returns a model pre-trained on ImageNet
82
+ **kwargs: keyword arguments of the VGG architecture
81
83
 
82
84
  Returns:
85
+ -------
83
86
  VGG feature extractor
84
87
  """
85
-
86
88
  return _vgg(
87
89
  "vgg16_bn_r",
88
90
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -32,6 +32,7 @@ class VGG(Sequential):
32
32
  <https://arxiv.org/pdf/1409.1556.pdf>`_.
33
33
 
34
34
  Args:
35
+ ----
35
36
  num_blocks: number of convolutional block in each stage
36
37
  planes: number of output channels in each stage
37
38
  rect_pools: whether pooling square kernels should be replace with rectangular ones
@@ -99,12 +100,14 @@ def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> VGG:
99
100
  >>> out = model(input_tensor)
100
101
 
101
102
  Args:
103
+ ----
102
104
  pretrained (bool): If True, returns a model pre-trained on ImageNet
105
+ **kwargs: keyword arguments of the VGG architecture
103
106
 
104
107
  Returns:
108
+ -------
105
109
  VGG feature extractor
106
110
  """
107
-
108
111
  return _vgg(
109
112
  "vgg16_bn_r", pretrained, [2, 2, 3, 3, 3], [64, 128, 256, 512, 512], [False, False, True, True, True], **kwargs
110
113
  )
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -40,6 +40,7 @@ class ClassifierHead(nn.Module):
40
40
  """Classifier head for Vision Transformer
41
41
 
42
42
  Args:
43
+ ----
43
44
  in_channels: number of input channels
44
45
  num_classes: number of output classes
45
46
  """
@@ -64,6 +65,7 @@ class VisionTransformer(nn.Sequential):
64
65
  <https://arxiv.org/pdf/2010.11929.pdf>`_.
65
66
 
66
67
  Args:
68
+ ----
67
69
  d_model: dimension of the transformer layers
68
70
  num_layers: number of transformer layers
69
71
  num_heads: number of attention heads
@@ -141,12 +143,14 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
141
143
  >>> out = model(input_tensor)
142
144
 
143
145
  Args:
146
+ ----
144
147
  pretrained: boolean, True if model is pretrained
148
+ **kwargs: keyword arguments of the VisionTransformer architecture
145
149
 
146
150
  Returns:
151
+ -------
147
152
  A feature extractor model
148
153
  """
149
-
150
154
  return _vit(
151
155
  "vit_s",
152
156
  pretrained,
@@ -171,12 +175,14 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
171
175
  >>> out = model(input_tensor)
172
176
 
173
177
  Args:
178
+ ----
174
179
  pretrained: boolean, True if model is pretrained
180
+ **kwargs: keyword arguments of the VisionTransformer architecture
175
181
 
176
182
  Returns:
183
+ -------
177
184
  A feature extractor model
178
185
  """
179
-
180
186
  return _vit(
181
187
  "vit_b",
182
188
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -41,6 +41,7 @@ class ClassifierHead(layers.Layer, NestedObject):
41
41
  """Classifier head for Vision Transformer
42
42
 
43
43
  Args:
44
+ ----
44
45
  num_classes: number of output classes
45
46
  """
46
47
 
@@ -60,6 +61,7 @@ class VisionTransformer(Sequential):
60
61
  <https://arxiv.org/pdf/2010.11929.pdf>`_.
61
62
 
62
63
  Args:
64
+ ----
63
65
  d_model: dimension of the transformer layers
64
66
  num_layers: number of transformer layers
65
67
  num_heads: number of attention heads
@@ -140,12 +142,14 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
140
142
  >>> out = model(input_tensor)
141
143
 
142
144
  Args:
145
+ ----
143
146
  pretrained: boolean, True if model is pretrained
147
+ **kwargs: keyword arguments of the VisionTransformer architecture
144
148
 
145
149
  Returns:
150
+ -------
146
151
  A feature extractor model
147
152
  """
148
-
149
153
  return _vit(
150
154
  "vit_s",
151
155
  pretrained,
@@ -169,12 +173,14 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
169
173
  >>> out = model(input_tensor)
170
174
 
171
175
  Args:
176
+ ----
172
177
  pretrained: boolean, True if model is pretrained
178
+ **kwargs: keyword arguments of the VisionTransformer architecture
173
179
 
174
180
  Returns:
181
+ -------
175
182
  A feature extractor model
176
183
  """
177
-
178
184
  return _vit(
179
185
  "vit_b",
180
186
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -24,6 +24,9 @@ ARCHS: List[str] = [
24
24
  "resnet34",
25
25
  "resnet50",
26
26
  "resnet34_wide",
27
+ "textnet_tiny",
28
+ "textnet_small",
29
+ "textnet_base",
27
30
  "vgg16_bn_r",
28
31
  "vit_s",
29
32
  "vit_b",
@@ -59,11 +62,13 @@ def crop_orientation_predictor(
59
62
  >>> out = model([input_crop])
60
63
 
61
64
  Args:
65
+ ----
62
66
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small')
63
67
  pretrained: If True, returns a model pre-trained on our recognition crops dataset
68
+ **kwargs: keyword arguments to be passed to the CropOrientationPredictor
64
69
 
65
70
  Returns:
71
+ -------
66
72
  CropOrientationPredictor
67
73
  """
68
-
69
74
  return _crop_orientation_predictor(arch, pretrained, **kwargs)
doctr/models/core.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,3 +1,4 @@
1
1
  from .differentiable_binarization import *
2
2
  from .linknet import *
3
+ from .fast import *
3
4
  from .zoo import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -13,9 +13,12 @@ def erode(x: Tensor, kernel_size: int) -> Tensor:
13
13
  """Performs erosion on a given tensor
14
14
 
15
15
  Args:
16
+ ----
16
17
  x: boolean tensor of shape (N, C, H, W)
17
18
  kernel_size: the size of the kernel to use for erosion
19
+
18
20
  Returns:
21
+ -------
19
22
  the eroded tensor
20
23
  """
21
24
  _pad = (kernel_size - 1) // 2
@@ -27,9 +30,12 @@ def dilate(x: Tensor, kernel_size: int) -> Tensor:
27
30
  """Performs dilation on a given tensor
28
31
 
29
32
  Args:
33
+ ----
30
34
  x: boolean tensor of shape (N, C, H, W)
31
35
  kernel_size: the size of the kernel to use for dilation
36
+
32
37
  Returns:
38
+ -------
33
39
  the dilated tensor
34
40
  """
35
41
  _pad = (kernel_size - 1) // 2
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -12,12 +12,14 @@ def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
12
12
  """Performs erosion on a given tensor
13
13
 
14
14
  Args:
15
+ ----
15
16
  x: boolean tensor of shape (N, H, W, C)
16
17
  kernel_size: the size of the kernel to use for erosion
18
+
17
19
  Returns:
20
+ -------
18
21
  the eroded tensor
19
22
  """
20
-
21
23
  return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME")
22
24
 
23
25
 
@@ -25,10 +27,12 @@ def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
25
27
  """Performs dilation on a given tensor
26
28
 
27
29
  Args:
30
+ ----
28
31
  x: boolean tensor of shape (N, H, W, C)
29
32
  kernel_size: the size of the kernel to use for dilation
33
+
30
34
  Returns:
35
+ -------
31
36
  the dilated tensor
32
37
  """
33
-
34
38
  return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME")
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -17,6 +17,7 @@ class DetectionPostProcessor(NestedObject):
17
17
  """Abstract class to postprocess the raw output of the model
18
18
 
19
19
  Args:
20
+ ----
20
21
  box_thresh (float): minimal objectness score to consider a box
21
22
  bin_thresh (float): threshold to apply to segmentation raw heatmap
22
23
  assume straight_pages (bool): if True, fit straight boxes only
@@ -36,9 +37,13 @@ class DetectionPostProcessor(NestedObject):
36
37
  """Compute the confidence score for a polygon : mean of the p values on the polygon
37
38
 
38
39
  Args:
40
+ ----
39
41
  pred (np.ndarray): p map returned by the model
42
+ points: coordinates of the polygon
43
+ assume_straight_pages: if True, fit straight boxes only
40
44
 
41
45
  Returns:
46
+ -------
42
47
  polygon objectness
43
48
  """
44
49
  h, w = pred.shape[:2]
@@ -52,7 +57,7 @@ class DetectionPostProcessor(NestedObject):
52
57
 
53
58
  else:
54
59
  mask: np.ndarray = np.zeros((h, w), np.int32)
55
- cv2.fillPoly(mask, [points.astype(np.int32)], 1.0)
60
+ cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload]
56
61
  product = pred * mask
57
62
  return np.sum(product) / np.count_nonzero(product)
58
63
 
@@ -70,13 +75,14 @@ class DetectionPostProcessor(NestedObject):
70
75
  """Performs postprocessing for a list of model outputs
71
76
 
72
77
  Args:
78
+ ----
73
79
  proba_map: probability map of shape (N, H, W, C)
74
80
 
75
81
  Returns:
82
+ -------
76
83
  list of N class predictions (for each input sample), where each class predictions is a list of C tensors
77
84
  of shape (*, 5) or (*, 6)
78
85
  """
79
-
80
86
  if proba_map.ndim != 4:
81
87
  raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.")
82
88