yomitoku 0.4.0.post1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. yomitoku/__init__.py +20 -0
  2. yomitoku/base.py +136 -0
  3. yomitoku/cli/__init__.py +0 -0
  4. yomitoku/cli/main.py +230 -0
  5. yomitoku/configs/__init__.py +13 -0
  6. yomitoku/configs/cfg_layout_parser_rtdtrv2.py +89 -0
  7. yomitoku/configs/cfg_table_structure_recognizer_rtdtrv2.py +80 -0
  8. yomitoku/configs/cfg_text_detector_dbnet.py +49 -0
  9. yomitoku/configs/cfg_text_recognizer_parseq.py +51 -0
  10. yomitoku/constants.py +32 -0
  11. yomitoku/data/__init__.py +3 -0
  12. yomitoku/data/dataset.py +40 -0
  13. yomitoku/data/functions.py +279 -0
  14. yomitoku/document_analyzer.py +315 -0
  15. yomitoku/export/__init__.py +6 -0
  16. yomitoku/export/export_csv.py +71 -0
  17. yomitoku/export/export_html.py +188 -0
  18. yomitoku/export/export_json.py +34 -0
  19. yomitoku/export/export_markdown.py +145 -0
  20. yomitoku/layout_analyzer.py +66 -0
  21. yomitoku/layout_parser.py +189 -0
  22. yomitoku/models/__init__.py +9 -0
  23. yomitoku/models/dbnet_plus.py +272 -0
  24. yomitoku/models/layers/__init__.py +0 -0
  25. yomitoku/models/layers/activate.py +38 -0
  26. yomitoku/models/layers/dbnet_feature_attention.py +160 -0
  27. yomitoku/models/layers/parseq_transformer.py +218 -0
  28. yomitoku/models/layers/rtdetr_backbone.py +333 -0
  29. yomitoku/models/layers/rtdetr_hybrid_encoder.py +433 -0
  30. yomitoku/models/layers/rtdetrv2_decoder.py +811 -0
  31. yomitoku/models/parseq.py +243 -0
  32. yomitoku/models/rtdetr.py +22 -0
  33. yomitoku/ocr.py +87 -0
  34. yomitoku/postprocessor/__init__.py +9 -0
  35. yomitoku/postprocessor/dbnet_postporcessor.py +137 -0
  36. yomitoku/postprocessor/parseq_tokenizer.py +128 -0
  37. yomitoku/postprocessor/rtdetr_postprocessor.py +107 -0
  38. yomitoku/reading_order.py +214 -0
  39. yomitoku/resource/MPLUS1p-Medium.ttf +0 -0
  40. yomitoku/resource/charset.txt +1 -0
  41. yomitoku/table_structure_recognizer.py +244 -0
  42. yomitoku/text_detector.py +103 -0
  43. yomitoku/text_recognizer.py +128 -0
  44. yomitoku/utils/__init__.py +0 -0
  45. yomitoku/utils/graph.py +20 -0
  46. yomitoku/utils/logger.py +15 -0
  47. yomitoku/utils/misc.py +102 -0
  48. yomitoku/utils/visualizer.py +179 -0
  49. yomitoku-0.4.0.post1.dev0.dist-info/METADATA +127 -0
  50. yomitoku-0.4.0.post1.dev0.dist-info/RECORD +52 -0
  51. yomitoku-0.4.0.post1.dev0.dist-info/WHEEL +4 -0
  52. yomitoku-0.4.0.post1.dev0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,272 @@
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+ from torchvision.models._utils import IntermediateLayerGetter
9
+
10
+ from .layers.dbnet_feature_attention import ScaleFeatureSelection
11
+
12
+
13
+ class BackboneBase(nn.Module):
14
+ def __init__(self, backbone: nn.Module):
15
+ super().__init__()
16
+ return_layers = {
17
+ "layer1": "layer1",
18
+ "layer2": "layer2",
19
+ "layer3": "layer3",
20
+ "layer4": "layer4",
21
+ }
22
+
23
+ self.body = IntermediateLayerGetter(
24
+ backbone, return_layers=return_layers
25
+ )
26
+
27
+ def forward(self, tensor):
28
+ xs = self.body(tensor)
29
+ return xs
30
+
31
+
32
+ class Backbone(BackboneBase):
33
+ """ResNet backbone with frozen BatchNorm."""
34
+
35
+ def __init__(self, name="resnet50", dilation=True):
36
+ backbone = getattr(torchvision.models, name)(
37
+ replace_stride_with_dilation=[False, False, dilation],
38
+ pretrained=False,
39
+ )
40
+ super().__init__(backbone)
41
+
42
+
43
+ class DBNetDecoder(nn.Module):
44
+ def __init__(
45
+ self,
46
+ in_channels,
47
+ hidden_dim,
48
+ adaptive=False,
49
+ serial=False,
50
+ smooth=False,
51
+ k=50,
52
+ ):
53
+ super().__init__()
54
+ self.d_model = hidden_dim
55
+ self.n_layers = len(in_channels)
56
+ self.k = k
57
+ self.training = True
58
+ self.input_proj = nn.ModuleDict(
59
+ {
60
+ "layer1": nn.Conv2d(
61
+ in_channels[0], self.d_model, 1, bias=False
62
+ ),
63
+ "layer2": nn.Conv2d(
64
+ in_channels[1], self.d_model, 1, bias=False
65
+ ),
66
+ "layer3": nn.Conv2d(
67
+ in_channels[2], self.d_model, 1, bias=False
68
+ ),
69
+ "layer4": nn.Conv2d(
70
+ in_channels[3], self.d_model, 1, bias=False
71
+ ),
72
+ }
73
+ )
74
+
75
+ self.upsample_2x = nn.Upsample(
76
+ scale_factor=2, mode="bilinear", align_corners=False
77
+ )
78
+
79
+ self.out_proj = nn.ModuleDict(
80
+ {
81
+ "layer1": nn.Conv2d(
82
+ self.d_model, self.d_model // 4, 3, padding=1, bias=False
83
+ ),
84
+ "layer2": nn.Sequential(
85
+ nn.Conv2d(
86
+ self.d_model,
87
+ self.d_model // 4,
88
+ 3,
89
+ padding=1,
90
+ bias=False,
91
+ ),
92
+ nn.Upsample(
93
+ scale_factor=2, mode="bilinear", align_corners=False
94
+ ),
95
+ ),
96
+ "layer3": nn.Sequential(
97
+ nn.Conv2d(
98
+ self.d_model,
99
+ self.d_model // 4,
100
+ 3,
101
+ padding=1,
102
+ bias=False,
103
+ ),
104
+ nn.Upsample(
105
+ scale_factor=4, mode="bilinear", align_corners=False
106
+ ),
107
+ ),
108
+ "layer4": nn.Sequential(
109
+ nn.Conv2d(
110
+ self.d_model,
111
+ self.d_model // 4,
112
+ 3,
113
+ padding=1,
114
+ bias=False,
115
+ ),
116
+ nn.Upsample(
117
+ scale_factor=4, mode="bilinear", align_corners=False
118
+ ),
119
+ ),
120
+ }
121
+ )
122
+
123
+ self.binarize = nn.Sequential(
124
+ nn.Conv2d(
125
+ self.d_model, self.d_model // 4, 3, padding=1, bias=False
126
+ ),
127
+ nn.BatchNorm2d(self.d_model // 4),
128
+ nn.ReLU(inplace=True),
129
+ nn.ConvTranspose2d(self.d_model // 4, self.d_model // 4, 2, 2),
130
+ nn.BatchNorm2d(self.d_model // 4),
131
+ nn.ReLU(inplace=True),
132
+ nn.ConvTranspose2d(self.d_model // 4, 1, 2, 2),
133
+ nn.Sigmoid(),
134
+ )
135
+
136
+ self.adaptive = adaptive
137
+ self.serial = serial
138
+ if self.adaptive:
139
+ self.thresh = self._init_thresh(
140
+ self.d_model,
141
+ serial=serial,
142
+ smooth=smooth,
143
+ bias=False,
144
+ )
145
+ self.thresh.apply(self.weights_init)
146
+
147
+ self.binarize.apply(self.weights_init)
148
+
149
+ for layer in self.input_proj.values():
150
+ layer.apply(self.weights_init)
151
+
152
+ for layer in self.out_proj.values():
153
+ layer.apply(self.weights_init)
154
+
155
+ self.concat_attention = ScaleFeatureSelection(
156
+ self.d_model,
157
+ self.d_model // 4,
158
+ attention_type="scale_channel_spatial",
159
+ )
160
+
161
+ def weights_init(self, m):
162
+ classname = m.__class__.__name__
163
+ if classname.find("Conv") != -1:
164
+ nn.init.kaiming_normal_(m.weight.data)
165
+ elif classname.find("BatchNorm") != -1:
166
+ m.weight.data.fill_(1.0)
167
+ m.bias.data.fill_(1e-4)
168
+
169
+ def _init_thresh(
170
+ self, inner_channels, serial=False, smooth=False, bias=False
171
+ ):
172
+ in_channels = inner_channels
173
+ if serial:
174
+ in_channels += 1
175
+ self.thresh = nn.Sequential(
176
+ nn.Conv2d(
177
+ in_channels, inner_channels // 4, 3, padding=1, bias=bias
178
+ ),
179
+ nn.BatchNorm2d(inner_channels // 4),
180
+ nn.ReLU(inplace=True),
181
+ self._init_upsample(
182
+ inner_channels // 4,
183
+ inner_channels // 4,
184
+ smooth=smooth,
185
+ bias=bias,
186
+ ),
187
+ nn.BatchNorm2d(inner_channels // 4),
188
+ nn.ReLU(inplace=True),
189
+ self._init_upsample(
190
+ inner_channels // 4, 1, smooth=smooth, bias=bias
191
+ ),
192
+ nn.Sigmoid(),
193
+ )
194
+ return self.thresh
195
+
196
+ def _init_upsample(
197
+ self, in_channels, out_channels, smooth=False, bias=False
198
+ ):
199
+ if smooth:
200
+ inter_out_channels = out_channels
201
+ if out_channels == 1:
202
+ inter_out_channels = in_channels
203
+ module_list = [
204
+ nn.Upsample(scale_factor=2, mode="nearest"),
205
+ nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias),
206
+ ]
207
+ if out_channels == 1:
208
+ module_list.append(
209
+ nn.Conv2d(
210
+ in_channels,
211
+ out_channels,
212
+ kernel_size=1,
213
+ stride=1,
214
+ padding=0,
215
+ bias=False,
216
+ )
217
+ )
218
+
219
+ return nn.Sequential(*module_list)
220
+ else:
221
+ return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
222
+
223
+ def step_function(self, x, y):
224
+ return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
225
+
226
+ def forward(self, features):
227
+ for layer, feature in features.items():
228
+ features[layer] = self.input_proj[layer](feature)
229
+
230
+ layers = ["layer4", "layer3", "layer2", "layer1"]
231
+ for i in range(self.n_layers - 1):
232
+ feature_bottom = features[layers[i]]
233
+ feature_top = features[layers[i + 1]]
234
+
235
+ bh, bw = feature_bottom.shape[-2:]
236
+ th, tw = feature_top.shape[-2:]
237
+
238
+ if bh != th or bw != tw:
239
+ feature_bottom = F.interpolate(
240
+ feature_bottom,
241
+ size=(th, tw),
242
+ mode="bilinear",
243
+ align_corners=False,
244
+ )
245
+
246
+ features[layers[i + 1]] = feature_bottom + feature_top
247
+
248
+ fp = []
249
+ for layer, feature in features.items():
250
+ fp.append(self.out_proj[layer](feature))
251
+ fuse = torch.cat(fp[::-1], dim=1)
252
+ fuse = self.concat_attention(fuse, fp[::-1])
253
+
254
+ binary = self.binarize(fuse)
255
+ result = OrderedDict(binary=binary)
256
+ return result
257
+
258
+
259
+ class DBNet(nn.Module, PyTorchModelHubMixin):
260
+ def __init__(
261
+ self,
262
+ cfg,
263
+ ):
264
+ super().__init__()
265
+ self.cfg = cfg
266
+ self.backbone = Backbone(**cfg.backbone)
267
+ self.decoder = DBNetDecoder(**cfg.decoder)
268
+
269
+ def forward(self, tensor):
270
+ features = self.backbone(tensor)
271
+ xs = self.decoder(features)
272
+ return xs
File without changes
@@ -0,0 +1,38 @@
1
+ import torch.nn as nn
2
+
3
+
4
+ def get_activation(act: str, inplace: bool = True):
5
+ """get activation"""
6
+ if act is None:
7
+ return nn.Identity()
8
+
9
+ elif isinstance(act, nn.Module):
10
+ return act
11
+
12
+ act = act.lower()
13
+
14
+ if act == "silu" or act == "swish":
15
+ m = nn.SiLU()
16
+
17
+ elif act == "relu":
18
+ m = nn.ReLU()
19
+
20
+ elif act == "leaky_relu":
21
+ m = nn.LeakyReLU()
22
+
23
+ elif act == "silu":
24
+ m = nn.SiLU()
25
+
26
+ elif act == "gelu":
27
+ m = nn.GELU()
28
+
29
+ elif act == "hardsigmoid":
30
+ m = nn.Hardsigmoid()
31
+
32
+ else:
33
+ raise RuntimeError("")
34
+
35
+ if hasattr(m, "inplace"):
36
+ m.inplace = inplace
37
+
38
+ return m
@@ -0,0 +1,160 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ScaleChannelAttention(nn.Module):
7
+ def __init__(self, in_planes, out_planes, num_features, init_weight=True):
8
+ super(ScaleChannelAttention, self).__init__()
9
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
10
+ print(self.avgpool)
11
+ self.fc1 = nn.Conv2d(in_planes, out_planes, 1, bias=False)
12
+ self.bn = nn.BatchNorm2d(out_planes)
13
+ self.fc2 = nn.Conv2d(out_planes, num_features, 1, bias=False)
14
+ if init_weight:
15
+ self._initialize_weights()
16
+
17
+ def _initialize_weights(self):
18
+ for m in self.modules():
19
+ if isinstance(m, nn.Conv2d):
20
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
21
+ if m.bias is not None:
22
+ nn.init.constant_(m.bias, 0)
23
+ if isinstance(m, nn.BatchNorm2d):
24
+ nn.init.constant_(m.weight, 1)
25
+ nn.init.constant_(m.bias, 0)
26
+
27
+ def forward(self, x):
28
+ global_x = self.avgpool(x)
29
+ global_x = self.fc1(global_x)
30
+ global_x = F.relu(self.bn(global_x))
31
+ global_x = self.fc2(global_x)
32
+ global_x = F.softmax(global_x, 1)
33
+ return global_x
34
+
35
+
36
+ class ScaleChannelSpatialAttention(nn.Module):
37
+ def __init__(self, in_planes, out_planes, num_features, init_weight=True):
38
+ super(ScaleChannelSpatialAttention, self).__init__()
39
+ self.channel_wise = nn.Sequential(
40
+ nn.AdaptiveAvgPool2d(1),
41
+ nn.Conv2d(in_planes, out_planes, 1, bias=False),
42
+ # nn.BatchNorm2d(out_planes),
43
+ nn.ReLU(),
44
+ nn.Conv2d(out_planes, in_planes, 1, bias=False),
45
+ )
46
+ self.spatial_wise = nn.Sequential(
47
+ # Nx1xHxW
48
+ nn.Conv2d(1, 1, 3, bias=False, padding=1),
49
+ nn.ReLU(),
50
+ nn.Conv2d(1, 1, 1, bias=False),
51
+ nn.Sigmoid(),
52
+ )
53
+ self.attention_wise = nn.Sequential(
54
+ nn.Conv2d(in_planes, num_features, 1, bias=False), nn.Sigmoid()
55
+ )
56
+ if init_weight:
57
+ self._initialize_weights()
58
+
59
+ def _initialize_weights(self):
60
+ for m in self.modules():
61
+ if isinstance(m, nn.Conv2d):
62
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
63
+ if m.bias is not None:
64
+ nn.init.constant_(m.bias, 0)
65
+ if isinstance(m, nn.BatchNorm2d):
66
+ nn.init.constant_(m.weight, 1)
67
+ nn.init.constant_(m.bias, 0)
68
+
69
+ def forward(self, x):
70
+ # global_x = self.avgpool(x)
71
+ # shape Nx4x1x1
72
+ global_x = self.channel_wise(x).sigmoid()
73
+ # shape: NxCxHxW
74
+ global_x = global_x + x
75
+ # shape:Nx1xHxW
76
+ x = torch.mean(global_x, dim=1, keepdim=True)
77
+ global_x = self.spatial_wise(x) + global_x
78
+ global_x = self.attention_wise(global_x)
79
+ return global_x
80
+
81
+
82
+ class ScaleSpatialAttention(nn.Module):
83
+ def __init__(self, in_planes, out_planes, num_features, init_weight=True):
84
+ super(ScaleSpatialAttention, self).__init__()
85
+ self.spatial_wise = nn.Sequential(
86
+ # Nx1xHxW
87
+ nn.Conv2d(1, 1, 3, bias=False, padding=1),
88
+ nn.ReLU(),
89
+ nn.Conv2d(1, 1, 1, bias=False),
90
+ nn.Sigmoid(),
91
+ )
92
+ self.attention_wise = nn.Sequential(
93
+ nn.Conv2d(in_planes, num_features, 1, bias=False), nn.Sigmoid()
94
+ )
95
+ if init_weight:
96
+ self._initialize_weights()
97
+
98
+ def _initialize_weights(self):
99
+ for m in self.modules():
100
+ if isinstance(m, nn.Conv2d):
101
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
102
+ if m.bias is not None:
103
+ nn.init.constant_(m.bias, 0)
104
+ if isinstance(m, nn.BatchNorm2d):
105
+ nn.init.constant_(m.weight, 1)
106
+ nn.init.constant_(m.bias, 0)
107
+
108
+ def forward(self, x):
109
+ global_x = torch.mean(x, dim=1, keepdim=True)
110
+ global_x = self.spatial_wise(global_x) + x
111
+ global_x = self.attention_wise(global_x)
112
+ return global_x
113
+
114
+
115
+ class ScaleFeatureSelection(nn.Module):
116
+ def __init__(
117
+ self,
118
+ in_channels,
119
+ inter_channels,
120
+ out_features_num=4,
121
+ attention_type="scale_spatial",
122
+ ):
123
+ super(ScaleFeatureSelection, self).__init__()
124
+ self.in_channels = in_channels
125
+ self.inter_channels = inter_channels
126
+ self.out_features_num = out_features_num
127
+ self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1)
128
+ self.type = attention_type
129
+ if self.type == "scale_spatial":
130
+ self.enhanced_attention = ScaleSpatialAttention(
131
+ inter_channels, inter_channels // 4, out_features_num
132
+ )
133
+ elif self.type == "scale_channel_spatial":
134
+ self.enhanced_attention = ScaleChannelSpatialAttention(
135
+ inter_channels, inter_channels // 4, out_features_num
136
+ )
137
+ elif self.type == "scale_channel":
138
+ self.enhanced_attention = ScaleChannelAttention(
139
+ inter_channels, inter_channels // 2, out_features_num
140
+ )
141
+
142
+ def _initialize_weights(self, m):
143
+ classname = m.__class__.__name__
144
+ if classname.find("Conv") != -1:
145
+ nn.init.kaiming_normal_(m.weight.data)
146
+ elif classname.find("BatchNorm") != -1:
147
+ m.weight.data.fill_(1.0)
148
+ m.bias.data.fill_(1e-4)
149
+
150
+ def forward(self, concat_x, features_list):
151
+ concat_x = self.conv(concat_x)
152
+ score = self.enhanced_attention(concat_x)
153
+ assert len(features_list) == self.out_features_num
154
+ if self.type not in ["scale_channel_spatial", "scale_spatial"]:
155
+ shape = features_list[0].shape[2:]
156
+ score = F.interpolate(score, size=shape, mode="bilinear")
157
+ x = []
158
+ for i in range(self.out_features_num):
159
+ x.append(score[:, i : i + 1] * features_list[i])
160
+ return torch.cat(x, dim=1)
@@ -0,0 +1,218 @@
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Optional
18
+
19
+ import torch
20
+ from timm.models.vision_transformer import PatchEmbed, VisionTransformer
21
+ from torch import Tensor
22
+ from torch import nn as nn
23
+ from torch.nn import functional as F
24
+ from torch.nn.modules import transformer
25
+
26
+
27
+ class DecoderLayer(nn.Module):
28
+ """A Transformer decoder layer supporting two-stream attention (XLNet)
29
+ This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ embed_dim,
35
+ num_heads,
36
+ mlp_ratio=2048,
37
+ dropout=0.1,
38
+ activation="gelu",
39
+ layer_norm_eps=1e-5,
40
+ ):
41
+ super().__init__()
42
+ dim_feedforward = embed_dim * mlp_ratio
43
+ self.self_attn = nn.MultiheadAttention(
44
+ embed_dim, num_heads, dropout=dropout, batch_first=True
45
+ )
46
+ self.cross_attn = nn.MultiheadAttention(
47
+ embed_dim, num_heads, dropout=dropout, batch_first=True
48
+ )
49
+ # Implementation of Feedforward model
50
+ self.linear1 = nn.Linear(embed_dim, dim_feedforward)
51
+ self.dropout = nn.Dropout(dropout)
52
+ self.linear2 = nn.Linear(dim_feedforward, embed_dim)
53
+
54
+ self.norm1 = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
55
+ self.norm2 = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
56
+ self.norm_q = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
57
+ self.norm_c = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
58
+ self.dropout1 = nn.Dropout(dropout)
59
+ self.dropout2 = nn.Dropout(dropout)
60
+ self.dropout3 = nn.Dropout(dropout)
61
+
62
+ self.activation = transformer._get_activation_fn(activation)
63
+
64
+ def __setstate__(self, state):
65
+ if "activation" not in state:
66
+ state["activation"] = F.gelu
67
+ super().__setstate__(state)
68
+
69
+ def forward_stream(
70
+ self,
71
+ tgt: Tensor,
72
+ tgt_norm: Tensor,
73
+ tgt_kv: Tensor,
74
+ memory: Tensor,
75
+ tgt_mask: Optional[Tensor],
76
+ tgt_key_padding_mask: Optional[Tensor],
77
+ ):
78
+ """Forward pass for a single stream (i.e. content or query)
79
+ tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
80
+ Both tgt_kv and memory are expected to be LayerNorm'd too.
81
+ memory is LayerNorm'd by ViT.
82
+ """
83
+ tgt2, sa_weights = self.self_attn(
84
+ tgt_norm,
85
+ tgt_kv,
86
+ tgt_kv,
87
+ attn_mask=tgt_mask,
88
+ key_padding_mask=tgt_key_padding_mask,
89
+ )
90
+ tgt = tgt + self.dropout1(tgt2)
91
+
92
+ tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
93
+ tgt = tgt + self.dropout2(tgt2)
94
+
95
+ tgt2 = self.linear2(
96
+ self.dropout(self.activation(self.linear1(self.norm2(tgt))))
97
+ )
98
+ tgt = tgt + self.dropout3(tgt2)
99
+ return tgt, sa_weights, ca_weights
100
+
101
+ def forward(
102
+ self,
103
+ query,
104
+ content,
105
+ memory,
106
+ query_mask: Optional[Tensor] = None,
107
+ content_mask: Optional[Tensor] = None,
108
+ content_key_padding_mask: Optional[Tensor] = None,
109
+ update_content: bool = True,
110
+ ):
111
+ query_norm = self.norm_q(query)
112
+ content_norm = self.norm_c(content)
113
+ query = self.forward_stream(
114
+ query,
115
+ query_norm,
116
+ content_norm,
117
+ memory,
118
+ query_mask,
119
+ content_key_padding_mask,
120
+ )[0]
121
+ if update_content:
122
+ content = self.forward_stream(
123
+ content,
124
+ content_norm,
125
+ content_norm,
126
+ memory,
127
+ content_mask,
128
+ content_key_padding_mask,
129
+ )[0]
130
+ return query, content
131
+
132
+
133
+ class Decoder(nn.Module):
134
+ __constants__ = ["norm"]
135
+
136
+ def __init__(self, norm, cfg):
137
+ super().__init__()
138
+ decoder_layer = DecoderLayer(
139
+ embed_dim=cfg.embed_dim,
140
+ num_heads=cfg.num_heads,
141
+ mlp_ratio=cfg.mlp_ratio,
142
+ )
143
+
144
+ self.layers = transformer._get_clones(decoder_layer, cfg.depth)
145
+ self.num_layers = cfg.depth
146
+ self.norm = norm
147
+
148
+ def forward(
149
+ self,
150
+ query,
151
+ content,
152
+ memory,
153
+ query_mask: Optional[Tensor] = None,
154
+ content_mask: Optional[Tensor] = None,
155
+ content_key_padding_mask: Optional[Tensor] = None,
156
+ ):
157
+ for i, mod in enumerate(self.layers):
158
+ last = i == len(self.layers) - 1
159
+ query, content = mod(
160
+ query,
161
+ content,
162
+ memory,
163
+ query_mask,
164
+ content_mask,
165
+ content_key_padding_mask,
166
+ update_content=not last,
167
+ )
168
+ query = self.norm(query)
169
+ return query
170
+
171
+
172
+ class Encoder(VisionTransformer):
173
+ def __init__(
174
+ self,
175
+ img_size=224,
176
+ patch_size=16,
177
+ in_chans=3,
178
+ embed_dim=768,
179
+ depth=12,
180
+ num_heads=12,
181
+ mlp_ratio=4.0,
182
+ qkv_bias=True,
183
+ drop_rate=0.0,
184
+ attn_drop_rate=0.0,
185
+ drop_path_rate=0.0,
186
+ embed_layer=PatchEmbed,
187
+ ):
188
+ super().__init__(
189
+ img_size,
190
+ patch_size,
191
+ in_chans,
192
+ embed_dim=embed_dim,
193
+ depth=depth,
194
+ num_heads=num_heads,
195
+ mlp_ratio=mlp_ratio,
196
+ qkv_bias=qkv_bias,
197
+ drop_rate=drop_rate,
198
+ attn_drop_rate=attn_drop_rate,
199
+ drop_path_rate=drop_path_rate,
200
+ embed_layer=embed_layer,
201
+ num_classes=0, # These
202
+ global_pool="", # disable the
203
+ class_token=False, # classifier head.
204
+ )
205
+
206
+ def forward(self, x):
207
+ # Return all tokens
208
+ return self.forward_features(x)
209
+
210
+
211
+ class TokenEmbedding(nn.Module):
212
+ def __init__(self, charset_size: int, embed_dim: int):
213
+ super().__init__()
214
+ self.embedding = nn.Embedding(charset_size, embed_dim)
215
+ self.embed_dim = embed_dim
216
+
217
+ def forward(self, tokens: torch.Tensor):
218
+ return math.sqrt(self.embed_dim) * self.embedding(tokens)