python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +1 -5
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1100 -54
- doctr/file_utils.py +2 -92
- doctr/io/elements.py +37 -3
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +4 -4
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -4
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +11 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vip/__init__.py +1 -0
- doctr/models/classification/vip/layers/__init__.py +1 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +12 -3
- doctr/models/classification/zoo.py +7 -8
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +7 -16
- doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +6 -17
- doctr/models/detection/fast/pytorch.py +17 -8
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +5 -15
- doctr/models/detection/linknet/pytorch.py +12 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +9 -22
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +52 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +16 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +15 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +26 -8
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +100 -47
- doctr/models/recognition/predictor/pytorch.py +4 -5
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +13 -4
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +1 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +13 -4
- doctr/models/recognition/zoo.py +13 -8
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +29 -19
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -433
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -397
- doctr/models/classification/textnet/tensorflow.py +0 -266
- doctr/models/classification/vgg/tensorflow.py +0 -116
- doctr/models/classification/vit/tensorflow.py +0 -192
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
- doctr/models/detection/fast/tensorflow.py +0 -419
- doctr/models/detection/linknet/tensorflow.py +0 -369
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -308
- doctr/models/recognition/master/tensorflow.py +0 -313
- doctr/models/recognition/parseq/tensorflow.py +0 -508
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -416
- doctr/models/recognition/vitstr/tensorflow.py +0 -278
- doctr/models/utils/tensorflow.py +0 -182
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.11.0.dist-info/RECORD +0 -173
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,505 @@
|
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
from doctr.datasets import VOCABS
|
|
13
|
+
from doctr.models.modules.layers import AdaptiveAvgPool2d
|
|
14
|
+
|
|
15
|
+
from ...utils import load_pretrained_params
|
|
16
|
+
from .layers import (
|
|
17
|
+
CrossShapedWindowAttention,
|
|
18
|
+
MultiHeadSelfAttention,
|
|
19
|
+
OSRABlock,
|
|
20
|
+
PatchEmbed,
|
|
21
|
+
PatchMerging,
|
|
22
|
+
PermuteLayer,
|
|
23
|
+
SqueezeLayer,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
__all__ = ["vip_tiny", "vip_base"]
|
|
27
|
+
|
|
28
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
29
|
+
"vip_tiny": {
|
|
30
|
+
"mean": (0.694, 0.695, 0.693),
|
|
31
|
+
"std": (0.299, 0.296, 0.301),
|
|
32
|
+
"input_shape": (3, 32, 32),
|
|
33
|
+
"classes": list(VOCABS["french"]),
|
|
34
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.11.0/vip_tiny-033ed51c.pt&src=0",
|
|
35
|
+
},
|
|
36
|
+
"vip_base": {
|
|
37
|
+
"mean": (0.694, 0.695, 0.693),
|
|
38
|
+
"std": (0.299, 0.296, 0.301),
|
|
39
|
+
"input_shape": (3, 32, 32),
|
|
40
|
+
"classes": list(VOCABS["french"]),
|
|
41
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.11.0/vip_base-f6ea2ff5.pt&src=0",
|
|
42
|
+
},
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ClassifierHead(nn.Module):
|
|
47
|
+
"""Classification head which averages the features and applies a linear layer."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, in_features: int, out_features: int):
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.fc = nn.Linear(in_features, out_features)
|
|
52
|
+
|
|
53
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
return self.fc(x.mean(dim=1))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class VIPBlock(nn.Module):
|
|
58
|
+
"""Unified block for Local, Global, and Mixed feature mixing in VIP architecture."""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
embed_dim: int,
|
|
63
|
+
local_unit: nn.ModuleList,
|
|
64
|
+
global_unit: nn.ModuleList | None = None,
|
|
65
|
+
proj: nn.Module | None = None,
|
|
66
|
+
downsample: bool = False,
|
|
67
|
+
out_dim: int | None = None,
|
|
68
|
+
):
|
|
69
|
+
"""
|
|
70
|
+
Args:
|
|
71
|
+
embed_dim: dimension of embeddings
|
|
72
|
+
local_unit: local mixing block(s)
|
|
73
|
+
global_unit: global mixing block(s)
|
|
74
|
+
proj: projection layer used for mixed mixing
|
|
75
|
+
downsample: whether to downsample at the end
|
|
76
|
+
out_dim: out channels if downsampling
|
|
77
|
+
"""
|
|
78
|
+
super().__init__()
|
|
79
|
+
if downsample and out_dim is None: # pragma: no cover
|
|
80
|
+
raise ValueError("`out_dim` must be specified if `downsample=True`")
|
|
81
|
+
|
|
82
|
+
self.local_unit = local_unit
|
|
83
|
+
self.global_unit = global_unit
|
|
84
|
+
self.proj = proj
|
|
85
|
+
self.downsample = PatchMerging(dim=embed_dim, out_dim=out_dim) if downsample else None # type: ignore[arg-type]
|
|
86
|
+
|
|
87
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
88
|
+
"""
|
|
89
|
+
Forward pass for VIPBlock.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
x: input tensor (B, H, W, C)
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Transformed tensor
|
|
96
|
+
"""
|
|
97
|
+
b, h, w, C = x.shape
|
|
98
|
+
|
|
99
|
+
# Local or Mixed
|
|
100
|
+
if self.global_unit is None:
|
|
101
|
+
# local or global only
|
|
102
|
+
for blk in self.local_unit:
|
|
103
|
+
# Flatten to (B, H*W, C)
|
|
104
|
+
x = x.reshape(b, -1, C)
|
|
105
|
+
x = blk(x, (h, w))
|
|
106
|
+
x = x.reshape(b, h, w, -1)
|
|
107
|
+
else:
|
|
108
|
+
# Mixed
|
|
109
|
+
for lblk, gblk in zip(self.local_unit, self.global_unit):
|
|
110
|
+
x = x.reshape(b, -1, C)
|
|
111
|
+
# chunk into two halves
|
|
112
|
+
x1, x2 = torch.chunk(x, chunks=2, dim=2)
|
|
113
|
+
x1 = lblk(x1, (h, w))
|
|
114
|
+
x2 = gblk(x2, (h, w))
|
|
115
|
+
x = torch.cat([x1, x2], dim=2)
|
|
116
|
+
x = x.transpose(1, 2).contiguous().reshape(b, -1, h, w)
|
|
117
|
+
x = self.proj(x) + x # type: ignore[misc]
|
|
118
|
+
x = x.permute(0, 2, 3, 1).contiguous()
|
|
119
|
+
|
|
120
|
+
if isinstance(self.downsample, nn.Module):
|
|
121
|
+
x = self.downsample(x)
|
|
122
|
+
|
|
123
|
+
return x
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class VIPNet(nn.Sequential):
|
|
127
|
+
"""
|
|
128
|
+
VIP (Vision Permutable) encoder architecture, adapted for text recognition.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
in_channels: int,
|
|
134
|
+
out_dim: int,
|
|
135
|
+
embed_dims: list[int],
|
|
136
|
+
depths: list[int],
|
|
137
|
+
num_heads: list[int],
|
|
138
|
+
mlp_ratios: list[int],
|
|
139
|
+
split_sizes: list[int],
|
|
140
|
+
sr_ratios: list[int],
|
|
141
|
+
input_shape: tuple[int, int, int] = (3, 32, 32),
|
|
142
|
+
num_classes: int = 1000,
|
|
143
|
+
include_top: bool = True,
|
|
144
|
+
cfg: dict[str, Any] | None = None,
|
|
145
|
+
) -> None:
|
|
146
|
+
"""
|
|
147
|
+
Args:
|
|
148
|
+
in_channels: number of input channels
|
|
149
|
+
out_dim: final embedding dimension
|
|
150
|
+
embed_dims: list of embedding dims per stage
|
|
151
|
+
depths: number of blocks per stage
|
|
152
|
+
num_heads: number of heads for attention blocks
|
|
153
|
+
mlp_ratios: ratio for MLP expansion
|
|
154
|
+
split_sizes: local window split sizes
|
|
155
|
+
sr_ratios: used for some global block adjustments
|
|
156
|
+
input_shape: (C, H, W)
|
|
157
|
+
num_classes: number of output classes
|
|
158
|
+
include_top: if True, append a classification head
|
|
159
|
+
cfg: optional config dictionary
|
|
160
|
+
"""
|
|
161
|
+
self.cfg = cfg
|
|
162
|
+
|
|
163
|
+
dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))]
|
|
164
|
+
drop_paths = [dpr[sum(depths[:i]) : sum(depths[: i + 1])] for i in range(len(depths))]
|
|
165
|
+
layers: list[Any] = [PatchEmbed(in_channels=in_channels, embed_dim=embed_dims[0])]
|
|
166
|
+
|
|
167
|
+
# Construct mixers
|
|
168
|
+
# e.g. local, mixed, global
|
|
169
|
+
mixer_functions = [
|
|
170
|
+
_vip_local_mixer,
|
|
171
|
+
_vip_mixed_mixer,
|
|
172
|
+
_vip_global_mha_mixer,
|
|
173
|
+
]
|
|
174
|
+
|
|
175
|
+
for i, mixer_fn in enumerate(mixer_functions):
|
|
176
|
+
embed_dim = embed_dims[i]
|
|
177
|
+
depth_i = depths[i]
|
|
178
|
+
num_head = num_heads[i]
|
|
179
|
+
mlp_ratio = mlp_ratios[i]
|
|
180
|
+
sp_size = split_sizes[i]
|
|
181
|
+
sr_ratio = sr_ratios[i]
|
|
182
|
+
drop_path = drop_paths[i]
|
|
183
|
+
|
|
184
|
+
next_dim = embed_dims[i + 1] if i < len(embed_dims) - 1 else None
|
|
185
|
+
|
|
186
|
+
block = mixer_fn(
|
|
187
|
+
embed_dim=embed_dim,
|
|
188
|
+
depth=depth_i,
|
|
189
|
+
num_heads=num_head,
|
|
190
|
+
mlp_ratio=mlp_ratio,
|
|
191
|
+
split_size=sp_size,
|
|
192
|
+
sr_ratio=sr_ratio,
|
|
193
|
+
drop_path=drop_path,
|
|
194
|
+
downsample=(next_dim is not None),
|
|
195
|
+
out_dim=next_dim,
|
|
196
|
+
)
|
|
197
|
+
layers.append(block)
|
|
198
|
+
|
|
199
|
+
# LN -> permute -> GAP -> squeeze -> MLP
|
|
200
|
+
layers.append(
|
|
201
|
+
nn.Sequential(
|
|
202
|
+
nn.LayerNorm(embed_dims[-1], eps=1e-6),
|
|
203
|
+
PermuteLayer((0, 2, 3, 1)),
|
|
204
|
+
AdaptiveAvgPool2d((embed_dims[-1], 1)),
|
|
205
|
+
SqueezeLayer(dim=3),
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
mlp_head = nn.Sequential(
|
|
210
|
+
nn.Linear(embed_dims[-1], out_dim, bias=False),
|
|
211
|
+
nn.Hardswish(),
|
|
212
|
+
nn.Dropout(p=0.1),
|
|
213
|
+
)
|
|
214
|
+
layers.append(mlp_head)
|
|
215
|
+
if include_top:
|
|
216
|
+
layers.append(ClassifierHead(out_dim, num_classes))
|
|
217
|
+
|
|
218
|
+
super().__init__(*layers)
|
|
219
|
+
|
|
220
|
+
self.apply(self._init_weights)
|
|
221
|
+
|
|
222
|
+
def _init_weights(self, m):
|
|
223
|
+
if isinstance(m, nn.Linear):
|
|
224
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
225
|
+
if m.bias is not None:
|
|
226
|
+
nn.init.constant_(m.bias, 0)
|
|
227
|
+
elif isinstance(m, nn.Conv2d):
|
|
228
|
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
229
|
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
|
230
|
+
nn.init.constant_(m.bias, 0)
|
|
231
|
+
nn.init.constant_(m.weight, 1.0)
|
|
232
|
+
|
|
233
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
234
|
+
"""Load pretrained parameters onto the model
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
238
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
239
|
+
"""
|
|
240
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def vip_tiny(pretrained: bool = False, **kwargs: Any) -> VIPNet:
|
|
244
|
+
"""
|
|
245
|
+
VIP-Tiny encoder architecture.Corresponds to SVIPTRv2-T variant in the paper (VIPTRv2 function
|
|
246
|
+
in the official implementation:
|
|
247
|
+
https://github.com/cxfyxl/VIPTR/blob/main/modules/VIPTRv2.py)
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
pretrained: whether to load pretrained weights
|
|
251
|
+
**kwargs: optional arguments
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
VIPNet model
|
|
255
|
+
"""
|
|
256
|
+
return _vip(
|
|
257
|
+
"vip_tiny",
|
|
258
|
+
pretrained,
|
|
259
|
+
in_channels=3,
|
|
260
|
+
out_dim=192,
|
|
261
|
+
embed_dims=[64, 128, 256],
|
|
262
|
+
depths=[3, 3, 3],
|
|
263
|
+
num_heads=[2, 4, 8],
|
|
264
|
+
mlp_ratios=[3, 4, 4],
|
|
265
|
+
split_sizes=[1, 2, 4],
|
|
266
|
+
sr_ratios=[4, 2, 2],
|
|
267
|
+
ignore_keys=["6.fc.weight", "6.fc.bias"],
|
|
268
|
+
**kwargs,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def vip_base(pretrained: bool = False, **kwargs: Any) -> VIPNet:
|
|
273
|
+
"""
|
|
274
|
+
VIP-Base encoder architecture. Corresponds to SVIPTRv2-B variant in the paper (VIPTRv2B function
|
|
275
|
+
in the official implementation:
|
|
276
|
+
https://github.com/cxfyxl/VIPTR/blob/main/modules/VIPTRv2.py)
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
pretrained: whether to load pretrained weights
|
|
280
|
+
**kwargs: optional arguments
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
VIPNet model
|
|
284
|
+
"""
|
|
285
|
+
return _vip(
|
|
286
|
+
"vip_base",
|
|
287
|
+
pretrained,
|
|
288
|
+
in_channels=3,
|
|
289
|
+
out_dim=256,
|
|
290
|
+
embed_dims=[128, 256, 384],
|
|
291
|
+
depths=[3, 6, 9],
|
|
292
|
+
num_heads=[4, 8, 12],
|
|
293
|
+
mlp_ratios=[4, 4, 4],
|
|
294
|
+
split_sizes=[1, 2, 4],
|
|
295
|
+
sr_ratios=[4, 2, 2],
|
|
296
|
+
ignore_keys=["6.fc.weight", "6.fc.bias"],
|
|
297
|
+
**kwargs,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _vip(
|
|
302
|
+
arch: str,
|
|
303
|
+
pretrained: bool,
|
|
304
|
+
ignore_keys: list[str],
|
|
305
|
+
**kwargs: Any,
|
|
306
|
+
) -> VIPNet:
|
|
307
|
+
"""
|
|
308
|
+
Internal constructor for the VIPNet models.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
arch: architecture key
|
|
312
|
+
pretrained: load pretrained weights?
|
|
313
|
+
ignore_keys: layer keys to ignore
|
|
314
|
+
**kwargs: arguments passed to VIPNet
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
VIPNet instance
|
|
318
|
+
"""
|
|
319
|
+
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
320
|
+
kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
|
|
321
|
+
kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
|
|
322
|
+
|
|
323
|
+
_cfg = deepcopy(default_cfgs[arch])
|
|
324
|
+
_cfg["num_classes"] = kwargs["num_classes"]
|
|
325
|
+
_cfg["input_shape"] = kwargs["input_shape"]
|
|
326
|
+
_cfg["classes"] = kwargs["classes"]
|
|
327
|
+
kwargs.pop("classes")
|
|
328
|
+
|
|
329
|
+
model = VIPNet(cfg=_cfg, **kwargs)
|
|
330
|
+
if pretrained:
|
|
331
|
+
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
332
|
+
# remove the last layer weights
|
|
333
|
+
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
334
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
335
|
+
return model
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
############################################
|
|
339
|
+
# _vip_local_mixer
|
|
340
|
+
############################################
|
|
341
|
+
def _vip_local_mixer(
|
|
342
|
+
embed_dim: int,
|
|
343
|
+
depth: int,
|
|
344
|
+
num_heads: int,
|
|
345
|
+
mlp_ratio: float,
|
|
346
|
+
drop_path: list[float],
|
|
347
|
+
split_size: int = 1,
|
|
348
|
+
sr_ratio: int = 1,
|
|
349
|
+
downsample: bool = False,
|
|
350
|
+
out_dim: int | None = None,
|
|
351
|
+
) -> nn.Module:
|
|
352
|
+
"""Builds a VIPBlock performing local (cross-shaped) window attention.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
embed_dim: embedding dimension.
|
|
356
|
+
depth: number of attention blocks in this stage.
|
|
357
|
+
num_heads: number of attention heads.
|
|
358
|
+
mlp_ratio: ratio used to expand the hidden dimension in MLP.
|
|
359
|
+
split_size: size of the local window splits.
|
|
360
|
+
sr_ratio: parameter needed for cross-compatibility between different mixers
|
|
361
|
+
drop_path: list of per-block drop path rates.
|
|
362
|
+
downsample: whether to apply PatchMerging at the end.
|
|
363
|
+
out_dim: output embedding dimension if downsampling.
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
A VIPBlock (local attention) for one stage of the VIP network.
|
|
367
|
+
"""
|
|
368
|
+
blocks = nn.ModuleList([
|
|
369
|
+
CrossShapedWindowAttention(
|
|
370
|
+
dim=embed_dim,
|
|
371
|
+
num_heads=num_heads,
|
|
372
|
+
mlp_ratio=mlp_ratio,
|
|
373
|
+
qkv_bias=True,
|
|
374
|
+
split_size=split_size,
|
|
375
|
+
drop_path=drop_path[i],
|
|
376
|
+
)
|
|
377
|
+
for i in range(depth)
|
|
378
|
+
])
|
|
379
|
+
return VIPBlock(embed_dim, local_unit=blocks, downsample=downsample, out_dim=out_dim)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
############################################
|
|
383
|
+
# _vip_global_mha_mixer
|
|
384
|
+
############################################
|
|
385
|
+
def _vip_global_mha_mixer(
|
|
386
|
+
embed_dim: int,
|
|
387
|
+
depth: int,
|
|
388
|
+
num_heads: int,
|
|
389
|
+
mlp_ratio: float,
|
|
390
|
+
drop_path: list[float],
|
|
391
|
+
split_size: int = 1,
|
|
392
|
+
sr_ratio: int = 1,
|
|
393
|
+
downsample: bool = False,
|
|
394
|
+
out_dim: int | None = None,
|
|
395
|
+
) -> nn.Module:
|
|
396
|
+
"""Builds a VIPBlock performing global multi-head self-attention.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
embed_dim: embedding dimension.
|
|
400
|
+
depth: number of attention blocks in this stage.
|
|
401
|
+
num_heads: number of attention heads.
|
|
402
|
+
mlp_ratio: ratio used to expand the hidden dimension in MLP.
|
|
403
|
+
drop_path: list of per-block drop path rates.
|
|
404
|
+
split_size: parameter needed for cross-compatibility between different mixers
|
|
405
|
+
sr_ratio: parameter needed for cross-compatibility between different mixers
|
|
406
|
+
downsample: whether to apply PatchMerging at the end.
|
|
407
|
+
out_dim: output embedding dimension if downsampling.
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
A VIPBlock (global MHA) for one stage of the VIP network.
|
|
411
|
+
"""
|
|
412
|
+
blocks = nn.ModuleList([
|
|
413
|
+
MultiHeadSelfAttention(
|
|
414
|
+
dim=embed_dim,
|
|
415
|
+
num_heads=num_heads,
|
|
416
|
+
mlp_ratio=mlp_ratio,
|
|
417
|
+
qkv_bias=True,
|
|
418
|
+
drop_path_rate=drop_path[i],
|
|
419
|
+
)
|
|
420
|
+
for i in range(depth)
|
|
421
|
+
])
|
|
422
|
+
return VIPBlock(
|
|
423
|
+
embed_dim,
|
|
424
|
+
local_unit=blocks, # In this context, they are "global" blocks but stored in local_unit
|
|
425
|
+
downsample=downsample,
|
|
426
|
+
out_dim=out_dim,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
############################################
|
|
431
|
+
# _vip_mixed_mixer
|
|
432
|
+
############################################
|
|
433
|
+
def _vip_mixed_mixer(
|
|
434
|
+
embed_dim: int,
|
|
435
|
+
depth: int,
|
|
436
|
+
num_heads: int,
|
|
437
|
+
mlp_ratio: float,
|
|
438
|
+
drop_path: list[float],
|
|
439
|
+
split_size: int = 1,
|
|
440
|
+
sr_ratio: int = 1,
|
|
441
|
+
downsample: bool = False,
|
|
442
|
+
out_dim: int | None = None,
|
|
443
|
+
) -> nn.Module:
|
|
444
|
+
"""Builds a VIPBlock performing mixed local+global attention.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
embed_dim: embedding dimension.
|
|
448
|
+
depth: number of attention blocks in this stage.
|
|
449
|
+
num_heads: total number of attention heads.
|
|
450
|
+
mlp_ratio: ratio used to expand the hidden dimension in MLP.
|
|
451
|
+
drop_path: list of per-block drop path rates.
|
|
452
|
+
split_size: size of the local window splits (for the local half).
|
|
453
|
+
sr_ratio: reduce spatial resolution in the global half (OSRA).
|
|
454
|
+
downsample: whether to apply PatchMerging at the end.
|
|
455
|
+
out_dim: output embedding dimension if downsampling.
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
A VIPBlock (mixed local+global) for one stage of the VIP network.
|
|
459
|
+
"""
|
|
460
|
+
# an inner dimension for the conv-projection
|
|
461
|
+
inner_dim = max(16, embed_dim // 8)
|
|
462
|
+
proj = nn.Sequential(
|
|
463
|
+
nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, groups=embed_dim),
|
|
464
|
+
nn.GELU(),
|
|
465
|
+
nn.BatchNorm2d(embed_dim),
|
|
466
|
+
nn.Conv2d(embed_dim, inner_dim, kernel_size=1),
|
|
467
|
+
nn.GELU(),
|
|
468
|
+
nn.BatchNorm2d(inner_dim),
|
|
469
|
+
nn.Conv2d(inner_dim, embed_dim, kernel_size=1),
|
|
470
|
+
nn.BatchNorm2d(embed_dim),
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# local half blocks
|
|
474
|
+
local_unit = nn.ModuleList([
|
|
475
|
+
CrossShapedWindowAttention(
|
|
476
|
+
dim=embed_dim // 2,
|
|
477
|
+
num_heads=num_heads,
|
|
478
|
+
mlp_ratio=mlp_ratio,
|
|
479
|
+
qkv_bias=True,
|
|
480
|
+
split_size=split_size,
|
|
481
|
+
drop_path=drop_path[i],
|
|
482
|
+
)
|
|
483
|
+
for i in range(depth)
|
|
484
|
+
])
|
|
485
|
+
|
|
486
|
+
# global half blocks
|
|
487
|
+
global_unit = nn.ModuleList([
|
|
488
|
+
OSRABlock(
|
|
489
|
+
dim=embed_dim // 2,
|
|
490
|
+
sr_ratio=sr_ratio,
|
|
491
|
+
num_heads=num_heads // 2,
|
|
492
|
+
mlp_ratio=mlp_ratio,
|
|
493
|
+
drop_path=drop_path[i],
|
|
494
|
+
)
|
|
495
|
+
for i in range(depth)
|
|
496
|
+
])
|
|
497
|
+
|
|
498
|
+
return VIPBlock(
|
|
499
|
+
embed_dim,
|
|
500
|
+
local_unit=local_unit,
|
|
501
|
+
global_unit=global_unit,
|
|
502
|
+
proj=proj,
|
|
503
|
+
downsample=downsample,
|
|
504
|
+
out_dim=out_dim,
|
|
505
|
+
)
|
|
@@ -11,9 +11,9 @@ from torch import nn
|
|
|
11
11
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
from doctr.models.modules.transformer import EncoderBlock
|
|
14
|
-
from doctr.models.modules.vision_transformer
|
|
14
|
+
from doctr.models.modules.vision_transformer import PatchEmbedding
|
|
15
15
|
|
|
16
|
-
from ...utils
|
|
16
|
+
from ...utils import load_pretrained_params
|
|
17
17
|
|
|
18
18
|
__all__ = ["vit_s", "vit_b"]
|
|
19
19
|
|
|
@@ -98,6 +98,15 @@ class VisionTransformer(nn.Sequential):
|
|
|
98
98
|
super().__init__(*_layers)
|
|
99
99
|
self.cfg = cfg
|
|
100
100
|
|
|
101
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
102
|
+
"""Load pretrained parameters onto the model
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
106
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
107
|
+
"""
|
|
108
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
109
|
+
|
|
101
110
|
|
|
102
111
|
def _vit(
|
|
103
112
|
arch: str,
|
|
@@ -122,7 +131,7 @@ def _vit(
|
|
|
122
131
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
123
132
|
# remove the last layer weights
|
|
124
133
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
125
|
-
|
|
134
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
126
135
|
|
|
127
136
|
return model
|
|
128
137
|
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.
|
|
8
|
+
from doctr.models.utils import _CompiledModule
|
|
9
9
|
|
|
10
10
|
from .. import classification
|
|
11
11
|
from ..preprocessor import PreProcessor
|
|
@@ -30,7 +30,10 @@ ARCHS: list[str] = [
|
|
|
30
30
|
"vgg16_bn_r",
|
|
31
31
|
"vit_s",
|
|
32
32
|
"vit_b",
|
|
33
|
+
"vip_tiny",
|
|
34
|
+
"vip_base",
|
|
33
35
|
]
|
|
36
|
+
|
|
34
37
|
ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
|
|
35
38
|
|
|
36
39
|
|
|
@@ -48,12 +51,8 @@ def _orientation_predictor(
|
|
|
48
51
|
# Load directly classifier from backbone
|
|
49
52
|
_model = classification.__dict__[arch](pretrained=pretrained)
|
|
50
53
|
else:
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
# Adding the type for torch compiled models to the allowed architectures
|
|
54
|
-
from doctr.models.utils import _CompiledModule
|
|
55
|
-
|
|
56
|
-
allowed_archs.append(_CompiledModule)
|
|
54
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
55
|
+
allowed_archs = [classification.MobileNetV3, _CompiledModule]
|
|
57
56
|
|
|
58
57
|
if not isinstance(arch, tuple(allowed_archs)):
|
|
59
58
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
@@ -62,7 +61,7 @@ def _orientation_predictor(
|
|
|
62
61
|
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
63
62
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
64
63
|
kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
|
|
65
|
-
input_shape = _model.cfg["input_shape"][
|
|
64
|
+
input_shape = _model.cfg["input_shape"][1:]
|
|
66
65
|
predictor = OrientationPredictor(
|
|
67
66
|
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
|
|
68
67
|
)
|
doctr/models/detection/core.py
CHANGED
|
@@ -53,7 +53,7 @@ class DetectionPostProcessor(NestedObject):
|
|
|
53
53
|
|
|
54
54
|
else:
|
|
55
55
|
mask: np.ndarray = np.zeros((h, w), np.int32)
|
|
56
|
-
cv2.fillPoly(mask, [points.astype(np.int32)], 1.0)
|
|
56
|
+
cv2.fillPoly(mask, [points.astype(np.int32)], 1.0)
|
|
57
57
|
product = pred * mask
|
|
58
58
|
return np.sum(product) / np.count_nonzero(product)
|
|
59
59
|
|
|
@@ -58,9 +58,8 @@ class DBPostProcessor(DetectionPostProcessor):
|
|
|
58
58
|
area = (rect[1][0] + 1) * (1 + rect[1][1])
|
|
59
59
|
length = 2 * (rect[1][0] + rect[1][1]) + 2
|
|
60
60
|
else:
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
length = poly.length
|
|
61
|
+
area = cv2.contourArea(points)
|
|
62
|
+
length = cv2.arcLength(points, closed=True)
|
|
64
63
|
distance = area * self.unclip_ratio / length # compute distance to expand polygon
|
|
65
64
|
offset = pyclipper.PyclipperOffset()
|
|
66
65
|
offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
@@ -206,7 +205,7 @@ class _DBNet:
|
|
|
206
205
|
canvas: np.ndarray,
|
|
207
206
|
mask: np.ndarray,
|
|
208
207
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
209
|
-
"""Draw a polygon
|
|
208
|
+
"""Draw a polygon threshold map on a canvas, as described in the DB paper
|
|
210
209
|
|
|
211
210
|
Args:
|
|
212
211
|
polygon : array of coord., to draw the boundary of the polygon
|
|
@@ -225,7 +224,7 @@ class _DBNet:
|
|
|
225
224
|
padded_polygon: np.ndarray = np.array(padding.Execute(distance)[0])
|
|
226
225
|
|
|
227
226
|
# Fill the mask with 1 on the new padded polygon
|
|
228
|
-
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
|
|
227
|
+
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
|
|
229
228
|
|
|
230
229
|
# Get min/max to recover polygon after distance computation
|
|
231
230
|
xmin = padded_polygon[:, 0].min()
|
|
@@ -270,7 +269,6 @@ class _DBNet:
|
|
|
270
269
|
self,
|
|
271
270
|
target: list[dict[str, np.ndarray]],
|
|
272
271
|
output_shape: tuple[int, int, int],
|
|
273
|
-
channels_last: bool = True,
|
|
274
272
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
275
273
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|
|
276
274
|
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
|
|
@@ -281,10 +279,8 @@ class _DBNet:
|
|
|
281
279
|
|
|
282
280
|
h: int
|
|
283
281
|
w: int
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
else:
|
|
287
|
-
num_classes, h, w = output_shape
|
|
282
|
+
|
|
283
|
+
num_classes, h, w = output_shape
|
|
288
284
|
target_shape = (len(target), num_classes, h, w)
|
|
289
285
|
|
|
290
286
|
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
@@ -344,17 +340,12 @@ class _DBNet:
|
|
|
344
340
|
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
|
|
345
341
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
346
342
|
continue
|
|
347
|
-
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
343
|
+
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
348
344
|
|
|
349
345
|
# Draw on both thresh map and thresh mask
|
|
350
346
|
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
|
|
351
347
|
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx]
|
|
352
348
|
)
|
|
353
|
-
if channels_last:
|
|
354
|
-
seg_target = seg_target.transpose((0, 2, 3, 1))
|
|
355
|
-
seg_mask = seg_mask.transpose((0, 2, 3, 1))
|
|
356
|
-
thresh_target = thresh_target.transpose((0, 2, 3, 1))
|
|
357
|
-
thresh_mask = thresh_mask.transpose((0, 2, 3, 1))
|
|
358
349
|
|
|
359
350
|
thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min
|
|
360
351
|
|