python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +2 -0
- doctr/datasets/cord.py +6 -4
- doctr/datasets/datasets/base.py +3 -2
- doctr/datasets/datasets/pytorch.py +4 -2
- doctr/datasets/datasets/tensorflow.py +4 -2
- doctr/datasets/detection.py +6 -3
- doctr/datasets/doc_artefacts.py +2 -1
- doctr/datasets/funsd.py +7 -8
- doctr/datasets/generator/base.py +3 -2
- doctr/datasets/generator/pytorch.py +3 -1
- doctr/datasets/generator/tensorflow.py +3 -1
- doctr/datasets/ic03.py +3 -2
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +6 -4
- doctr/datasets/iiithws.py +2 -1
- doctr/datasets/imgur5k.py +3 -2
- doctr/datasets/loader.py +4 -2
- doctr/datasets/mjsynth.py +2 -1
- doctr/datasets/ocr.py +2 -1
- doctr/datasets/orientation.py +40 -0
- doctr/datasets/recognition.py +3 -2
- doctr/datasets/sroie.py +2 -1
- doctr/datasets/svhn.py +2 -1
- doctr/datasets/svt.py +3 -2
- doctr/datasets/synthtext.py +2 -1
- doctr/datasets/utils.py +27 -11
- doctr/datasets/vocabs.py +26 -1
- doctr/datasets/wildreceipt.py +111 -0
- doctr/file_utils.py +3 -1
- doctr/io/elements.py +52 -35
- doctr/io/html.py +5 -3
- doctr/io/image/base.py +5 -4
- doctr/io/image/pytorch.py +12 -7
- doctr/io/image/tensorflow.py +11 -6
- doctr/io/pdf.py +5 -4
- doctr/io/reader.py +13 -5
- doctr/models/_utils.py +30 -53
- doctr/models/artefacts/barcode.py +4 -3
- doctr/models/artefacts/face.py +4 -2
- doctr/models/builder.py +58 -43
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +5 -2
- doctr/models/classification/magc_resnet/tensorflow.py +5 -2
- doctr/models/classification/mobilenet/pytorch.py +16 -4
- doctr/models/classification/mobilenet/tensorflow.py +29 -20
- doctr/models/classification/predictor/pytorch.py +3 -2
- doctr/models/classification/predictor/tensorflow.py +2 -1
- doctr/models/classification/resnet/pytorch.py +23 -13
- doctr/models/classification/resnet/tensorflow.py +33 -26
- doctr/models/classification/textnet/__init__.py +6 -0
- doctr/models/classification/textnet/pytorch.py +275 -0
- doctr/models/classification/textnet/tensorflow.py +267 -0
- doctr/models/classification/vgg/pytorch.py +4 -2
- doctr/models/classification/vgg/tensorflow.py +5 -2
- doctr/models/classification/vit/pytorch.py +9 -3
- doctr/models/classification/vit/tensorflow.py +9 -3
- doctr/models/classification/zoo.py +7 -2
- doctr/models/core.py +1 -1
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/pytorch.py +7 -1
- doctr/models/detection/_utils/tensorflow.py +7 -3
- doctr/models/detection/core.py +9 -3
- doctr/models/detection/differentiable_binarization/base.py +37 -25
- doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
- doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +256 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +12 -5
- doctr/models/detection/linknet/pytorch.py +28 -15
- doctr/models/detection/linknet/tensorflow.py +68 -88
- doctr/models/detection/predictor/pytorch.py +16 -6
- doctr/models/detection/predictor/tensorflow.py +13 -5
- doctr/models/detection/zoo.py +19 -16
- doctr/models/factory/hub.py +20 -10
- doctr/models/kie_predictor/base.py +2 -1
- doctr/models/kie_predictor/pytorch.py +28 -36
- doctr/models/kie_predictor/tensorflow.py +27 -27
- doctr/models/modules/__init__.py +1 -0
- doctr/models/modules/layers/__init__.py +6 -0
- doctr/models/modules/layers/pytorch.py +166 -0
- doctr/models/modules/layers/tensorflow.py +175 -0
- doctr/models/modules/transformer/pytorch.py +24 -22
- doctr/models/modules/transformer/tensorflow.py +6 -4
- doctr/models/modules/vision_transformer/pytorch.py +2 -4
- doctr/models/modules/vision_transformer/tensorflow.py +2 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
- doctr/models/predictor/base.py +14 -3
- doctr/models/predictor/pytorch.py +26 -29
- doctr/models/predictor/tensorflow.py +25 -22
- doctr/models/preprocessor/pytorch.py +14 -9
- doctr/models/preprocessor/tensorflow.py +10 -5
- doctr/models/recognition/core.py +4 -1
- doctr/models/recognition/crnn/pytorch.py +23 -16
- doctr/models/recognition/crnn/tensorflow.py +25 -17
- doctr/models/recognition/master/base.py +4 -1
- doctr/models/recognition/master/pytorch.py +20 -9
- doctr/models/recognition/master/tensorflow.py +20 -8
- doctr/models/recognition/parseq/base.py +4 -1
- doctr/models/recognition/parseq/pytorch.py +28 -22
- doctr/models/recognition/parseq/tensorflow.py +22 -11
- doctr/models/recognition/predictor/_utils.py +3 -2
- doctr/models/recognition/predictor/pytorch.py +3 -2
- doctr/models/recognition/predictor/tensorflow.py +2 -1
- doctr/models/recognition/sar/pytorch.py +14 -7
- doctr/models/recognition/sar/tensorflow.py +23 -14
- doctr/models/recognition/utils.py +5 -1
- doctr/models/recognition/vitstr/base.py +4 -1
- doctr/models/recognition/vitstr/pytorch.py +22 -13
- doctr/models/recognition/vitstr/tensorflow.py +21 -10
- doctr/models/recognition/zoo.py +4 -2
- doctr/models/utils/pytorch.py +24 -6
- doctr/models/utils/tensorflow.py +22 -3
- doctr/models/zoo.py +21 -3
- doctr/transforms/functional/base.py +8 -3
- doctr/transforms/functional/pytorch.py +23 -6
- doctr/transforms/functional/tensorflow.py +25 -5
- doctr/transforms/modules/base.py +12 -5
- doctr/transforms/modules/pytorch.py +10 -12
- doctr/transforms/modules/tensorflow.py +17 -9
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +4 -2
- doctr/utils/fonts.py +3 -2
- doctr/utils/geometry.py +95 -26
- doctr/utils/metrics.py +36 -22
- doctr/utils/multithreading.py +5 -3
- doctr/utils/repr.py +3 -1
- doctr/utils/visualization.py +31 -8
- doctr/version.py +1 -1
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
- python_doctr-0.8.1.dist-info/RECORD +173 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
- python_doctr-0.7.0.dist-info/RECORD +0 -161
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
from tensorflow.keras import Sequential, layers
|
|
11
|
+
|
|
12
|
+
from doctr.datasets import VOCABS
|
|
13
|
+
|
|
14
|
+
from ...modules.layers.tensorflow import FASTConvLayer
|
|
15
|
+
from ...utils import conv_sequence, load_pretrained_params
|
|
16
|
+
|
|
17
|
+
__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
|
|
18
|
+
|
|
19
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
20
|
+
"textnet_tiny": {
|
|
21
|
+
"mean": (0.694, 0.695, 0.693),
|
|
22
|
+
"std": (0.299, 0.296, 0.301),
|
|
23
|
+
"input_shape": (32, 32, 3),
|
|
24
|
+
"classes": list(VOCABS["french"]),
|
|
25
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-9e605bd8.zip&src=0",
|
|
26
|
+
},
|
|
27
|
+
"textnet_small": {
|
|
28
|
+
"mean": (0.694, 0.695, 0.693),
|
|
29
|
+
"std": (0.299, 0.296, 0.301),
|
|
30
|
+
"input_shape": (32, 32, 3),
|
|
31
|
+
"classes": list(VOCABS["french"]),
|
|
32
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-4784b292.zip&src=0",
|
|
33
|
+
},
|
|
34
|
+
"textnet_base": {
|
|
35
|
+
"mean": (0.694, 0.695, 0.693),
|
|
36
|
+
"std": (0.299, 0.296, 0.301),
|
|
37
|
+
"input_shape": (32, 32, 3),
|
|
38
|
+
"classes": list(VOCABS["french"]),
|
|
39
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-2c3f3265.zip&src=0",
|
|
40
|
+
},
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TextNet(Sequential):
|
|
45
|
+
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
|
|
46
|
+
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
|
|
47
|
+
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
----
|
|
51
|
+
stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
|
|
52
|
+
include_top (bool, optional): Whether to include the classifier head. Defaults to True.
|
|
53
|
+
num_classes (int, optional): Number of output classes. Defaults to 1000.
|
|
54
|
+
cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
stages: List[Dict[str, List[int]]],
|
|
60
|
+
input_shape: Tuple[int, int, int] = (32, 32, 3),
|
|
61
|
+
num_classes: int = 1000,
|
|
62
|
+
include_top: bool = True,
|
|
63
|
+
cfg: Optional[Dict[str, Any]] = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
_layers = [
|
|
66
|
+
*conv_sequence(
|
|
67
|
+
out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2, input_shape=input_shape
|
|
68
|
+
),
|
|
69
|
+
*[
|
|
70
|
+
Sequential(
|
|
71
|
+
[
|
|
72
|
+
FASTConvLayer(**params) # type: ignore[arg-type]
|
|
73
|
+
for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))]
|
|
74
|
+
],
|
|
75
|
+
name=f"stage_{i}",
|
|
76
|
+
)
|
|
77
|
+
for i, stage in enumerate(stages)
|
|
78
|
+
],
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
if include_top:
|
|
82
|
+
_layers.append(
|
|
83
|
+
Sequential(
|
|
84
|
+
[
|
|
85
|
+
layers.AveragePooling2D(1),
|
|
86
|
+
layers.Flatten(),
|
|
87
|
+
layers.Dense(num_classes),
|
|
88
|
+
],
|
|
89
|
+
name="classifier",
|
|
90
|
+
)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
super().__init__(_layers)
|
|
94
|
+
self.cfg = cfg
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _textnet(
|
|
98
|
+
arch: str,
|
|
99
|
+
pretrained: bool,
|
|
100
|
+
**kwargs: Any,
|
|
101
|
+
) -> TextNet:
|
|
102
|
+
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
103
|
+
kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
|
|
104
|
+
kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
|
|
105
|
+
|
|
106
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
107
|
+
_cfg["num_classes"] = kwargs["num_classes"]
|
|
108
|
+
_cfg["input_shape"] = kwargs["input_shape"]
|
|
109
|
+
_cfg["classes"] = kwargs["classes"]
|
|
110
|
+
kwargs.pop("classes")
|
|
111
|
+
|
|
112
|
+
# Build the model
|
|
113
|
+
model = TextNet(cfg=_cfg, **kwargs)
|
|
114
|
+
# Load pretrained parameters
|
|
115
|
+
if pretrained:
|
|
116
|
+
load_pretrained_params(model, default_cfgs[arch]["url"])
|
|
117
|
+
|
|
118
|
+
return model
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
122
|
+
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
|
|
123
|
+
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
|
|
124
|
+
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
|
|
125
|
+
|
|
126
|
+
>>> import tensorflow as tf
|
|
127
|
+
>>> from doctr.models import textnet_tiny
|
|
128
|
+
>>> model = textnet_tiny(pretrained=False)
|
|
129
|
+
>>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
|
|
130
|
+
>>> out = model(input_tensor)
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
----
|
|
134
|
+
pretrained: boolean, True if model is pretrained
|
|
135
|
+
**kwargs: keyword arguments of the TextNet architecture
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
-------
|
|
139
|
+
A textnet tiny model
|
|
140
|
+
"""
|
|
141
|
+
return _textnet(
|
|
142
|
+
"textnet_tiny",
|
|
143
|
+
pretrained,
|
|
144
|
+
stages=[
|
|
145
|
+
{"in_channels": [64] * 3, "out_channels": [64] * 3, "kernel_size": [(3, 3)] * 3, "stride": [1, 2, 1]},
|
|
146
|
+
{
|
|
147
|
+
"in_channels": [64, 128, 128, 128],
|
|
148
|
+
"out_channels": [128] * 4,
|
|
149
|
+
"kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1)],
|
|
150
|
+
"stride": [2, 1, 1, 1],
|
|
151
|
+
},
|
|
152
|
+
{
|
|
153
|
+
"in_channels": [128, 256, 256, 256],
|
|
154
|
+
"out_channels": [256] * 4,
|
|
155
|
+
"kernel_size": [(3, 3), (3, 3), (3, 1), (1, 3)],
|
|
156
|
+
"stride": [2, 1, 1, 1],
|
|
157
|
+
},
|
|
158
|
+
{
|
|
159
|
+
"in_channels": [256, 512, 512, 512],
|
|
160
|
+
"out_channels": [512] * 4,
|
|
161
|
+
"kernel_size": [(3, 3), (3, 1), (1, 3), (3, 3)],
|
|
162
|
+
"stride": [2, 1, 1, 1],
|
|
163
|
+
},
|
|
164
|
+
],
|
|
165
|
+
**kwargs,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
170
|
+
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
|
|
171
|
+
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
|
|
172
|
+
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
|
|
173
|
+
|
|
174
|
+
>>> import tensorflow as tf
|
|
175
|
+
>>> from doctr.models import textnet_small
|
|
176
|
+
>>> model = textnet_small(pretrained=False)
|
|
177
|
+
>>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
|
|
178
|
+
>>> out = model(input_tensor)
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
----
|
|
182
|
+
pretrained: boolean, True if model is pretrained
|
|
183
|
+
**kwargs: keyword arguments of the TextNet architecture
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
-------
|
|
187
|
+
A TextNet small model
|
|
188
|
+
"""
|
|
189
|
+
return _textnet(
|
|
190
|
+
"textnet_small",
|
|
191
|
+
pretrained,
|
|
192
|
+
stages=[
|
|
193
|
+
{"in_channels": [64] * 2, "out_channels": [64] * 2, "kernel_size": [(3, 3)] * 2, "stride": [1, 2]},
|
|
194
|
+
{
|
|
195
|
+
"in_channels": [64, 128, 128, 128, 128, 128, 128, 128],
|
|
196
|
+
"out_channels": [128] * 8,
|
|
197
|
+
"kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1), (1, 3), (3, 3)],
|
|
198
|
+
"stride": [2, 1, 1, 1, 1, 1, 1, 1],
|
|
199
|
+
},
|
|
200
|
+
{
|
|
201
|
+
"in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
|
|
202
|
+
"out_channels": [256] * 8,
|
|
203
|
+
"kernel_size": [(3, 3), (3, 3), (1, 3), (3, 1), (3, 3), (1, 3), (3, 1), (3, 3)],
|
|
204
|
+
"stride": [2, 1, 1, 1, 1, 1, 1, 1],
|
|
205
|
+
},
|
|
206
|
+
{
|
|
207
|
+
"in_channels": [256, 512, 512, 512, 512],
|
|
208
|
+
"out_channels": [512] * 5,
|
|
209
|
+
"kernel_size": [(3, 3), (3, 1), (1, 3), (1, 3), (3, 1)],
|
|
210
|
+
"stride": [2, 1, 1, 1, 1],
|
|
211
|
+
},
|
|
212
|
+
],
|
|
213
|
+
**kwargs,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
218
|
+
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
|
|
219
|
+
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
|
|
220
|
+
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
|
|
221
|
+
|
|
222
|
+
>>> import tensorflow as tf
|
|
223
|
+
>>> from doctr.models import textnet_base
|
|
224
|
+
>>> model = textnet_base(pretrained=False)
|
|
225
|
+
>>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
|
|
226
|
+
>>> out = model(input_tensor)
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
----
|
|
230
|
+
pretrained: boolean, True if model is pretrained
|
|
231
|
+
**kwargs: keyword arguments of the TextNet architecture
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
-------
|
|
235
|
+
A TextNet base model
|
|
236
|
+
"""
|
|
237
|
+
return _textnet(
|
|
238
|
+
"textnet_base",
|
|
239
|
+
pretrained,
|
|
240
|
+
stages=[
|
|
241
|
+
{
|
|
242
|
+
"in_channels": [64] * 10,
|
|
243
|
+
"out_channels": [64] * 10,
|
|
244
|
+
"kernel_size": [(3, 3), (3, 3), (3, 1), (3, 3), (3, 1), (3, 3), (3, 3), (1, 3), (3, 3), (3, 3)],
|
|
245
|
+
"stride": [1, 2, 1, 1, 1, 1, 1, 1, 1, 1],
|
|
246
|
+
},
|
|
247
|
+
{
|
|
248
|
+
"in_channels": [64, 128, 128, 128, 128, 128, 128, 128, 128, 128],
|
|
249
|
+
"out_channels": [128] * 10,
|
|
250
|
+
"kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 3), (3, 1), (3, 1), (3, 3), (3, 3)],
|
|
251
|
+
"stride": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
|
252
|
+
},
|
|
253
|
+
{
|
|
254
|
+
"in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
|
|
255
|
+
"out_channels": [256] * 8,
|
|
256
|
+
"kernel_size": [(3, 3), (3, 3), (3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1)],
|
|
257
|
+
"stride": [2, 1, 1, 1, 1, 1, 1, 1],
|
|
258
|
+
},
|
|
259
|
+
{
|
|
260
|
+
"in_channels": [256, 512, 512, 512, 512],
|
|
261
|
+
"out_channels": [512] * 5,
|
|
262
|
+
"kernel_size": [(3, 3), (1, 3), (3, 1), (3, 1), (1, 3)],
|
|
263
|
+
"stride": [2, 1, 1, 1, 1],
|
|
264
|
+
},
|
|
265
|
+
],
|
|
266
|
+
**kwargs,
|
|
267
|
+
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -77,12 +77,14 @@ def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> tv_vgg.VGG:
|
|
|
77
77
|
>>> out = model(input_tensor)
|
|
78
78
|
|
|
79
79
|
Args:
|
|
80
|
+
----
|
|
80
81
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
82
|
+
**kwargs: keyword arguments of the VGG architecture
|
|
81
83
|
|
|
82
84
|
Returns:
|
|
85
|
+
-------
|
|
83
86
|
VGG feature extractor
|
|
84
87
|
"""
|
|
85
|
-
|
|
86
88
|
return _vgg(
|
|
87
89
|
"vgg16_bn_r",
|
|
88
90
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -32,6 +32,7 @@ class VGG(Sequential):
|
|
|
32
32
|
<https://arxiv.org/pdf/1409.1556.pdf>`_.
|
|
33
33
|
|
|
34
34
|
Args:
|
|
35
|
+
----
|
|
35
36
|
num_blocks: number of convolutional block in each stage
|
|
36
37
|
planes: number of output channels in each stage
|
|
37
38
|
rect_pools: whether pooling square kernels should be replace with rectangular ones
|
|
@@ -99,12 +100,14 @@ def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> VGG:
|
|
|
99
100
|
>>> out = model(input_tensor)
|
|
100
101
|
|
|
101
102
|
Args:
|
|
103
|
+
----
|
|
102
104
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
105
|
+
**kwargs: keyword arguments of the VGG architecture
|
|
103
106
|
|
|
104
107
|
Returns:
|
|
108
|
+
-------
|
|
105
109
|
VGG feature extractor
|
|
106
110
|
"""
|
|
107
|
-
|
|
108
111
|
return _vgg(
|
|
109
112
|
"vgg16_bn_r", pretrained, [2, 2, 3, 3, 3], [64, 128, 256, 512, 512], [False, False, True, True, True], **kwargs
|
|
110
113
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -40,6 +40,7 @@ class ClassifierHead(nn.Module):
|
|
|
40
40
|
"""Classifier head for Vision Transformer
|
|
41
41
|
|
|
42
42
|
Args:
|
|
43
|
+
----
|
|
43
44
|
in_channels: number of input channels
|
|
44
45
|
num_classes: number of output classes
|
|
45
46
|
"""
|
|
@@ -64,6 +65,7 @@ class VisionTransformer(nn.Sequential):
|
|
|
64
65
|
<https://arxiv.org/pdf/2010.11929.pdf>`_.
|
|
65
66
|
|
|
66
67
|
Args:
|
|
68
|
+
----
|
|
67
69
|
d_model: dimension of the transformer layers
|
|
68
70
|
num_layers: number of transformer layers
|
|
69
71
|
num_heads: number of attention heads
|
|
@@ -141,12 +143,14 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
|
|
|
141
143
|
>>> out = model(input_tensor)
|
|
142
144
|
|
|
143
145
|
Args:
|
|
146
|
+
----
|
|
144
147
|
pretrained: boolean, True if model is pretrained
|
|
148
|
+
**kwargs: keyword arguments of the VisionTransformer architecture
|
|
145
149
|
|
|
146
150
|
Returns:
|
|
151
|
+
-------
|
|
147
152
|
A feature extractor model
|
|
148
153
|
"""
|
|
149
|
-
|
|
150
154
|
return _vit(
|
|
151
155
|
"vit_s",
|
|
152
156
|
pretrained,
|
|
@@ -171,12 +175,14 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
|
|
|
171
175
|
>>> out = model(input_tensor)
|
|
172
176
|
|
|
173
177
|
Args:
|
|
178
|
+
----
|
|
174
179
|
pretrained: boolean, True if model is pretrained
|
|
180
|
+
**kwargs: keyword arguments of the VisionTransformer architecture
|
|
175
181
|
|
|
176
182
|
Returns:
|
|
183
|
+
-------
|
|
177
184
|
A feature extractor model
|
|
178
185
|
"""
|
|
179
|
-
|
|
180
186
|
return _vit(
|
|
181
187
|
"vit_b",
|
|
182
188
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -41,6 +41,7 @@ class ClassifierHead(layers.Layer, NestedObject):
|
|
|
41
41
|
"""Classifier head for Vision Transformer
|
|
42
42
|
|
|
43
43
|
Args:
|
|
44
|
+
----
|
|
44
45
|
num_classes: number of output classes
|
|
45
46
|
"""
|
|
46
47
|
|
|
@@ -60,6 +61,7 @@ class VisionTransformer(Sequential):
|
|
|
60
61
|
<https://arxiv.org/pdf/2010.11929.pdf>`_.
|
|
61
62
|
|
|
62
63
|
Args:
|
|
64
|
+
----
|
|
63
65
|
d_model: dimension of the transformer layers
|
|
64
66
|
num_layers: number of transformer layers
|
|
65
67
|
num_heads: number of attention heads
|
|
@@ -140,12 +142,14 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
|
|
|
140
142
|
>>> out = model(input_tensor)
|
|
141
143
|
|
|
142
144
|
Args:
|
|
145
|
+
----
|
|
143
146
|
pretrained: boolean, True if model is pretrained
|
|
147
|
+
**kwargs: keyword arguments of the VisionTransformer architecture
|
|
144
148
|
|
|
145
149
|
Returns:
|
|
150
|
+
-------
|
|
146
151
|
A feature extractor model
|
|
147
152
|
"""
|
|
148
|
-
|
|
149
153
|
return _vit(
|
|
150
154
|
"vit_s",
|
|
151
155
|
pretrained,
|
|
@@ -169,12 +173,14 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
|
|
|
169
173
|
>>> out = model(input_tensor)
|
|
170
174
|
|
|
171
175
|
Args:
|
|
176
|
+
----
|
|
172
177
|
pretrained: boolean, True if model is pretrained
|
|
178
|
+
**kwargs: keyword arguments of the VisionTransformer architecture
|
|
173
179
|
|
|
174
180
|
Returns:
|
|
181
|
+
-------
|
|
175
182
|
A feature extractor model
|
|
176
183
|
"""
|
|
177
|
-
|
|
178
184
|
return _vit(
|
|
179
185
|
"vit_b",
|
|
180
186
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -24,6 +24,9 @@ ARCHS: List[str] = [
|
|
|
24
24
|
"resnet34",
|
|
25
25
|
"resnet50",
|
|
26
26
|
"resnet34_wide",
|
|
27
|
+
"textnet_tiny",
|
|
28
|
+
"textnet_small",
|
|
29
|
+
"textnet_base",
|
|
27
30
|
"vgg16_bn_r",
|
|
28
31
|
"vit_s",
|
|
29
32
|
"vit_b",
|
|
@@ -59,11 +62,13 @@ def crop_orientation_predictor(
|
|
|
59
62
|
>>> out = model([input_crop])
|
|
60
63
|
|
|
61
64
|
Args:
|
|
65
|
+
----
|
|
62
66
|
arch: name of the architecture to use (e.g. 'mobilenet_v3_small')
|
|
63
67
|
pretrained: If True, returns a model pre-trained on our recognition crops dataset
|
|
68
|
+
**kwargs: keyword arguments to be passed to the CropOrientationPredictor
|
|
64
69
|
|
|
65
70
|
Returns:
|
|
71
|
+
-------
|
|
66
72
|
CropOrientationPredictor
|
|
67
73
|
"""
|
|
68
|
-
|
|
69
74
|
return _crop_orientation_predictor(arch, pretrained, **kwargs)
|
doctr/models/core.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -13,9 +13,12 @@ def erode(x: Tensor, kernel_size: int) -> Tensor:
|
|
|
13
13
|
"""Performs erosion on a given tensor
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
|
+
----
|
|
16
17
|
x: boolean tensor of shape (N, C, H, W)
|
|
17
18
|
kernel_size: the size of the kernel to use for erosion
|
|
19
|
+
|
|
18
20
|
Returns:
|
|
21
|
+
-------
|
|
19
22
|
the eroded tensor
|
|
20
23
|
"""
|
|
21
24
|
_pad = (kernel_size - 1) // 2
|
|
@@ -27,9 +30,12 @@ def dilate(x: Tensor, kernel_size: int) -> Tensor:
|
|
|
27
30
|
"""Performs dilation on a given tensor
|
|
28
31
|
|
|
29
32
|
Args:
|
|
33
|
+
----
|
|
30
34
|
x: boolean tensor of shape (N, C, H, W)
|
|
31
35
|
kernel_size: the size of the kernel to use for dilation
|
|
36
|
+
|
|
32
37
|
Returns:
|
|
38
|
+
-------
|
|
33
39
|
the dilated tensor
|
|
34
40
|
"""
|
|
35
41
|
_pad = (kernel_size - 1) // 2
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -12,12 +12,14 @@ def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
|
|
|
12
12
|
"""Performs erosion on a given tensor
|
|
13
13
|
|
|
14
14
|
Args:
|
|
15
|
+
----
|
|
15
16
|
x: boolean tensor of shape (N, H, W, C)
|
|
16
17
|
kernel_size: the size of the kernel to use for erosion
|
|
18
|
+
|
|
17
19
|
Returns:
|
|
20
|
+
-------
|
|
18
21
|
the eroded tensor
|
|
19
22
|
"""
|
|
20
|
-
|
|
21
23
|
return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME")
|
|
22
24
|
|
|
23
25
|
|
|
@@ -25,10 +27,12 @@ def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
|
|
|
25
27
|
"""Performs dilation on a given tensor
|
|
26
28
|
|
|
27
29
|
Args:
|
|
30
|
+
----
|
|
28
31
|
x: boolean tensor of shape (N, H, W, C)
|
|
29
32
|
kernel_size: the size of the kernel to use for dilation
|
|
33
|
+
|
|
30
34
|
Returns:
|
|
35
|
+
-------
|
|
31
36
|
the dilated tensor
|
|
32
37
|
"""
|
|
33
|
-
|
|
34
38
|
return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME")
|
doctr/models/detection/core.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -17,6 +17,7 @@ class DetectionPostProcessor(NestedObject):
|
|
|
17
17
|
"""Abstract class to postprocess the raw output of the model
|
|
18
18
|
|
|
19
19
|
Args:
|
|
20
|
+
----
|
|
20
21
|
box_thresh (float): minimal objectness score to consider a box
|
|
21
22
|
bin_thresh (float): threshold to apply to segmentation raw heatmap
|
|
22
23
|
assume straight_pages (bool): if True, fit straight boxes only
|
|
@@ -36,9 +37,13 @@ class DetectionPostProcessor(NestedObject):
|
|
|
36
37
|
"""Compute the confidence score for a polygon : mean of the p values on the polygon
|
|
37
38
|
|
|
38
39
|
Args:
|
|
40
|
+
----
|
|
39
41
|
pred (np.ndarray): p map returned by the model
|
|
42
|
+
points: coordinates of the polygon
|
|
43
|
+
assume_straight_pages: if True, fit straight boxes only
|
|
40
44
|
|
|
41
45
|
Returns:
|
|
46
|
+
-------
|
|
42
47
|
polygon objectness
|
|
43
48
|
"""
|
|
44
49
|
h, w = pred.shape[:2]
|
|
@@ -52,7 +57,7 @@ class DetectionPostProcessor(NestedObject):
|
|
|
52
57
|
|
|
53
58
|
else:
|
|
54
59
|
mask: np.ndarray = np.zeros((h, w), np.int32)
|
|
55
|
-
cv2.fillPoly(mask, [points.astype(np.int32)], 1.0)
|
|
60
|
+
cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload]
|
|
56
61
|
product = pred * mask
|
|
57
62
|
return np.sum(product) / np.count_nonzero(product)
|
|
58
63
|
|
|
@@ -70,13 +75,14 @@ class DetectionPostProcessor(NestedObject):
|
|
|
70
75
|
"""Performs postprocessing for a list of model outputs
|
|
71
76
|
|
|
72
77
|
Args:
|
|
78
|
+
----
|
|
73
79
|
proba_map: probability map of shape (N, H, W, C)
|
|
74
80
|
|
|
75
81
|
Returns:
|
|
82
|
+
-------
|
|
76
83
|
list of N class predictions (for each input sample), where each class predictions is a list of C tensors
|
|
77
84
|
of shape (*, 5) or (*, 6)
|
|
78
85
|
"""
|
|
79
|
-
|
|
80
86
|
if proba_map.ndim != 4:
|
|
81
87
|
raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.")
|
|
82
88
|
|