python-doctr 0.12.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +0 -5
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/vocabs.py +0 -2
- doctr/file_utils.py +2 -101
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +3 -3
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +2 -2
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +1 -1
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +1 -1
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vip/__init__.py +1 -4
- doctr/models/classification/vip/layers/__init__.py +1 -4
- doctr/models/classification/vip/layers/pytorch.py +1 -1
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +2 -2
- doctr/models/classification/zoo.py +6 -11
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +4 -12
- doctr/models/detection/differentiable_binarization/pytorch.py +3 -3
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +4 -14
- doctr/models/detection/fast/pytorch.py +4 -4
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +3 -12
- doctr/models/detection/linknet/pytorch.py +2 -2
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +8 -21
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +2 -6
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +3 -3
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +2 -5
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +6 -6
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +5 -5
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +5 -5
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +7 -16
- doctr/models/recognition/predictor/pytorch.py +1 -2
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +3 -3
- doctr/models/recognition/viptr/__init__.py +1 -4
- doctr/models/recognition/viptr/pytorch.py +3 -3
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +3 -3
- doctr/models/recognition/zoo.py +13 -13
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +1 -1
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/geometry.py +6 -10
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +18 -75
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -442
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -418
- doctr/models/classification/textnet/tensorflow.py +0 -275
- doctr/models/classification/vgg/tensorflow.py +0 -125
- doctr/models/classification/vit/tensorflow.py +0 -201
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
- doctr/models/detection/fast/tensorflow.py +0 -427
- doctr/models/detection/linknet/tensorflow.py +0 -377
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -317
- doctr/models/recognition/master/tensorflow.py +0 -320
- doctr/models/recognition/parseq/tensorflow.py +0 -516
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -423
- doctr/models/recognition/vitstr/tensorflow.py +0 -285
- doctr/models/utils/tensorflow.py +0 -189
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.12.0.dist-info/RECORD +0 -180
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
|
@@ -1,125 +0,0 @@
|
|
|
1
|
-
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
-
|
|
3
|
-
# This program is licensed under the Apache License 2.0.
|
|
4
|
-
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
-
|
|
6
|
-
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
|
-
from tensorflow.keras import layers
|
|
10
|
-
from tensorflow.keras.models import Sequential
|
|
11
|
-
|
|
12
|
-
from doctr.datasets import VOCABS
|
|
13
|
-
|
|
14
|
-
from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
15
|
-
|
|
16
|
-
__all__ = ["VGG", "vgg16_bn_r"]
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
default_cfgs: dict[str, dict[str, Any]] = {
|
|
20
|
-
"vgg16_bn_r": {
|
|
21
|
-
"mean": (0.5, 0.5, 0.5),
|
|
22
|
-
"std": (1.0, 1.0, 1.0),
|
|
23
|
-
"input_shape": (32, 32, 3),
|
|
24
|
-
"classes": list(VOCABS["french"]),
|
|
25
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
|
|
26
|
-
},
|
|
27
|
-
}
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class VGG(Sequential):
|
|
31
|
-
"""Implements the VGG architecture from `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
|
|
32
|
-
<https://arxiv.org/pdf/1409.1556.pdf>`_.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
num_blocks: number of convolutional block in each stage
|
|
36
|
-
planes: number of output channels in each stage
|
|
37
|
-
rect_pools: whether pooling square kernels should be replace with rectangular ones
|
|
38
|
-
include_top: whether the classifier head should be instantiated
|
|
39
|
-
num_classes: number of output classes
|
|
40
|
-
input_shape: shapes of the input tensor
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
def __init__(
|
|
44
|
-
self,
|
|
45
|
-
num_blocks: list[int],
|
|
46
|
-
planes: list[int],
|
|
47
|
-
rect_pools: list[bool],
|
|
48
|
-
include_top: bool = False,
|
|
49
|
-
num_classes: int = 1000,
|
|
50
|
-
input_shape: tuple[int, int, int] | None = None,
|
|
51
|
-
cfg: dict[str, Any] | None = None,
|
|
52
|
-
) -> None:
|
|
53
|
-
_layers = []
|
|
54
|
-
# Specify input_shape only for the first layer
|
|
55
|
-
kwargs = {"input_shape": input_shape}
|
|
56
|
-
for nb_blocks, out_chan, rect_pool in zip(num_blocks, planes, rect_pools):
|
|
57
|
-
for _ in range(nb_blocks):
|
|
58
|
-
_layers.extend(conv_sequence(out_chan, "relu", True, kernel_size=3, **kwargs)) # type: ignore[arg-type]
|
|
59
|
-
kwargs = {}
|
|
60
|
-
_layers.append(layers.MaxPooling2D((2, 1 if rect_pool else 2)))
|
|
61
|
-
|
|
62
|
-
if include_top:
|
|
63
|
-
_layers.extend([layers.GlobalAveragePooling2D(), layers.Dense(num_classes)])
|
|
64
|
-
super().__init__(_layers)
|
|
65
|
-
self.cfg = cfg
|
|
66
|
-
|
|
67
|
-
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
68
|
-
"""Load pretrained parameters onto the model
|
|
69
|
-
|
|
70
|
-
Args:
|
|
71
|
-
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
72
|
-
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
73
|
-
"""
|
|
74
|
-
load_pretrained_params(self, path_or_url, **kwargs)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def _vgg(
|
|
78
|
-
arch: str, pretrained: bool, num_blocks: list[int], planes: list[int], rect_pools: list[bool], **kwargs: Any
|
|
79
|
-
) -> VGG:
|
|
80
|
-
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
81
|
-
kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
|
|
82
|
-
kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
|
|
83
|
-
|
|
84
|
-
_cfg = deepcopy(default_cfgs[arch])
|
|
85
|
-
_cfg["num_classes"] = kwargs["num_classes"]
|
|
86
|
-
_cfg["classes"] = kwargs["classes"]
|
|
87
|
-
_cfg["input_shape"] = kwargs["input_shape"]
|
|
88
|
-
kwargs.pop("classes")
|
|
89
|
-
|
|
90
|
-
# Build the model
|
|
91
|
-
model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
|
|
92
|
-
_build_model(model)
|
|
93
|
-
|
|
94
|
-
# Load pretrained parameters
|
|
95
|
-
if pretrained:
|
|
96
|
-
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
97
|
-
# skip the mismatching layers for fine tuning
|
|
98
|
-
model.from_pretrained(
|
|
99
|
-
default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
return model
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> VGG:
|
|
106
|
-
"""VGG-16 architecture as described in `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
|
|
107
|
-
<https://arxiv.org/pdf/1409.1556.pdf>`_, modified by adding batch normalization, rectangular pooling and a simpler
|
|
108
|
-
classification head.
|
|
109
|
-
|
|
110
|
-
>>> import tensorflow as tf
|
|
111
|
-
>>> from doctr.models import vgg16_bn_r
|
|
112
|
-
>>> model = vgg16_bn_r(pretrained=False)
|
|
113
|
-
>>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
|
|
114
|
-
>>> out = model(input_tensor)
|
|
115
|
-
|
|
116
|
-
Args:
|
|
117
|
-
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
118
|
-
**kwargs: keyword arguments of the VGG architecture
|
|
119
|
-
|
|
120
|
-
Returns:
|
|
121
|
-
VGG feature extractor
|
|
122
|
-
"""
|
|
123
|
-
return _vgg(
|
|
124
|
-
"vgg16_bn_r", pretrained, [2, 2, 3, 3, 3], [64, 128, 256, 512, 512], [False, False, True, True, True], **kwargs
|
|
125
|
-
)
|
|
@@ -1,201 +0,0 @@
|
|
|
1
|
-
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
-
|
|
3
|
-
# This program is licensed under the Apache License 2.0.
|
|
4
|
-
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
-
|
|
6
|
-
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
|
-
import tensorflow as tf
|
|
10
|
-
from tensorflow.keras import Sequential, layers
|
|
11
|
-
|
|
12
|
-
from doctr.datasets import VOCABS
|
|
13
|
-
from doctr.models.modules.transformer import EncoderBlock
|
|
14
|
-
from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding
|
|
15
|
-
from doctr.utils.repr import NestedObject
|
|
16
|
-
|
|
17
|
-
from ...utils import _build_model, load_pretrained_params
|
|
18
|
-
|
|
19
|
-
__all__ = ["vit_s", "vit_b"]
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
default_cfgs: dict[str, dict[str, Any]] = {
|
|
23
|
-
"vit_s": {
|
|
24
|
-
"mean": (0.694, 0.695, 0.693),
|
|
25
|
-
"std": (0.299, 0.296, 0.301),
|
|
26
|
-
"input_shape": (3, 32, 32),
|
|
27
|
-
"classes": list(VOCABS["french"]),
|
|
28
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0",
|
|
29
|
-
},
|
|
30
|
-
"vit_b": {
|
|
31
|
-
"mean": (0.694, 0.695, 0.693),
|
|
32
|
-
"std": (0.299, 0.296, 0.301),
|
|
33
|
-
"input_shape": (32, 32, 3),
|
|
34
|
-
"classes": list(VOCABS["french"]),
|
|
35
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0",
|
|
36
|
-
},
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class ClassifierHead(layers.Layer, NestedObject):
|
|
41
|
-
"""Classifier head for Vision Transformer
|
|
42
|
-
|
|
43
|
-
Args:
|
|
44
|
-
num_classes: number of output classes
|
|
45
|
-
"""
|
|
46
|
-
|
|
47
|
-
def __init__(self, num_classes: int) -> None:
|
|
48
|
-
super().__init__()
|
|
49
|
-
|
|
50
|
-
self.head = layers.Dense(num_classes, kernel_initializer="he_normal", name="dense")
|
|
51
|
-
|
|
52
|
-
def call(self, x: tf.Tensor) -> tf.Tensor:
|
|
53
|
-
# (batch_size, num_classes) cls token
|
|
54
|
-
return self.head(x[:, 0])
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
class VisionTransformer(Sequential):
|
|
58
|
-
"""VisionTransformer architecture as described in
|
|
59
|
-
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
|
|
60
|
-
<https://arxiv.org/pdf/2010.11929.pdf>`_.
|
|
61
|
-
|
|
62
|
-
Args:
|
|
63
|
-
d_model: dimension of the transformer layers
|
|
64
|
-
num_layers: number of transformer layers
|
|
65
|
-
num_heads: number of attention heads
|
|
66
|
-
ffd_ratio: multiplier for the hidden dimension of the feedforward layer
|
|
67
|
-
patch_size: size of the patches
|
|
68
|
-
input_shape: size of the input image
|
|
69
|
-
dropout: dropout rate
|
|
70
|
-
num_classes: number of output classes
|
|
71
|
-
include_top: whether the classifier head should be instantiated
|
|
72
|
-
"""
|
|
73
|
-
|
|
74
|
-
def __init__(
|
|
75
|
-
self,
|
|
76
|
-
d_model: int,
|
|
77
|
-
num_layers: int,
|
|
78
|
-
num_heads: int,
|
|
79
|
-
ffd_ratio: int,
|
|
80
|
-
patch_size: tuple[int, int] = (4, 4),
|
|
81
|
-
input_shape: tuple[int, int, int] = (32, 32, 3),
|
|
82
|
-
dropout: float = 0.0,
|
|
83
|
-
num_classes: int = 1000,
|
|
84
|
-
include_top: bool = True,
|
|
85
|
-
cfg: dict[str, Any] | None = None,
|
|
86
|
-
) -> None:
|
|
87
|
-
_layers = [
|
|
88
|
-
PatchEmbedding(input_shape, d_model, patch_size),
|
|
89
|
-
EncoderBlock(
|
|
90
|
-
num_layers,
|
|
91
|
-
num_heads,
|
|
92
|
-
d_model,
|
|
93
|
-
d_model * ffd_ratio,
|
|
94
|
-
dropout,
|
|
95
|
-
activation_fct=layers.Activation("gelu"),
|
|
96
|
-
),
|
|
97
|
-
]
|
|
98
|
-
if include_top:
|
|
99
|
-
_layers.append(ClassifierHead(num_classes))
|
|
100
|
-
|
|
101
|
-
super().__init__(_layers)
|
|
102
|
-
self.cfg = cfg
|
|
103
|
-
|
|
104
|
-
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
105
|
-
"""Load pretrained parameters onto the model
|
|
106
|
-
|
|
107
|
-
Args:
|
|
108
|
-
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
109
|
-
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
110
|
-
"""
|
|
111
|
-
load_pretrained_params(self, path_or_url, **kwargs)
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def _vit(
|
|
115
|
-
arch: str,
|
|
116
|
-
pretrained: bool,
|
|
117
|
-
**kwargs: Any,
|
|
118
|
-
) -> VisionTransformer:
|
|
119
|
-
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
120
|
-
kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
|
|
121
|
-
kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
|
|
122
|
-
|
|
123
|
-
_cfg = deepcopy(default_cfgs[arch])
|
|
124
|
-
_cfg["num_classes"] = kwargs["num_classes"]
|
|
125
|
-
_cfg["input_shape"] = kwargs["input_shape"]
|
|
126
|
-
_cfg["classes"] = kwargs["classes"]
|
|
127
|
-
kwargs.pop("classes")
|
|
128
|
-
|
|
129
|
-
# Build the model
|
|
130
|
-
model = VisionTransformer(cfg=_cfg, **kwargs)
|
|
131
|
-
_build_model(model)
|
|
132
|
-
|
|
133
|
-
# Load pretrained parameters
|
|
134
|
-
if pretrained:
|
|
135
|
-
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
136
|
-
# skip the mismatching layers for fine tuning
|
|
137
|
-
load_pretrained_params(
|
|
138
|
-
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
return model
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
|
|
145
|
-
"""VisionTransformer-S architecture
|
|
146
|
-
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
|
|
147
|
-
<https://arxiv.org/pdf/2010.11929.pdf>`_. Patches: (H, W) -> (H/8, W/8)
|
|
148
|
-
|
|
149
|
-
NOTE: unofficial config used in ViTSTR and ParSeq
|
|
150
|
-
|
|
151
|
-
>>> import tensorflow as tf
|
|
152
|
-
>>> from doctr.models import vit_s
|
|
153
|
-
>>> model = vit_s(pretrained=False)
|
|
154
|
-
>>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
|
|
155
|
-
>>> out = model(input_tensor)
|
|
156
|
-
|
|
157
|
-
Args:
|
|
158
|
-
pretrained: boolean, True if model is pretrained
|
|
159
|
-
**kwargs: keyword arguments of the VisionTransformer architecture
|
|
160
|
-
|
|
161
|
-
Returns:
|
|
162
|
-
A feature extractor model
|
|
163
|
-
"""
|
|
164
|
-
return _vit(
|
|
165
|
-
"vit_s",
|
|
166
|
-
pretrained,
|
|
167
|
-
d_model=384,
|
|
168
|
-
num_layers=12,
|
|
169
|
-
num_heads=6,
|
|
170
|
-
ffd_ratio=4,
|
|
171
|
-
**kwargs,
|
|
172
|
-
)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
|
|
176
|
-
"""VisionTransformer-B architecture as described in
|
|
177
|
-
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
|
|
178
|
-
<https://arxiv.org/pdf/2010.11929.pdf>`_. Patches: (H, W) -> (H/8, W/8)
|
|
179
|
-
|
|
180
|
-
>>> import tensorflow as tf
|
|
181
|
-
>>> from doctr.models import vit_b
|
|
182
|
-
>>> model = vit_b(pretrained=False)
|
|
183
|
-
>>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
|
|
184
|
-
>>> out = model(input_tensor)
|
|
185
|
-
|
|
186
|
-
Args:
|
|
187
|
-
pretrained: boolean, True if model is pretrained
|
|
188
|
-
**kwargs: keyword arguments of the VisionTransformer architecture
|
|
189
|
-
|
|
190
|
-
Returns:
|
|
191
|
-
A feature extractor model
|
|
192
|
-
"""
|
|
193
|
-
return _vit(
|
|
194
|
-
"vit_b",
|
|
195
|
-
pretrained,
|
|
196
|
-
d_model=768,
|
|
197
|
-
num_layers=12,
|
|
198
|
-
num_heads=12,
|
|
199
|
-
ffd_ratio=4,
|
|
200
|
-
**kwargs,
|
|
201
|
-
)
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
-
|
|
3
|
-
# This program is licensed under the Apache License 2.0.
|
|
4
|
-
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
-
|
|
6
|
-
import tensorflow as tf
|
|
7
|
-
|
|
8
|
-
__all__ = ["erode", "dilate"]
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
|
|
12
|
-
"""Performs erosion on a given tensor
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
x: boolean tensor of shape (N, H, W, C)
|
|
16
|
-
kernel_size: the size of the kernel to use for erosion
|
|
17
|
-
|
|
18
|
-
Returns:
|
|
19
|
-
the eroded tensor
|
|
20
|
-
"""
|
|
21
|
-
return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME")
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
|
|
25
|
-
"""Performs dilation on a given tensor
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
x: boolean tensor of shape (N, H, W, C)
|
|
29
|
-
kernel_size: the size of the kernel to use for dilation
|
|
30
|
-
|
|
31
|
-
Returns:
|
|
32
|
-
the dilated tensor
|
|
33
|
-
"""
|
|
34
|
-
return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME")
|