yomitoku 0.4.0.post1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yomitoku/__init__.py +20 -0
- yomitoku/base.py +136 -0
- yomitoku/cli/__init__.py +0 -0
- yomitoku/cli/main.py +230 -0
- yomitoku/configs/__init__.py +13 -0
- yomitoku/configs/cfg_layout_parser_rtdtrv2.py +89 -0
- yomitoku/configs/cfg_table_structure_recognizer_rtdtrv2.py +80 -0
- yomitoku/configs/cfg_text_detector_dbnet.py +49 -0
- yomitoku/configs/cfg_text_recognizer_parseq.py +51 -0
- yomitoku/constants.py +32 -0
- yomitoku/data/__init__.py +3 -0
- yomitoku/data/dataset.py +40 -0
- yomitoku/data/functions.py +279 -0
- yomitoku/document_analyzer.py +315 -0
- yomitoku/export/__init__.py +6 -0
- yomitoku/export/export_csv.py +71 -0
- yomitoku/export/export_html.py +188 -0
- yomitoku/export/export_json.py +34 -0
- yomitoku/export/export_markdown.py +145 -0
- yomitoku/layout_analyzer.py +66 -0
- yomitoku/layout_parser.py +189 -0
- yomitoku/models/__init__.py +9 -0
- yomitoku/models/dbnet_plus.py +272 -0
- yomitoku/models/layers/__init__.py +0 -0
- yomitoku/models/layers/activate.py +38 -0
- yomitoku/models/layers/dbnet_feature_attention.py +160 -0
- yomitoku/models/layers/parseq_transformer.py +218 -0
- yomitoku/models/layers/rtdetr_backbone.py +333 -0
- yomitoku/models/layers/rtdetr_hybrid_encoder.py +433 -0
- yomitoku/models/layers/rtdetrv2_decoder.py +811 -0
- yomitoku/models/parseq.py +243 -0
- yomitoku/models/rtdetr.py +22 -0
- yomitoku/ocr.py +87 -0
- yomitoku/postprocessor/__init__.py +9 -0
- yomitoku/postprocessor/dbnet_postporcessor.py +137 -0
- yomitoku/postprocessor/parseq_tokenizer.py +128 -0
- yomitoku/postprocessor/rtdetr_postprocessor.py +107 -0
- yomitoku/reading_order.py +214 -0
- yomitoku/resource/MPLUS1p-Medium.ttf +0 -0
- yomitoku/resource/charset.txt +1 -0
- yomitoku/table_structure_recognizer.py +244 -0
- yomitoku/text_detector.py +103 -0
- yomitoku/text_recognizer.py +128 -0
- yomitoku/utils/__init__.py +0 -0
- yomitoku/utils/graph.py +20 -0
- yomitoku/utils/logger.py +15 -0
- yomitoku/utils/misc.py +102 -0
- yomitoku/utils/visualizer.py +179 -0
- yomitoku-0.4.0.post1.dev0.dist-info/METADATA +127 -0
- yomitoku-0.4.0.post1.dev0.dist-info/RECORD +52 -0
- yomitoku-0.4.0.post1.dev0.dist-info/WHEEL +4 -0
- yomitoku-0.4.0.post1.dev0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,272 @@
|
|
1
|
+
from collections import OrderedDict
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
import torch.nn.functional as F
|
6
|
+
import torchvision
|
7
|
+
from huggingface_hub import PyTorchModelHubMixin
|
8
|
+
from torchvision.models._utils import IntermediateLayerGetter
|
9
|
+
|
10
|
+
from .layers.dbnet_feature_attention import ScaleFeatureSelection
|
11
|
+
|
12
|
+
|
13
|
+
class BackboneBase(nn.Module):
|
14
|
+
def __init__(self, backbone: nn.Module):
|
15
|
+
super().__init__()
|
16
|
+
return_layers = {
|
17
|
+
"layer1": "layer1",
|
18
|
+
"layer2": "layer2",
|
19
|
+
"layer3": "layer3",
|
20
|
+
"layer4": "layer4",
|
21
|
+
}
|
22
|
+
|
23
|
+
self.body = IntermediateLayerGetter(
|
24
|
+
backbone, return_layers=return_layers
|
25
|
+
)
|
26
|
+
|
27
|
+
def forward(self, tensor):
|
28
|
+
xs = self.body(tensor)
|
29
|
+
return xs
|
30
|
+
|
31
|
+
|
32
|
+
class Backbone(BackboneBase):
|
33
|
+
"""ResNet backbone with frozen BatchNorm."""
|
34
|
+
|
35
|
+
def __init__(self, name="resnet50", dilation=True):
|
36
|
+
backbone = getattr(torchvision.models, name)(
|
37
|
+
replace_stride_with_dilation=[False, False, dilation],
|
38
|
+
pretrained=False,
|
39
|
+
)
|
40
|
+
super().__init__(backbone)
|
41
|
+
|
42
|
+
|
43
|
+
class DBNetDecoder(nn.Module):
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
in_channels,
|
47
|
+
hidden_dim,
|
48
|
+
adaptive=False,
|
49
|
+
serial=False,
|
50
|
+
smooth=False,
|
51
|
+
k=50,
|
52
|
+
):
|
53
|
+
super().__init__()
|
54
|
+
self.d_model = hidden_dim
|
55
|
+
self.n_layers = len(in_channels)
|
56
|
+
self.k = k
|
57
|
+
self.training = True
|
58
|
+
self.input_proj = nn.ModuleDict(
|
59
|
+
{
|
60
|
+
"layer1": nn.Conv2d(
|
61
|
+
in_channels[0], self.d_model, 1, bias=False
|
62
|
+
),
|
63
|
+
"layer2": nn.Conv2d(
|
64
|
+
in_channels[1], self.d_model, 1, bias=False
|
65
|
+
),
|
66
|
+
"layer3": nn.Conv2d(
|
67
|
+
in_channels[2], self.d_model, 1, bias=False
|
68
|
+
),
|
69
|
+
"layer4": nn.Conv2d(
|
70
|
+
in_channels[3], self.d_model, 1, bias=False
|
71
|
+
),
|
72
|
+
}
|
73
|
+
)
|
74
|
+
|
75
|
+
self.upsample_2x = nn.Upsample(
|
76
|
+
scale_factor=2, mode="bilinear", align_corners=False
|
77
|
+
)
|
78
|
+
|
79
|
+
self.out_proj = nn.ModuleDict(
|
80
|
+
{
|
81
|
+
"layer1": nn.Conv2d(
|
82
|
+
self.d_model, self.d_model // 4, 3, padding=1, bias=False
|
83
|
+
),
|
84
|
+
"layer2": nn.Sequential(
|
85
|
+
nn.Conv2d(
|
86
|
+
self.d_model,
|
87
|
+
self.d_model // 4,
|
88
|
+
3,
|
89
|
+
padding=1,
|
90
|
+
bias=False,
|
91
|
+
),
|
92
|
+
nn.Upsample(
|
93
|
+
scale_factor=2, mode="bilinear", align_corners=False
|
94
|
+
),
|
95
|
+
),
|
96
|
+
"layer3": nn.Sequential(
|
97
|
+
nn.Conv2d(
|
98
|
+
self.d_model,
|
99
|
+
self.d_model // 4,
|
100
|
+
3,
|
101
|
+
padding=1,
|
102
|
+
bias=False,
|
103
|
+
),
|
104
|
+
nn.Upsample(
|
105
|
+
scale_factor=4, mode="bilinear", align_corners=False
|
106
|
+
),
|
107
|
+
),
|
108
|
+
"layer4": nn.Sequential(
|
109
|
+
nn.Conv2d(
|
110
|
+
self.d_model,
|
111
|
+
self.d_model // 4,
|
112
|
+
3,
|
113
|
+
padding=1,
|
114
|
+
bias=False,
|
115
|
+
),
|
116
|
+
nn.Upsample(
|
117
|
+
scale_factor=4, mode="bilinear", align_corners=False
|
118
|
+
),
|
119
|
+
),
|
120
|
+
}
|
121
|
+
)
|
122
|
+
|
123
|
+
self.binarize = nn.Sequential(
|
124
|
+
nn.Conv2d(
|
125
|
+
self.d_model, self.d_model // 4, 3, padding=1, bias=False
|
126
|
+
),
|
127
|
+
nn.BatchNorm2d(self.d_model // 4),
|
128
|
+
nn.ReLU(inplace=True),
|
129
|
+
nn.ConvTranspose2d(self.d_model // 4, self.d_model // 4, 2, 2),
|
130
|
+
nn.BatchNorm2d(self.d_model // 4),
|
131
|
+
nn.ReLU(inplace=True),
|
132
|
+
nn.ConvTranspose2d(self.d_model // 4, 1, 2, 2),
|
133
|
+
nn.Sigmoid(),
|
134
|
+
)
|
135
|
+
|
136
|
+
self.adaptive = adaptive
|
137
|
+
self.serial = serial
|
138
|
+
if self.adaptive:
|
139
|
+
self.thresh = self._init_thresh(
|
140
|
+
self.d_model,
|
141
|
+
serial=serial,
|
142
|
+
smooth=smooth,
|
143
|
+
bias=False,
|
144
|
+
)
|
145
|
+
self.thresh.apply(self.weights_init)
|
146
|
+
|
147
|
+
self.binarize.apply(self.weights_init)
|
148
|
+
|
149
|
+
for layer in self.input_proj.values():
|
150
|
+
layer.apply(self.weights_init)
|
151
|
+
|
152
|
+
for layer in self.out_proj.values():
|
153
|
+
layer.apply(self.weights_init)
|
154
|
+
|
155
|
+
self.concat_attention = ScaleFeatureSelection(
|
156
|
+
self.d_model,
|
157
|
+
self.d_model // 4,
|
158
|
+
attention_type="scale_channel_spatial",
|
159
|
+
)
|
160
|
+
|
161
|
+
def weights_init(self, m):
|
162
|
+
classname = m.__class__.__name__
|
163
|
+
if classname.find("Conv") != -1:
|
164
|
+
nn.init.kaiming_normal_(m.weight.data)
|
165
|
+
elif classname.find("BatchNorm") != -1:
|
166
|
+
m.weight.data.fill_(1.0)
|
167
|
+
m.bias.data.fill_(1e-4)
|
168
|
+
|
169
|
+
def _init_thresh(
|
170
|
+
self, inner_channels, serial=False, smooth=False, bias=False
|
171
|
+
):
|
172
|
+
in_channels = inner_channels
|
173
|
+
if serial:
|
174
|
+
in_channels += 1
|
175
|
+
self.thresh = nn.Sequential(
|
176
|
+
nn.Conv2d(
|
177
|
+
in_channels, inner_channels // 4, 3, padding=1, bias=bias
|
178
|
+
),
|
179
|
+
nn.BatchNorm2d(inner_channels // 4),
|
180
|
+
nn.ReLU(inplace=True),
|
181
|
+
self._init_upsample(
|
182
|
+
inner_channels // 4,
|
183
|
+
inner_channels // 4,
|
184
|
+
smooth=smooth,
|
185
|
+
bias=bias,
|
186
|
+
),
|
187
|
+
nn.BatchNorm2d(inner_channels // 4),
|
188
|
+
nn.ReLU(inplace=True),
|
189
|
+
self._init_upsample(
|
190
|
+
inner_channels // 4, 1, smooth=smooth, bias=bias
|
191
|
+
),
|
192
|
+
nn.Sigmoid(),
|
193
|
+
)
|
194
|
+
return self.thresh
|
195
|
+
|
196
|
+
def _init_upsample(
|
197
|
+
self, in_channels, out_channels, smooth=False, bias=False
|
198
|
+
):
|
199
|
+
if smooth:
|
200
|
+
inter_out_channels = out_channels
|
201
|
+
if out_channels == 1:
|
202
|
+
inter_out_channels = in_channels
|
203
|
+
module_list = [
|
204
|
+
nn.Upsample(scale_factor=2, mode="nearest"),
|
205
|
+
nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias),
|
206
|
+
]
|
207
|
+
if out_channels == 1:
|
208
|
+
module_list.append(
|
209
|
+
nn.Conv2d(
|
210
|
+
in_channels,
|
211
|
+
out_channels,
|
212
|
+
kernel_size=1,
|
213
|
+
stride=1,
|
214
|
+
padding=0,
|
215
|
+
bias=False,
|
216
|
+
)
|
217
|
+
)
|
218
|
+
|
219
|
+
return nn.Sequential(*module_list)
|
220
|
+
else:
|
221
|
+
return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
|
222
|
+
|
223
|
+
def step_function(self, x, y):
|
224
|
+
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
|
225
|
+
|
226
|
+
def forward(self, features):
|
227
|
+
for layer, feature in features.items():
|
228
|
+
features[layer] = self.input_proj[layer](feature)
|
229
|
+
|
230
|
+
layers = ["layer4", "layer3", "layer2", "layer1"]
|
231
|
+
for i in range(self.n_layers - 1):
|
232
|
+
feature_bottom = features[layers[i]]
|
233
|
+
feature_top = features[layers[i + 1]]
|
234
|
+
|
235
|
+
bh, bw = feature_bottom.shape[-2:]
|
236
|
+
th, tw = feature_top.shape[-2:]
|
237
|
+
|
238
|
+
if bh != th or bw != tw:
|
239
|
+
feature_bottom = F.interpolate(
|
240
|
+
feature_bottom,
|
241
|
+
size=(th, tw),
|
242
|
+
mode="bilinear",
|
243
|
+
align_corners=False,
|
244
|
+
)
|
245
|
+
|
246
|
+
features[layers[i + 1]] = feature_bottom + feature_top
|
247
|
+
|
248
|
+
fp = []
|
249
|
+
for layer, feature in features.items():
|
250
|
+
fp.append(self.out_proj[layer](feature))
|
251
|
+
fuse = torch.cat(fp[::-1], dim=1)
|
252
|
+
fuse = self.concat_attention(fuse, fp[::-1])
|
253
|
+
|
254
|
+
binary = self.binarize(fuse)
|
255
|
+
result = OrderedDict(binary=binary)
|
256
|
+
return result
|
257
|
+
|
258
|
+
|
259
|
+
class DBNet(nn.Module, PyTorchModelHubMixin):
|
260
|
+
def __init__(
|
261
|
+
self,
|
262
|
+
cfg,
|
263
|
+
):
|
264
|
+
super().__init__()
|
265
|
+
self.cfg = cfg
|
266
|
+
self.backbone = Backbone(**cfg.backbone)
|
267
|
+
self.decoder = DBNetDecoder(**cfg.decoder)
|
268
|
+
|
269
|
+
def forward(self, tensor):
|
270
|
+
features = self.backbone(tensor)
|
271
|
+
xs = self.decoder(features)
|
272
|
+
return xs
|
File without changes
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
|
3
|
+
|
4
|
+
def get_activation(act: str, inplace: bool = True):
|
5
|
+
"""get activation"""
|
6
|
+
if act is None:
|
7
|
+
return nn.Identity()
|
8
|
+
|
9
|
+
elif isinstance(act, nn.Module):
|
10
|
+
return act
|
11
|
+
|
12
|
+
act = act.lower()
|
13
|
+
|
14
|
+
if act == "silu" or act == "swish":
|
15
|
+
m = nn.SiLU()
|
16
|
+
|
17
|
+
elif act == "relu":
|
18
|
+
m = nn.ReLU()
|
19
|
+
|
20
|
+
elif act == "leaky_relu":
|
21
|
+
m = nn.LeakyReLU()
|
22
|
+
|
23
|
+
elif act == "silu":
|
24
|
+
m = nn.SiLU()
|
25
|
+
|
26
|
+
elif act == "gelu":
|
27
|
+
m = nn.GELU()
|
28
|
+
|
29
|
+
elif act == "hardsigmoid":
|
30
|
+
m = nn.Hardsigmoid()
|
31
|
+
|
32
|
+
else:
|
33
|
+
raise RuntimeError("")
|
34
|
+
|
35
|
+
if hasattr(m, "inplace"):
|
36
|
+
m.inplace = inplace
|
37
|
+
|
38
|
+
return m
|
@@ -0,0 +1,160 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
|
6
|
+
class ScaleChannelAttention(nn.Module):
|
7
|
+
def __init__(self, in_planes, out_planes, num_features, init_weight=True):
|
8
|
+
super(ScaleChannelAttention, self).__init__()
|
9
|
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
10
|
+
print(self.avgpool)
|
11
|
+
self.fc1 = nn.Conv2d(in_planes, out_planes, 1, bias=False)
|
12
|
+
self.bn = nn.BatchNorm2d(out_planes)
|
13
|
+
self.fc2 = nn.Conv2d(out_planes, num_features, 1, bias=False)
|
14
|
+
if init_weight:
|
15
|
+
self._initialize_weights()
|
16
|
+
|
17
|
+
def _initialize_weights(self):
|
18
|
+
for m in self.modules():
|
19
|
+
if isinstance(m, nn.Conv2d):
|
20
|
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
21
|
+
if m.bias is not None:
|
22
|
+
nn.init.constant_(m.bias, 0)
|
23
|
+
if isinstance(m, nn.BatchNorm2d):
|
24
|
+
nn.init.constant_(m.weight, 1)
|
25
|
+
nn.init.constant_(m.bias, 0)
|
26
|
+
|
27
|
+
def forward(self, x):
|
28
|
+
global_x = self.avgpool(x)
|
29
|
+
global_x = self.fc1(global_x)
|
30
|
+
global_x = F.relu(self.bn(global_x))
|
31
|
+
global_x = self.fc2(global_x)
|
32
|
+
global_x = F.softmax(global_x, 1)
|
33
|
+
return global_x
|
34
|
+
|
35
|
+
|
36
|
+
class ScaleChannelSpatialAttention(nn.Module):
|
37
|
+
def __init__(self, in_planes, out_planes, num_features, init_weight=True):
|
38
|
+
super(ScaleChannelSpatialAttention, self).__init__()
|
39
|
+
self.channel_wise = nn.Sequential(
|
40
|
+
nn.AdaptiveAvgPool2d(1),
|
41
|
+
nn.Conv2d(in_planes, out_planes, 1, bias=False),
|
42
|
+
# nn.BatchNorm2d(out_planes),
|
43
|
+
nn.ReLU(),
|
44
|
+
nn.Conv2d(out_planes, in_planes, 1, bias=False),
|
45
|
+
)
|
46
|
+
self.spatial_wise = nn.Sequential(
|
47
|
+
# Nx1xHxW
|
48
|
+
nn.Conv2d(1, 1, 3, bias=False, padding=1),
|
49
|
+
nn.ReLU(),
|
50
|
+
nn.Conv2d(1, 1, 1, bias=False),
|
51
|
+
nn.Sigmoid(),
|
52
|
+
)
|
53
|
+
self.attention_wise = nn.Sequential(
|
54
|
+
nn.Conv2d(in_planes, num_features, 1, bias=False), nn.Sigmoid()
|
55
|
+
)
|
56
|
+
if init_weight:
|
57
|
+
self._initialize_weights()
|
58
|
+
|
59
|
+
def _initialize_weights(self):
|
60
|
+
for m in self.modules():
|
61
|
+
if isinstance(m, nn.Conv2d):
|
62
|
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
63
|
+
if m.bias is not None:
|
64
|
+
nn.init.constant_(m.bias, 0)
|
65
|
+
if isinstance(m, nn.BatchNorm2d):
|
66
|
+
nn.init.constant_(m.weight, 1)
|
67
|
+
nn.init.constant_(m.bias, 0)
|
68
|
+
|
69
|
+
def forward(self, x):
|
70
|
+
# global_x = self.avgpool(x)
|
71
|
+
# shape Nx4x1x1
|
72
|
+
global_x = self.channel_wise(x).sigmoid()
|
73
|
+
# shape: NxCxHxW
|
74
|
+
global_x = global_x + x
|
75
|
+
# shape:Nx1xHxW
|
76
|
+
x = torch.mean(global_x, dim=1, keepdim=True)
|
77
|
+
global_x = self.spatial_wise(x) + global_x
|
78
|
+
global_x = self.attention_wise(global_x)
|
79
|
+
return global_x
|
80
|
+
|
81
|
+
|
82
|
+
class ScaleSpatialAttention(nn.Module):
|
83
|
+
def __init__(self, in_planes, out_planes, num_features, init_weight=True):
|
84
|
+
super(ScaleSpatialAttention, self).__init__()
|
85
|
+
self.spatial_wise = nn.Sequential(
|
86
|
+
# Nx1xHxW
|
87
|
+
nn.Conv2d(1, 1, 3, bias=False, padding=1),
|
88
|
+
nn.ReLU(),
|
89
|
+
nn.Conv2d(1, 1, 1, bias=False),
|
90
|
+
nn.Sigmoid(),
|
91
|
+
)
|
92
|
+
self.attention_wise = nn.Sequential(
|
93
|
+
nn.Conv2d(in_planes, num_features, 1, bias=False), nn.Sigmoid()
|
94
|
+
)
|
95
|
+
if init_weight:
|
96
|
+
self._initialize_weights()
|
97
|
+
|
98
|
+
def _initialize_weights(self):
|
99
|
+
for m in self.modules():
|
100
|
+
if isinstance(m, nn.Conv2d):
|
101
|
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
102
|
+
if m.bias is not None:
|
103
|
+
nn.init.constant_(m.bias, 0)
|
104
|
+
if isinstance(m, nn.BatchNorm2d):
|
105
|
+
nn.init.constant_(m.weight, 1)
|
106
|
+
nn.init.constant_(m.bias, 0)
|
107
|
+
|
108
|
+
def forward(self, x):
|
109
|
+
global_x = torch.mean(x, dim=1, keepdim=True)
|
110
|
+
global_x = self.spatial_wise(global_x) + x
|
111
|
+
global_x = self.attention_wise(global_x)
|
112
|
+
return global_x
|
113
|
+
|
114
|
+
|
115
|
+
class ScaleFeatureSelection(nn.Module):
|
116
|
+
def __init__(
|
117
|
+
self,
|
118
|
+
in_channels,
|
119
|
+
inter_channels,
|
120
|
+
out_features_num=4,
|
121
|
+
attention_type="scale_spatial",
|
122
|
+
):
|
123
|
+
super(ScaleFeatureSelection, self).__init__()
|
124
|
+
self.in_channels = in_channels
|
125
|
+
self.inter_channels = inter_channels
|
126
|
+
self.out_features_num = out_features_num
|
127
|
+
self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1)
|
128
|
+
self.type = attention_type
|
129
|
+
if self.type == "scale_spatial":
|
130
|
+
self.enhanced_attention = ScaleSpatialAttention(
|
131
|
+
inter_channels, inter_channels // 4, out_features_num
|
132
|
+
)
|
133
|
+
elif self.type == "scale_channel_spatial":
|
134
|
+
self.enhanced_attention = ScaleChannelSpatialAttention(
|
135
|
+
inter_channels, inter_channels // 4, out_features_num
|
136
|
+
)
|
137
|
+
elif self.type == "scale_channel":
|
138
|
+
self.enhanced_attention = ScaleChannelAttention(
|
139
|
+
inter_channels, inter_channels // 2, out_features_num
|
140
|
+
)
|
141
|
+
|
142
|
+
def _initialize_weights(self, m):
|
143
|
+
classname = m.__class__.__name__
|
144
|
+
if classname.find("Conv") != -1:
|
145
|
+
nn.init.kaiming_normal_(m.weight.data)
|
146
|
+
elif classname.find("BatchNorm") != -1:
|
147
|
+
m.weight.data.fill_(1.0)
|
148
|
+
m.bias.data.fill_(1e-4)
|
149
|
+
|
150
|
+
def forward(self, concat_x, features_list):
|
151
|
+
concat_x = self.conv(concat_x)
|
152
|
+
score = self.enhanced_attention(concat_x)
|
153
|
+
assert len(features_list) == self.out_features_num
|
154
|
+
if self.type not in ["scale_channel_spatial", "scale_spatial"]:
|
155
|
+
shape = features_list[0].shape[2:]
|
156
|
+
score = F.interpolate(score, size=shape, mode="bilinear")
|
157
|
+
x = []
|
158
|
+
for i in range(self.out_features_num):
|
159
|
+
x.append(score[:, i : i + 1] * features_list[i])
|
160
|
+
return torch.cat(x, dim=1)
|
@@ -0,0 +1,218 @@
|
|
1
|
+
# Scene Text Recognition Model Hub
|
2
|
+
# Copyright 2022 Darwin Bautista
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import math
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from timm.models.vision_transformer import PatchEmbed, VisionTransformer
|
21
|
+
from torch import Tensor
|
22
|
+
from torch import nn as nn
|
23
|
+
from torch.nn import functional as F
|
24
|
+
from torch.nn.modules import transformer
|
25
|
+
|
26
|
+
|
27
|
+
class DecoderLayer(nn.Module):
|
28
|
+
"""A Transformer decoder layer supporting two-stream attention (XLNet)
|
29
|
+
This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
embed_dim,
|
35
|
+
num_heads,
|
36
|
+
mlp_ratio=2048,
|
37
|
+
dropout=0.1,
|
38
|
+
activation="gelu",
|
39
|
+
layer_norm_eps=1e-5,
|
40
|
+
):
|
41
|
+
super().__init__()
|
42
|
+
dim_feedforward = embed_dim * mlp_ratio
|
43
|
+
self.self_attn = nn.MultiheadAttention(
|
44
|
+
embed_dim, num_heads, dropout=dropout, batch_first=True
|
45
|
+
)
|
46
|
+
self.cross_attn = nn.MultiheadAttention(
|
47
|
+
embed_dim, num_heads, dropout=dropout, batch_first=True
|
48
|
+
)
|
49
|
+
# Implementation of Feedforward model
|
50
|
+
self.linear1 = nn.Linear(embed_dim, dim_feedforward)
|
51
|
+
self.dropout = nn.Dropout(dropout)
|
52
|
+
self.linear2 = nn.Linear(dim_feedforward, embed_dim)
|
53
|
+
|
54
|
+
self.norm1 = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
55
|
+
self.norm2 = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
56
|
+
self.norm_q = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
57
|
+
self.norm_c = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
58
|
+
self.dropout1 = nn.Dropout(dropout)
|
59
|
+
self.dropout2 = nn.Dropout(dropout)
|
60
|
+
self.dropout3 = nn.Dropout(dropout)
|
61
|
+
|
62
|
+
self.activation = transformer._get_activation_fn(activation)
|
63
|
+
|
64
|
+
def __setstate__(self, state):
|
65
|
+
if "activation" not in state:
|
66
|
+
state["activation"] = F.gelu
|
67
|
+
super().__setstate__(state)
|
68
|
+
|
69
|
+
def forward_stream(
|
70
|
+
self,
|
71
|
+
tgt: Tensor,
|
72
|
+
tgt_norm: Tensor,
|
73
|
+
tgt_kv: Tensor,
|
74
|
+
memory: Tensor,
|
75
|
+
tgt_mask: Optional[Tensor],
|
76
|
+
tgt_key_padding_mask: Optional[Tensor],
|
77
|
+
):
|
78
|
+
"""Forward pass for a single stream (i.e. content or query)
|
79
|
+
tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
|
80
|
+
Both tgt_kv and memory are expected to be LayerNorm'd too.
|
81
|
+
memory is LayerNorm'd by ViT.
|
82
|
+
"""
|
83
|
+
tgt2, sa_weights = self.self_attn(
|
84
|
+
tgt_norm,
|
85
|
+
tgt_kv,
|
86
|
+
tgt_kv,
|
87
|
+
attn_mask=tgt_mask,
|
88
|
+
key_padding_mask=tgt_key_padding_mask,
|
89
|
+
)
|
90
|
+
tgt = tgt + self.dropout1(tgt2)
|
91
|
+
|
92
|
+
tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
|
93
|
+
tgt = tgt + self.dropout2(tgt2)
|
94
|
+
|
95
|
+
tgt2 = self.linear2(
|
96
|
+
self.dropout(self.activation(self.linear1(self.norm2(tgt))))
|
97
|
+
)
|
98
|
+
tgt = tgt + self.dropout3(tgt2)
|
99
|
+
return tgt, sa_weights, ca_weights
|
100
|
+
|
101
|
+
def forward(
|
102
|
+
self,
|
103
|
+
query,
|
104
|
+
content,
|
105
|
+
memory,
|
106
|
+
query_mask: Optional[Tensor] = None,
|
107
|
+
content_mask: Optional[Tensor] = None,
|
108
|
+
content_key_padding_mask: Optional[Tensor] = None,
|
109
|
+
update_content: bool = True,
|
110
|
+
):
|
111
|
+
query_norm = self.norm_q(query)
|
112
|
+
content_norm = self.norm_c(content)
|
113
|
+
query = self.forward_stream(
|
114
|
+
query,
|
115
|
+
query_norm,
|
116
|
+
content_norm,
|
117
|
+
memory,
|
118
|
+
query_mask,
|
119
|
+
content_key_padding_mask,
|
120
|
+
)[0]
|
121
|
+
if update_content:
|
122
|
+
content = self.forward_stream(
|
123
|
+
content,
|
124
|
+
content_norm,
|
125
|
+
content_norm,
|
126
|
+
memory,
|
127
|
+
content_mask,
|
128
|
+
content_key_padding_mask,
|
129
|
+
)[0]
|
130
|
+
return query, content
|
131
|
+
|
132
|
+
|
133
|
+
class Decoder(nn.Module):
|
134
|
+
__constants__ = ["norm"]
|
135
|
+
|
136
|
+
def __init__(self, norm, cfg):
|
137
|
+
super().__init__()
|
138
|
+
decoder_layer = DecoderLayer(
|
139
|
+
embed_dim=cfg.embed_dim,
|
140
|
+
num_heads=cfg.num_heads,
|
141
|
+
mlp_ratio=cfg.mlp_ratio,
|
142
|
+
)
|
143
|
+
|
144
|
+
self.layers = transformer._get_clones(decoder_layer, cfg.depth)
|
145
|
+
self.num_layers = cfg.depth
|
146
|
+
self.norm = norm
|
147
|
+
|
148
|
+
def forward(
|
149
|
+
self,
|
150
|
+
query,
|
151
|
+
content,
|
152
|
+
memory,
|
153
|
+
query_mask: Optional[Tensor] = None,
|
154
|
+
content_mask: Optional[Tensor] = None,
|
155
|
+
content_key_padding_mask: Optional[Tensor] = None,
|
156
|
+
):
|
157
|
+
for i, mod in enumerate(self.layers):
|
158
|
+
last = i == len(self.layers) - 1
|
159
|
+
query, content = mod(
|
160
|
+
query,
|
161
|
+
content,
|
162
|
+
memory,
|
163
|
+
query_mask,
|
164
|
+
content_mask,
|
165
|
+
content_key_padding_mask,
|
166
|
+
update_content=not last,
|
167
|
+
)
|
168
|
+
query = self.norm(query)
|
169
|
+
return query
|
170
|
+
|
171
|
+
|
172
|
+
class Encoder(VisionTransformer):
|
173
|
+
def __init__(
|
174
|
+
self,
|
175
|
+
img_size=224,
|
176
|
+
patch_size=16,
|
177
|
+
in_chans=3,
|
178
|
+
embed_dim=768,
|
179
|
+
depth=12,
|
180
|
+
num_heads=12,
|
181
|
+
mlp_ratio=4.0,
|
182
|
+
qkv_bias=True,
|
183
|
+
drop_rate=0.0,
|
184
|
+
attn_drop_rate=0.0,
|
185
|
+
drop_path_rate=0.0,
|
186
|
+
embed_layer=PatchEmbed,
|
187
|
+
):
|
188
|
+
super().__init__(
|
189
|
+
img_size,
|
190
|
+
patch_size,
|
191
|
+
in_chans,
|
192
|
+
embed_dim=embed_dim,
|
193
|
+
depth=depth,
|
194
|
+
num_heads=num_heads,
|
195
|
+
mlp_ratio=mlp_ratio,
|
196
|
+
qkv_bias=qkv_bias,
|
197
|
+
drop_rate=drop_rate,
|
198
|
+
attn_drop_rate=attn_drop_rate,
|
199
|
+
drop_path_rate=drop_path_rate,
|
200
|
+
embed_layer=embed_layer,
|
201
|
+
num_classes=0, # These
|
202
|
+
global_pool="", # disable the
|
203
|
+
class_token=False, # classifier head.
|
204
|
+
)
|
205
|
+
|
206
|
+
def forward(self, x):
|
207
|
+
# Return all tokens
|
208
|
+
return self.forward_features(x)
|
209
|
+
|
210
|
+
|
211
|
+
class TokenEmbedding(nn.Module):
|
212
|
+
def __init__(self, charset_size: int, embed_dim: int):
|
213
|
+
super().__init__()
|
214
|
+
self.embedding = nn.Embedding(charset_size, embed_dim)
|
215
|
+
self.embed_dim = embed_dim
|
216
|
+
|
217
|
+
def forward(self, tokens: torch.Tensor):
|
218
|
+
return math.sqrt(self.embed_dim) * self.embedding(tokens)
|