python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +2 -0
- doctr/datasets/cord.py +6 -4
- doctr/datasets/datasets/base.py +3 -2
- doctr/datasets/datasets/pytorch.py +4 -2
- doctr/datasets/datasets/tensorflow.py +4 -2
- doctr/datasets/detection.py +6 -3
- doctr/datasets/doc_artefacts.py +2 -1
- doctr/datasets/funsd.py +7 -8
- doctr/datasets/generator/base.py +3 -2
- doctr/datasets/generator/pytorch.py +3 -1
- doctr/datasets/generator/tensorflow.py +3 -1
- doctr/datasets/ic03.py +3 -2
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +6 -4
- doctr/datasets/iiithws.py +2 -1
- doctr/datasets/imgur5k.py +3 -2
- doctr/datasets/loader.py +4 -2
- doctr/datasets/mjsynth.py +2 -1
- doctr/datasets/ocr.py +2 -1
- doctr/datasets/orientation.py +40 -0
- doctr/datasets/recognition.py +3 -2
- doctr/datasets/sroie.py +2 -1
- doctr/datasets/svhn.py +2 -1
- doctr/datasets/svt.py +3 -2
- doctr/datasets/synthtext.py +2 -1
- doctr/datasets/utils.py +27 -11
- doctr/datasets/vocabs.py +26 -1
- doctr/datasets/wildreceipt.py +111 -0
- doctr/file_utils.py +3 -1
- doctr/io/elements.py +52 -35
- doctr/io/html.py +5 -3
- doctr/io/image/base.py +5 -4
- doctr/io/image/pytorch.py +12 -7
- doctr/io/image/tensorflow.py +11 -6
- doctr/io/pdf.py +5 -4
- doctr/io/reader.py +13 -5
- doctr/models/_utils.py +30 -53
- doctr/models/artefacts/barcode.py +4 -3
- doctr/models/artefacts/face.py +4 -2
- doctr/models/builder.py +58 -43
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +5 -2
- doctr/models/classification/magc_resnet/tensorflow.py +5 -2
- doctr/models/classification/mobilenet/pytorch.py +16 -4
- doctr/models/classification/mobilenet/tensorflow.py +29 -20
- doctr/models/classification/predictor/pytorch.py +3 -2
- doctr/models/classification/predictor/tensorflow.py +2 -1
- doctr/models/classification/resnet/pytorch.py +23 -13
- doctr/models/classification/resnet/tensorflow.py +33 -26
- doctr/models/classification/textnet/__init__.py +6 -0
- doctr/models/classification/textnet/pytorch.py +275 -0
- doctr/models/classification/textnet/tensorflow.py +267 -0
- doctr/models/classification/vgg/pytorch.py +4 -2
- doctr/models/classification/vgg/tensorflow.py +5 -2
- doctr/models/classification/vit/pytorch.py +9 -3
- doctr/models/classification/vit/tensorflow.py +9 -3
- doctr/models/classification/zoo.py +7 -2
- doctr/models/core.py +1 -1
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/pytorch.py +7 -1
- doctr/models/detection/_utils/tensorflow.py +7 -3
- doctr/models/detection/core.py +9 -3
- doctr/models/detection/differentiable_binarization/base.py +37 -25
- doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
- doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +256 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +12 -5
- doctr/models/detection/linknet/pytorch.py +28 -15
- doctr/models/detection/linknet/tensorflow.py +68 -88
- doctr/models/detection/predictor/pytorch.py +16 -6
- doctr/models/detection/predictor/tensorflow.py +13 -5
- doctr/models/detection/zoo.py +19 -16
- doctr/models/factory/hub.py +20 -10
- doctr/models/kie_predictor/base.py +2 -1
- doctr/models/kie_predictor/pytorch.py +28 -36
- doctr/models/kie_predictor/tensorflow.py +27 -27
- doctr/models/modules/__init__.py +1 -0
- doctr/models/modules/layers/__init__.py +6 -0
- doctr/models/modules/layers/pytorch.py +166 -0
- doctr/models/modules/layers/tensorflow.py +175 -0
- doctr/models/modules/transformer/pytorch.py +24 -22
- doctr/models/modules/transformer/tensorflow.py +6 -4
- doctr/models/modules/vision_transformer/pytorch.py +2 -4
- doctr/models/modules/vision_transformer/tensorflow.py +2 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
- doctr/models/predictor/base.py +14 -3
- doctr/models/predictor/pytorch.py +26 -29
- doctr/models/predictor/tensorflow.py +25 -22
- doctr/models/preprocessor/pytorch.py +14 -9
- doctr/models/preprocessor/tensorflow.py +10 -5
- doctr/models/recognition/core.py +4 -1
- doctr/models/recognition/crnn/pytorch.py +23 -16
- doctr/models/recognition/crnn/tensorflow.py +25 -17
- doctr/models/recognition/master/base.py +4 -1
- doctr/models/recognition/master/pytorch.py +20 -9
- doctr/models/recognition/master/tensorflow.py +20 -8
- doctr/models/recognition/parseq/base.py +4 -1
- doctr/models/recognition/parseq/pytorch.py +28 -22
- doctr/models/recognition/parseq/tensorflow.py +22 -11
- doctr/models/recognition/predictor/_utils.py +3 -2
- doctr/models/recognition/predictor/pytorch.py +3 -2
- doctr/models/recognition/predictor/tensorflow.py +2 -1
- doctr/models/recognition/sar/pytorch.py +14 -7
- doctr/models/recognition/sar/tensorflow.py +23 -14
- doctr/models/recognition/utils.py +5 -1
- doctr/models/recognition/vitstr/base.py +4 -1
- doctr/models/recognition/vitstr/pytorch.py +22 -13
- doctr/models/recognition/vitstr/tensorflow.py +21 -10
- doctr/models/recognition/zoo.py +4 -2
- doctr/models/utils/pytorch.py +24 -6
- doctr/models/utils/tensorflow.py +22 -3
- doctr/models/zoo.py +21 -3
- doctr/transforms/functional/base.py +8 -3
- doctr/transforms/functional/pytorch.py +23 -6
- doctr/transforms/functional/tensorflow.py +25 -5
- doctr/transforms/modules/base.py +12 -5
- doctr/transforms/modules/pytorch.py +10 -12
- doctr/transforms/modules/tensorflow.py +17 -9
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +4 -2
- doctr/utils/fonts.py +3 -2
- doctr/utils/geometry.py +95 -26
- doctr/utils/metrics.py +36 -22
- doctr/utils/multithreading.py +5 -3
- doctr/utils/repr.py +3 -1
- doctr/utils/visualization.py +31 -8
- doctr/version.py +1 -1
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
- python_doctr-0.8.1.dist-info/RECORD +173 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
- python_doctr-0.7.0.dist-info/RECORD +0 -161
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -61,6 +61,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
61
61
|
|
|
62
62
|
|
|
63
63
|
def resnet_stage(in_channels: int, out_channels: int, num_blocks: int, stride: int) -> List[nn.Module]:
|
|
64
|
+
"""Build a ResNet stage"""
|
|
64
65
|
_layers: List[nn.Module] = []
|
|
65
66
|
|
|
66
67
|
in_chan = in_channels
|
|
@@ -83,6 +84,7 @@ class ResNet(nn.Sequential):
|
|
|
83
84
|
Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
84
85
|
|
|
85
86
|
Args:
|
|
87
|
+
----
|
|
86
88
|
num_blocks: number of resnet block in each stage
|
|
87
89
|
output_channels: number of channels in each stage
|
|
88
90
|
stage_conv: whether to add a conv_sequence after each stage
|
|
@@ -134,13 +136,11 @@ class ResNet(nn.Sequential):
|
|
|
134
136
|
_layers.append(nn.Sequential(*_stage))
|
|
135
137
|
|
|
136
138
|
if include_top:
|
|
137
|
-
_layers.extend(
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
]
|
|
143
|
-
)
|
|
139
|
+
_layers.extend([
|
|
140
|
+
nn.AdaptiveAvgPool2d(1),
|
|
141
|
+
nn.Flatten(1),
|
|
142
|
+
nn.Linear(output_channels[-1], num_classes, bias=True),
|
|
143
|
+
])
|
|
144
144
|
|
|
145
145
|
super().__init__(*_layers)
|
|
146
146
|
self.cfg = cfg
|
|
@@ -224,12 +224,14 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> TVResNet:
|
|
|
224
224
|
>>> out = model(input_tensor)
|
|
225
225
|
|
|
226
226
|
Args:
|
|
227
|
+
----
|
|
227
228
|
pretrained: boolean, True if model is pretrained
|
|
229
|
+
**kwargs: keyword arguments of the ResNet architecture
|
|
228
230
|
|
|
229
231
|
Returns:
|
|
232
|
+
-------
|
|
230
233
|
A resnet18 model
|
|
231
234
|
"""
|
|
232
|
-
|
|
233
235
|
return _tv_resnet(
|
|
234
236
|
"resnet18",
|
|
235
237
|
pretrained,
|
|
@@ -251,12 +253,14 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
251
253
|
>>> out = model(input_tensor)
|
|
252
254
|
|
|
253
255
|
Args:
|
|
256
|
+
----
|
|
254
257
|
pretrained: boolean, True if model is pretrained
|
|
258
|
+
**kwargs: keyword arguments of the ResNet architecture
|
|
255
259
|
|
|
256
260
|
Returns:
|
|
261
|
+
-------
|
|
257
262
|
A resnet31 model
|
|
258
263
|
"""
|
|
259
|
-
|
|
260
264
|
return _resnet(
|
|
261
265
|
"resnet31",
|
|
262
266
|
pretrained,
|
|
@@ -283,12 +287,14 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> TVResNet:
|
|
|
283
287
|
>>> out = model(input_tensor)
|
|
284
288
|
|
|
285
289
|
Args:
|
|
290
|
+
----
|
|
286
291
|
pretrained: boolean, True if model is pretrained
|
|
292
|
+
**kwargs: keyword arguments of the ResNet architecture
|
|
287
293
|
|
|
288
294
|
Returns:
|
|
295
|
+
-------
|
|
289
296
|
A resnet34 model
|
|
290
297
|
"""
|
|
291
|
-
|
|
292
298
|
return _tv_resnet(
|
|
293
299
|
"resnet34",
|
|
294
300
|
pretrained,
|
|
@@ -309,12 +315,14 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
309
315
|
>>> out = model(input_tensor)
|
|
310
316
|
|
|
311
317
|
Args:
|
|
318
|
+
----
|
|
312
319
|
pretrained: boolean, True if model is pretrained
|
|
320
|
+
**kwargs: keyword arguments of the ResNet architecture
|
|
313
321
|
|
|
314
322
|
Returns:
|
|
323
|
+
-------
|
|
315
324
|
A resnet34_wide model
|
|
316
325
|
"""
|
|
317
|
-
|
|
318
326
|
return _resnet(
|
|
319
327
|
"resnet34_wide",
|
|
320
328
|
pretrained,
|
|
@@ -341,12 +349,14 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> TVResNet:
|
|
|
341
349
|
>>> out = model(input_tensor)
|
|
342
350
|
|
|
343
351
|
Args:
|
|
352
|
+
----
|
|
344
353
|
pretrained: boolean, True if model is pretrained
|
|
354
|
+
**kwargs: keyword arguments of the ResNet architecture
|
|
345
355
|
|
|
346
356
|
Returns:
|
|
357
|
+
-------
|
|
347
358
|
A resnet50 model
|
|
348
359
|
"""
|
|
349
|
-
|
|
350
360
|
return _tv_resnet(
|
|
351
361
|
"resnet50",
|
|
352
362
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -58,10 +58,10 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
58
58
|
|
|
59
59
|
|
|
60
60
|
class ResnetBlock(layers.Layer):
|
|
61
|
-
|
|
62
61
|
"""Implements a resnet31 block with shortcut
|
|
63
62
|
|
|
64
63
|
Args:
|
|
64
|
+
----
|
|
65
65
|
conv_shortcut: Use of shortcut
|
|
66
66
|
output_channels: number of channels to use in Conv2D
|
|
67
67
|
kernel_size: size of square kernels
|
|
@@ -71,19 +71,17 @@ class ResnetBlock(layers.Layer):
|
|
|
71
71
|
def __init__(self, output_channels: int, conv_shortcut: bool, strides: int = 1, **kwargs) -> None:
|
|
72
72
|
super().__init__(**kwargs)
|
|
73
73
|
if conv_shortcut:
|
|
74
|
-
self.shortcut = Sequential(
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
]
|
|
86
|
-
)
|
|
74
|
+
self.shortcut = Sequential([
|
|
75
|
+
layers.Conv2D(
|
|
76
|
+
filters=output_channels,
|
|
77
|
+
strides=strides,
|
|
78
|
+
padding="same",
|
|
79
|
+
kernel_size=1,
|
|
80
|
+
use_bias=False,
|
|
81
|
+
kernel_initializer="he_normal",
|
|
82
|
+
),
|
|
83
|
+
layers.BatchNormalization(),
|
|
84
|
+
])
|
|
87
85
|
else:
|
|
88
86
|
self.shortcut = layers.Lambda(lambda x: x)
|
|
89
87
|
self.conv_block = Sequential(self.conv_resnetblock(output_channels, 3, strides))
|
|
@@ -123,6 +121,7 @@ class ResNet(Sequential):
|
|
|
123
121
|
"""Implements a ResNet architecture
|
|
124
122
|
|
|
125
123
|
Args:
|
|
124
|
+
----
|
|
126
125
|
num_blocks: number of resnet block in each stage
|
|
127
126
|
output_channels: number of channels in each stage
|
|
128
127
|
stage_downsample: whether the first residual block of a stage should downsample
|
|
@@ -177,12 +176,10 @@ class ResNet(Sequential):
|
|
|
177
176
|
inplanes = out_chan
|
|
178
177
|
|
|
179
178
|
if include_top:
|
|
180
|
-
_layers.extend(
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
]
|
|
185
|
-
)
|
|
179
|
+
_layers.extend([
|
|
180
|
+
layers.GlobalAveragePooling2D(),
|
|
181
|
+
layers.Dense(num_classes),
|
|
182
|
+
])
|
|
186
183
|
|
|
187
184
|
super().__init__(_layers)
|
|
188
185
|
self.cfg = cfg
|
|
@@ -231,12 +228,14 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
231
228
|
>>> out = model(input_tensor)
|
|
232
229
|
|
|
233
230
|
Args:
|
|
231
|
+
----
|
|
234
232
|
pretrained: boolean, True if model is pretrained
|
|
233
|
+
**kwargs: keyword arguments of the ResNet architecture
|
|
235
234
|
|
|
236
235
|
Returns:
|
|
236
|
+
-------
|
|
237
237
|
A classification model
|
|
238
238
|
"""
|
|
239
|
-
|
|
240
239
|
return _resnet(
|
|
241
240
|
"resnet18",
|
|
242
241
|
pretrained,
|
|
@@ -262,12 +261,14 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
262
261
|
>>> out = model(input_tensor)
|
|
263
262
|
|
|
264
263
|
Args:
|
|
264
|
+
----
|
|
265
265
|
pretrained: boolean, True if model is pretrained
|
|
266
|
+
**kwargs: keyword arguments of the ResNet architecture
|
|
266
267
|
|
|
267
268
|
Returns:
|
|
269
|
+
-------
|
|
268
270
|
A classification model
|
|
269
271
|
"""
|
|
270
|
-
|
|
271
272
|
return _resnet(
|
|
272
273
|
"resnet31",
|
|
273
274
|
pretrained,
|
|
@@ -293,12 +294,14 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
293
294
|
>>> out = model(input_tensor)
|
|
294
295
|
|
|
295
296
|
Args:
|
|
297
|
+
----
|
|
296
298
|
pretrained: boolean, True if model is pretrained
|
|
299
|
+
**kwargs: keyword arguments of the ResNet architecture
|
|
297
300
|
|
|
298
301
|
Returns:
|
|
302
|
+
-------
|
|
299
303
|
A classification model
|
|
300
304
|
"""
|
|
301
|
-
|
|
302
305
|
return _resnet(
|
|
303
306
|
"resnet34",
|
|
304
307
|
pretrained,
|
|
@@ -323,12 +326,14 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
323
326
|
>>> out = model(input_tensor)
|
|
324
327
|
|
|
325
328
|
Args:
|
|
329
|
+
----
|
|
326
330
|
pretrained: boolean, True if model is pretrained
|
|
331
|
+
**kwargs: keyword arguments of the ResNet architecture
|
|
327
332
|
|
|
328
333
|
Returns:
|
|
334
|
+
-------
|
|
329
335
|
A classification model
|
|
330
336
|
"""
|
|
331
|
-
|
|
332
337
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs["resnet50"]["classes"]))
|
|
333
338
|
kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs["resnet50"]["input_shape"])
|
|
334
339
|
kwargs["classes"] = kwargs.get("classes", default_cfgs["resnet50"]["classes"])
|
|
@@ -368,12 +373,14 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
368
373
|
>>> out = model(input_tensor)
|
|
369
374
|
|
|
370
375
|
Args:
|
|
376
|
+
----
|
|
371
377
|
pretrained: boolean, True if model is pretrained
|
|
378
|
+
**kwargs: keyword arguments of the ResNet architecture
|
|
372
379
|
|
|
373
380
|
Returns:
|
|
381
|
+
-------
|
|
374
382
|
A classification model
|
|
375
383
|
"""
|
|
376
|
-
|
|
377
384
|
return _resnet(
|
|
378
385
|
"resnet34_wide",
|
|
379
386
|
pretrained,
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from doctr.datasets import VOCABS
|
|
13
|
+
|
|
14
|
+
from ...modules.layers.pytorch import FASTConvLayer
|
|
15
|
+
from ...utils import conv_sequence_pt, load_pretrained_params
|
|
16
|
+
|
|
17
|
+
__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
|
|
18
|
+
|
|
19
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
20
|
+
"textnet_tiny": {
|
|
21
|
+
"mean": (0.694, 0.695, 0.693),
|
|
22
|
+
"std": (0.299, 0.296, 0.301),
|
|
23
|
+
"input_shape": (3, 32, 32),
|
|
24
|
+
"classes": list(VOCABS["french"]),
|
|
25
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-c5970fe0.pt&src=0",
|
|
26
|
+
},
|
|
27
|
+
"textnet_small": {
|
|
28
|
+
"mean": (0.694, 0.695, 0.693),
|
|
29
|
+
"std": (0.299, 0.296, 0.301),
|
|
30
|
+
"input_shape": (3, 32, 32),
|
|
31
|
+
"classes": list(VOCABS["french"]),
|
|
32
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-6e8ab0ce.pt&src=0",
|
|
33
|
+
},
|
|
34
|
+
"textnet_base": {
|
|
35
|
+
"mean": (0.694, 0.695, 0.693),
|
|
36
|
+
"std": (0.299, 0.296, 0.301),
|
|
37
|
+
"input_shape": (3, 32, 32),
|
|
38
|
+
"classes": list(VOCABS["french"]),
|
|
39
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-8295dc85.pt&src=0",
|
|
40
|
+
},
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TextNet(nn.Sequential):
|
|
45
|
+
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
|
|
46
|
+
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
|
|
47
|
+
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
----
|
|
51
|
+
stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
|
|
52
|
+
include_top (bool, optional): Whether to include the classifier head. Defaults to True.
|
|
53
|
+
num_classes (int, optional): Number of output classes. Defaults to 1000.
|
|
54
|
+
cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
stages: List[Dict[str, List[int]]],
|
|
60
|
+
input_shape: Tuple[int, int, int] = (3, 32, 32),
|
|
61
|
+
num_classes: int = 1000,
|
|
62
|
+
include_top: bool = True,
|
|
63
|
+
cfg: Optional[Dict[str, Any]] = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
_layers: List[nn.Module] = [
|
|
66
|
+
*conv_sequence_pt(
|
|
67
|
+
in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2, padding=(1, 1)
|
|
68
|
+
),
|
|
69
|
+
*[
|
|
70
|
+
nn.Sequential(*[
|
|
71
|
+
FASTConvLayer(**params) # type: ignore[arg-type]
|
|
72
|
+
for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))]
|
|
73
|
+
])
|
|
74
|
+
for stage in stages
|
|
75
|
+
],
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
if include_top:
|
|
79
|
+
_layers.append(
|
|
80
|
+
nn.Sequential(
|
|
81
|
+
nn.AdaptiveAvgPool2d(1),
|
|
82
|
+
nn.Flatten(1),
|
|
83
|
+
nn.Linear(stages[-1]["out_channels"][-1], num_classes),
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
super().__init__(*_layers)
|
|
88
|
+
self.cfg = cfg
|
|
89
|
+
|
|
90
|
+
for m in self.modules():
|
|
91
|
+
if isinstance(m, nn.Conv2d):
|
|
92
|
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
93
|
+
elif isinstance(m, nn.BatchNorm2d):
|
|
94
|
+
nn.init.constant_(m.weight, 1)
|
|
95
|
+
nn.init.constant_(m.bias, 0)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _textnet(
|
|
99
|
+
arch: str,
|
|
100
|
+
pretrained: bool,
|
|
101
|
+
ignore_keys: Optional[List[str]] = None,
|
|
102
|
+
**kwargs: Any,
|
|
103
|
+
) -> TextNet:
|
|
104
|
+
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
105
|
+
kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
|
|
106
|
+
|
|
107
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
108
|
+
_cfg["num_classes"] = kwargs["num_classes"]
|
|
109
|
+
_cfg["classes"] = kwargs["classes"]
|
|
110
|
+
kwargs.pop("classes")
|
|
111
|
+
|
|
112
|
+
# Build the model
|
|
113
|
+
model = TextNet(**kwargs)
|
|
114
|
+
# Load pretrained parameters
|
|
115
|
+
if pretrained:
|
|
116
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
117
|
+
# remove the last layer weights
|
|
118
|
+
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
119
|
+
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
120
|
+
|
|
121
|
+
model.cfg = _cfg
|
|
122
|
+
|
|
123
|
+
return model
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
127
|
+
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
|
|
128
|
+
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
|
|
129
|
+
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
|
|
130
|
+
|
|
131
|
+
>>> import torch
|
|
132
|
+
>>> from doctr.models import textnet_tiny
|
|
133
|
+
>>> model = textnet_tiny(pretrained=False)
|
|
134
|
+
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
|
135
|
+
>>> out = model(input_tensor)
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
----
|
|
139
|
+
pretrained: boolean, True if model is pretrained
|
|
140
|
+
**kwargs: keyword arguments of the TextNet architecture
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
-------
|
|
144
|
+
A textnet tiny model
|
|
145
|
+
"""
|
|
146
|
+
return _textnet(
|
|
147
|
+
"textnet_tiny",
|
|
148
|
+
pretrained,
|
|
149
|
+
stages=[
|
|
150
|
+
{"in_channels": [64] * 3, "out_channels": [64] * 3, "kernel_size": [(3, 3)] * 3, "stride": [1, 2, 1]},
|
|
151
|
+
{
|
|
152
|
+
"in_channels": [64, 128, 128, 128],
|
|
153
|
+
"out_channels": [128] * 4,
|
|
154
|
+
"kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1)],
|
|
155
|
+
"stride": [2, 1, 1, 1],
|
|
156
|
+
},
|
|
157
|
+
{
|
|
158
|
+
"in_channels": [128, 256, 256, 256],
|
|
159
|
+
"out_channels": [256] * 4,
|
|
160
|
+
"kernel_size": [(3, 3), (3, 3), (3, 1), (1, 3)],
|
|
161
|
+
"stride": [2, 1, 1, 1],
|
|
162
|
+
},
|
|
163
|
+
{
|
|
164
|
+
"in_channels": [256, 512, 512, 512],
|
|
165
|
+
"out_channels": [512] * 4,
|
|
166
|
+
"kernel_size": [(3, 3), (3, 1), (1, 3), (3, 3)],
|
|
167
|
+
"stride": [2, 1, 1, 1],
|
|
168
|
+
},
|
|
169
|
+
],
|
|
170
|
+
ignore_keys=["7.2.weight", "7.2.bias"],
|
|
171
|
+
**kwargs,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
176
|
+
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
|
|
177
|
+
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
|
|
178
|
+
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
|
|
179
|
+
|
|
180
|
+
>>> import torch
|
|
181
|
+
>>> from doctr.models import textnet_small
|
|
182
|
+
>>> model = textnet_small(pretrained=False)
|
|
183
|
+
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
|
184
|
+
>>> out = model(input_tensor)
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
----
|
|
188
|
+
pretrained: boolean, True if model is pretrained
|
|
189
|
+
**kwargs: keyword arguments of the TextNet architecture
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
-------
|
|
193
|
+
A TextNet small model
|
|
194
|
+
"""
|
|
195
|
+
return _textnet(
|
|
196
|
+
"textnet_small",
|
|
197
|
+
pretrained,
|
|
198
|
+
stages=[
|
|
199
|
+
{"in_channels": [64] * 2, "out_channels": [64] * 2, "kernel_size": [(3, 3)] * 2, "stride": [1, 2]},
|
|
200
|
+
{
|
|
201
|
+
"in_channels": [64, 128, 128, 128, 128, 128, 128, 128],
|
|
202
|
+
"out_channels": [128] * 8,
|
|
203
|
+
"kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1), (1, 3), (3, 3)],
|
|
204
|
+
"stride": [2, 1, 1, 1, 1, 1, 1, 1],
|
|
205
|
+
},
|
|
206
|
+
{
|
|
207
|
+
"in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
|
|
208
|
+
"out_channels": [256] * 8,
|
|
209
|
+
"kernel_size": [(3, 3), (3, 3), (1, 3), (3, 1), (3, 3), (1, 3), (3, 1), (3, 3)],
|
|
210
|
+
"stride": [2, 1, 1, 1, 1, 1, 1, 1],
|
|
211
|
+
},
|
|
212
|
+
{
|
|
213
|
+
"in_channels": [256, 512, 512, 512, 512],
|
|
214
|
+
"out_channels": [512] * 5,
|
|
215
|
+
"kernel_size": [(3, 3), (3, 1), (1, 3), (1, 3), (3, 1)],
|
|
216
|
+
"stride": [2, 1, 1, 1, 1],
|
|
217
|
+
},
|
|
218
|
+
],
|
|
219
|
+
ignore_keys=["7.2.weight", "7.2.bias"],
|
|
220
|
+
**kwargs,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
225
|
+
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
|
|
226
|
+
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
|
|
227
|
+
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
|
|
228
|
+
|
|
229
|
+
>>> import torch
|
|
230
|
+
>>> from doctr.models import textnet_base
|
|
231
|
+
>>> model = textnet_base(pretrained=False)
|
|
232
|
+
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
|
233
|
+
>>> out = model(input_tensor)
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
----
|
|
237
|
+
pretrained: boolean, True if model is pretrained
|
|
238
|
+
**kwargs: keyword arguments of the TextNet architecture
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
-------
|
|
242
|
+
A TextNet base model
|
|
243
|
+
"""
|
|
244
|
+
return _textnet(
|
|
245
|
+
"textnet_base",
|
|
246
|
+
pretrained,
|
|
247
|
+
stages=[
|
|
248
|
+
{
|
|
249
|
+
"in_channels": [64] * 10,
|
|
250
|
+
"out_channels": [64] * 10,
|
|
251
|
+
"kernel_size": [(3, 3), (3, 3), (3, 1), (3, 3), (3, 1), (3, 3), (3, 3), (1, 3), (3, 3), (3, 3)],
|
|
252
|
+
"stride": [1, 2, 1, 1, 1, 1, 1, 1, 1, 1],
|
|
253
|
+
},
|
|
254
|
+
{
|
|
255
|
+
"in_channels": [64, 128, 128, 128, 128, 128, 128, 128, 128, 128],
|
|
256
|
+
"out_channels": [128] * 10,
|
|
257
|
+
"kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 3), (3, 1), (3, 1), (3, 3), (3, 3)],
|
|
258
|
+
"stride": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
|
259
|
+
},
|
|
260
|
+
{
|
|
261
|
+
"in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
|
|
262
|
+
"out_channels": [256] * 8,
|
|
263
|
+
"kernel_size": [(3, 3), (3, 3), (3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1)],
|
|
264
|
+
"stride": [2, 1, 1, 1, 1, 1, 1, 1],
|
|
265
|
+
},
|
|
266
|
+
{
|
|
267
|
+
"in_channels": [256, 512, 512, 512, 512],
|
|
268
|
+
"out_channels": [512] * 5,
|
|
269
|
+
"kernel_size": [(3, 3), (1, 3), (3, 1), (3, 1), (1, 3)],
|
|
270
|
+
"stride": [2, 1, 1, 1, 1],
|
|
271
|
+
},
|
|
272
|
+
],
|
|
273
|
+
ignore_keys=["7.2.weight", "7.2.bias"],
|
|
274
|
+
**kwargs,
|
|
275
|
+
)
|