dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
# -*-Encoding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Authors:
|
|
4
|
+
Li,Yan (liyan22021121@gmail.com)
|
|
5
|
+
"""
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from torch.nn.modules.batchnorm import _BatchNorm
|
|
10
|
+
from collections import OrderedDict
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
BN_EPS = 1e-5
|
|
14
|
+
SYNC_BN = False
|
|
15
|
+
|
|
16
|
+
OPS = OrderedDict([
|
|
17
|
+
('res_elu', lambda Cin, Cout, stride: ELUConv(Cin, Cout, 3, stride, 1)),
|
|
18
|
+
('res_bnelu', lambda Cin, Cout, stride: BNELUConv(Cin, Cout, 3, stride, 1)),
|
|
19
|
+
('res_bnswish', lambda Cin, Cout, stride: BNSwishConv(Cin, Cout, 3, stride, 1)),
|
|
20
|
+
('res_bnswish5', lambda Cin, Cout, stride: BNSwishConv(Cin, Cout, 3, stride, 2, 2)),
|
|
21
|
+
('mconv_e6k5g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=6, dil=1, k=5, g=1)),
|
|
22
|
+
('mconv_e3k5g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=3, dil=1, k=5, g=1)),
|
|
23
|
+
('mconv_e3k5g8', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=3, dil=1, k=5, g=8)),
|
|
24
|
+
('mconv_e6k11g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=6, dil=1, k=11, g=0)),
|
|
25
|
+
])
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SyncBatchNormSwish(_BatchNorm):
|
|
29
|
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
|
30
|
+
track_running_stats=True, process_group=None):
|
|
31
|
+
super(SyncBatchNormSwish, self).__init__(num_features, eps, momentum, affine, track_running_stats)
|
|
32
|
+
self.process_group = process_group
|
|
33
|
+
self.ddp_gpu_size = None
|
|
34
|
+
|
|
35
|
+
def forward(self, input):
|
|
36
|
+
exponential_average_factor = self.momentum
|
|
37
|
+
out = F.batch_norm(
|
|
38
|
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
|
39
|
+
self.training or not self.track_running_stats,
|
|
40
|
+
exponential_average_factor, self.eps)
|
|
41
|
+
return out
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_skip_connection(C, stride, channel_mult):
|
|
45
|
+
if stride == 1:
|
|
46
|
+
return Identity()
|
|
47
|
+
elif stride == 2:
|
|
48
|
+
return FactorizedReduce(C, int(channel_mult * C))
|
|
49
|
+
elif stride == -1:
|
|
50
|
+
return nn.Sequential(UpSample(), Conv2D(C, int(C / channel_mult), kernel_size=1))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def norm(t, dim):
|
|
54
|
+
return torch.sqrt(torch.sum(t * t, dim))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def logit(t):
|
|
58
|
+
return torch.log(t) - torch.log(1 - t)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def act(t):
|
|
62
|
+
# The following implementation has lower memory.
|
|
63
|
+
return SwishFN.apply(t)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class SwishFN(torch.autograd.Function):
|
|
67
|
+
def forward(ctx, i):
|
|
68
|
+
result = i * torch.sigmoid(i)
|
|
69
|
+
ctx.save_for_backward(i)
|
|
70
|
+
return result
|
|
71
|
+
|
|
72
|
+
def backward(ctx, grad_output):
|
|
73
|
+
i = ctx.saved_variables[0]
|
|
74
|
+
sigmoid_i = torch.sigmoid(i)
|
|
75
|
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class Swish(nn.Module):
|
|
79
|
+
def __init__(self):
|
|
80
|
+
super(Swish, self).__init__()
|
|
81
|
+
|
|
82
|
+
def forward(self, x):
|
|
83
|
+
return act(x)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def normalize_weight_jit(log_weight_norm, weight):
|
|
87
|
+
n = torch.exp(log_weight_norm)
|
|
88
|
+
wn = torch.sqrt(torch.sum(weight * weight, dim=[1, 2, 3])) # norm(w)
|
|
89
|
+
weight = n * weight / (wn.view(-1, 1, 1, 1) + 1e-5)
|
|
90
|
+
return weight
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class Conv2D(nn.Conv2d):
|
|
94
|
+
"""Allows for weights as input."""
|
|
95
|
+
|
|
96
|
+
def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, data_init=False,
|
|
97
|
+
weight_norm=True):
|
|
98
|
+
"""
|
|
99
|
+
Args:
|
|
100
|
+
use_shared (bool): Use weights for this layer or not?
|
|
101
|
+
"""
|
|
102
|
+
super(Conv2D, self).__init__(C_in, C_out, kernel_size, stride, padding, dilation, groups, bias)
|
|
103
|
+
|
|
104
|
+
self.log_weight_norm = None
|
|
105
|
+
if weight_norm:
|
|
106
|
+
init = norm(self.weight, dim=[1, 2, 3]).view(-1, 1, 1, 1)
|
|
107
|
+
self.log_weight_norm = nn.Parameter(torch.log(init + 1e-2), requires_grad=True)
|
|
108
|
+
|
|
109
|
+
self.data_init = data_init
|
|
110
|
+
self.init_done = False
|
|
111
|
+
self.weight_normalized = self.normalize_weight()
|
|
112
|
+
|
|
113
|
+
def forward(self, x):
|
|
114
|
+
# do data based initialization
|
|
115
|
+
self.weight_normalized = self.normalize_weight()
|
|
116
|
+
#print(self.weight_normalized.shape)
|
|
117
|
+
bias = self.bias
|
|
118
|
+
return F.conv2d(x, self.weight_normalized, bias, self.stride,
|
|
119
|
+
self.padding, self.dilation, self.groups)
|
|
120
|
+
|
|
121
|
+
def normalize_weight(self):
|
|
122
|
+
""" applies weight normalization """
|
|
123
|
+
if self.log_weight_norm is not None:
|
|
124
|
+
weight = normalize_weight_jit(self.log_weight_norm, self.weight)
|
|
125
|
+
else:
|
|
126
|
+
weight = self.weight
|
|
127
|
+
|
|
128
|
+
return weight
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class Identity(nn.Module):
|
|
132
|
+
def __init__(self):
|
|
133
|
+
super(Identity, self).__init__()
|
|
134
|
+
|
|
135
|
+
def forward(self, x):
|
|
136
|
+
return x
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class SyncBatchNorm(nn.Module):
|
|
140
|
+
def __init__(self, *args, **kwargs):
|
|
141
|
+
super(SyncBatchNorm, self).__init__()
|
|
142
|
+
self.bn = nn.BatchNorm(*args, **kwargs)
|
|
143
|
+
|
|
144
|
+
def forward(self, x):
|
|
145
|
+
return self.bn(x)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# quick switch between multi-gpu, single-gpu batch norm
|
|
149
|
+
def get_batchnorm(*args, **kwargs):
|
|
150
|
+
return nn.BatchNorm2d(*args, **kwargs)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class ELUConv(nn.Module):
|
|
154
|
+
def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1):
|
|
155
|
+
super(ELUConv, self).__init__()
|
|
156
|
+
self.upsample = stride == -1
|
|
157
|
+
stride = abs(stride)
|
|
158
|
+
self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation,
|
|
159
|
+
data_init=True)
|
|
160
|
+
|
|
161
|
+
def forward(self, x):
|
|
162
|
+
out = F.elu(x)
|
|
163
|
+
if self.upsample:
|
|
164
|
+
out = F.interpolate(out, scale_factor=2, mode='nearest')
|
|
165
|
+
out = self.conv_0(out)
|
|
166
|
+
return out
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class BNELUConv(nn.Module):
|
|
170
|
+
def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1):
|
|
171
|
+
super(BNELUConv, self).__init__()
|
|
172
|
+
self.upsample = stride == -1
|
|
173
|
+
stride = abs(stride)
|
|
174
|
+
self.bn = get_batchnorm(C_in, eps=BN_EPS, momentum=0.05)
|
|
175
|
+
self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation)
|
|
176
|
+
|
|
177
|
+
def forward(self, x):
|
|
178
|
+
x = self.bn(x)
|
|
179
|
+
out = F.elu(x)
|
|
180
|
+
if self.upsample:
|
|
181
|
+
out = F.interpolate(out, scale_factor=2, mode='nearest')
|
|
182
|
+
out = self.conv_0(out)
|
|
183
|
+
return out
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class BNSwishConv(nn.Module):
|
|
187
|
+
"""ReLU + Conv2d + BN."""
|
|
188
|
+
|
|
189
|
+
def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1):
|
|
190
|
+
super(BNSwishConv, self).__init__()
|
|
191
|
+
self.upsample = stride == -1
|
|
192
|
+
stride = abs(stride)
|
|
193
|
+
self.bn_act = SyncBatchNormSwish(C_in, eps=BN_EPS, momentum=0.05)
|
|
194
|
+
self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation)
|
|
195
|
+
|
|
196
|
+
def forward(self, x):
|
|
197
|
+
"""
|
|
198
|
+
Args:
|
|
199
|
+
x (torch.Tensor): of size (B, C_in, H, W)
|
|
200
|
+
"""
|
|
201
|
+
out = self.bn_act(x)
|
|
202
|
+
if self.upsample:
|
|
203
|
+
out = F.interpolate(out, scale_factor=2, mode='nearest')
|
|
204
|
+
out = self.conv_0(out)
|
|
205
|
+
return out
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class FactorizedReduce(nn.Module):
|
|
209
|
+
def __init__(self, C_in, C_out):
|
|
210
|
+
super(FactorizedReduce, self).__init__()
|
|
211
|
+
assert C_out % 2 == 0
|
|
212
|
+
self.conv_1 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True)
|
|
213
|
+
self.conv_2 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True)
|
|
214
|
+
self.conv_3 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True)
|
|
215
|
+
self.conv_4 = Conv2D(C_in, C_out - 3 * (C_out // 4), 1, stride=2, padding=0, bias=True)
|
|
216
|
+
|
|
217
|
+
def forward(self, x):
|
|
218
|
+
out = act(x)
|
|
219
|
+
conv1 = self.conv_1(out[:,:,:, :])
|
|
220
|
+
conv2 = self.conv_2(out[:, :, 1:, :])
|
|
221
|
+
conv3 = self.conv_3(out[:, :, :, :])
|
|
222
|
+
conv4 = self.conv_4(out[:, :, 1:, :])
|
|
223
|
+
out = torch.cat([conv1, conv2, conv3, conv4], dim=1)
|
|
224
|
+
return out
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class UpSample(nn.Module):
|
|
228
|
+
def __init__(self):
|
|
229
|
+
super(UpSample, self).__init__()
|
|
230
|
+
pass
|
|
231
|
+
|
|
232
|
+
def forward(self, x):
|
|
233
|
+
return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class EncCombinerCell(nn.Module):
|
|
237
|
+
def __init__(self, Cin1, Cin2, Cout, cell_type):
|
|
238
|
+
super(EncCombinerCell, self).__init__()
|
|
239
|
+
self.cell_type = cell_type
|
|
240
|
+
# Cin = Cin1 + Cin2
|
|
241
|
+
self.conv = Conv2D(Cin2, Cout, kernel_size=1, stride=1, padding=0, bias=True)
|
|
242
|
+
|
|
243
|
+
def forward(self, x1, x2):
|
|
244
|
+
x2 = self.conv(x2)
|
|
245
|
+
out = x1 + x2
|
|
246
|
+
return out
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
# original combiner
|
|
250
|
+
class DecCombinerCell(nn.Module):
|
|
251
|
+
def __init__(self, Cin1, Cin2, Cout, cell_type):
|
|
252
|
+
super(DecCombinerCell, self).__init__()
|
|
253
|
+
self.cell_type = cell_type
|
|
254
|
+
self.conv = Conv2D(Cin1 + Cin2, Cout, kernel_size=1, stride=1, padding=0, bias=True)
|
|
255
|
+
|
|
256
|
+
def forward(self, x1, x2):
|
|
257
|
+
out = torch.cat([x1, x2], dim=1)
|
|
258
|
+
out = self.conv(out)
|
|
259
|
+
return out
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class ConvBNSwish(nn.Module):
|
|
263
|
+
def __init__(self, Cin, Cout, k=3, stride=1, groups=1, dilation=1):
|
|
264
|
+
padding = dilation * (k - 1) // 2
|
|
265
|
+
super(ConvBNSwish, self).__init__()
|
|
266
|
+
|
|
267
|
+
self.conv = nn.Sequential(
|
|
268
|
+
Conv2D(Cin, Cout, k, stride, padding, groups=groups, bias=False, dilation=dilation, weight_norm=False),
|
|
269
|
+
SyncBatchNormSwish(Cout, eps=BN_EPS, momentum=0.05) # drop in replacement for BN + Swish
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
def forward(self, x):
|
|
273
|
+
return self.conv(x)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class SE(nn.Module):
|
|
277
|
+
def __init__(self, Cin, Cout):
|
|
278
|
+
super(SE, self).__init__()
|
|
279
|
+
num_hidden = max(Cout // 16, 4)
|
|
280
|
+
self.se = nn.Sequential(nn.Linear(Cin, num_hidden), nn.ReLU(inplace=True),
|
|
281
|
+
nn.Linear(num_hidden, Cout), nn.Sigmoid())
|
|
282
|
+
|
|
283
|
+
def forward(self, x):
|
|
284
|
+
se = torch.mean(x, dim=[2, 3])
|
|
285
|
+
se = se.view(se.size(0), -1)
|
|
286
|
+
se = self.se(se)
|
|
287
|
+
se = se.view(se.size(0), -1, 1, 1)
|
|
288
|
+
return x * se
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class InvertedResidual(nn.Module):
|
|
292
|
+
def __init__(self, Cin, Cout, stride, ex, dil, k, g):
|
|
293
|
+
super(InvertedResidual, self).__init__()
|
|
294
|
+
self.stride = stride
|
|
295
|
+
assert stride in [1, 2, -1]
|
|
296
|
+
|
|
297
|
+
hidden_dim = int(round(Cin * ex))
|
|
298
|
+
self.use_res_connect = self.stride == 1 and Cin == Cout
|
|
299
|
+
self.upsample = self.stride == -1
|
|
300
|
+
self.stride = abs(self.stride)
|
|
301
|
+
groups = hidden_dim if g == 0 else g
|
|
302
|
+
|
|
303
|
+
layers0 = [nn.UpsamplingNearest2d(scale_factor=2)] if self.upsample else []
|
|
304
|
+
layers = [get_batchnorm(Cin, eps=BN_EPS, momentum=0.05),
|
|
305
|
+
ConvBNSwish(Cin, hidden_dim, k=1),
|
|
306
|
+
ConvBNSwish(hidden_dim, hidden_dim, stride=self.stride, groups=groups, k=k, dilation=dil),
|
|
307
|
+
Conv2D(hidden_dim, Cout, 1, 1, 0, bias=False, weight_norm=False),
|
|
308
|
+
get_batchnorm(Cout, momentum=0.05)]
|
|
309
|
+
|
|
310
|
+
layers0.extend(layers)
|
|
311
|
+
self.conv = nn.Sequential(*layers0)
|
|
312
|
+
|
|
313
|
+
def forward(self, x):
|
|
314
|
+
return self.conv(x)
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# -*-Encoding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Authors:
|
|
4
|
+
Li,Yan (liyan22021121@gmail.com)
|
|
5
|
+
"""
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def weights_init(m):
|
|
10
|
+
classname = m.__class__.__name__
|
|
11
|
+
if classname.find("Conv") != -1:
|
|
12
|
+
m.weight.data.normal_(0.0, 0.2)
|
|
13
|
+
elif classname.find("BatchNorm") != -1:
|
|
14
|
+
m.weight.data.normal_(1.0, 0.2)
|
|
15
|
+
m.bias.data.fill_(0)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MyConvo2d(nn.Module):
|
|
19
|
+
def __init__(self, input_dim, output_dim, kernel_size, stride = 1, bias = True):
|
|
20
|
+
super(MyConvo2d, self).__init__()
|
|
21
|
+
self.padding = int((kernel_size - 1)/2)
|
|
22
|
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=self.padding, bias = bias)
|
|
23
|
+
|
|
24
|
+
def forward(self, input):
|
|
25
|
+
output = self.conv(input)
|
|
26
|
+
return output
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Square(nn.Module):
|
|
30
|
+
def __init__(self):
|
|
31
|
+
super(Square,self).__init__()
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
def forward(self,in_vect):
|
|
35
|
+
return in_vect**2
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Swish(nn.Module):
|
|
39
|
+
def __init__(self):
|
|
40
|
+
super(Swish,self).__init__()
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
def forward(self,in_vect):
|
|
44
|
+
return in_vect*nn.functional.sigmoid(in_vect)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class MeanPoolConv(nn.Module):
|
|
48
|
+
def __init__(self, input_dim, output_dim, kernel_size):
|
|
49
|
+
super(MeanPoolConv, self).__init__()
|
|
50
|
+
self.conv = MyConvo2d(input_dim, output_dim, kernel_size)
|
|
51
|
+
|
|
52
|
+
def forward(self, input):
|
|
53
|
+
output = input
|
|
54
|
+
output = self.conv(output)
|
|
55
|
+
return output
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ConvMeanPool(nn.Module):
|
|
59
|
+
def __init__(self, input_dim, output_dim, kernel_size):
|
|
60
|
+
super(ConvMeanPool, self).__init__()
|
|
61
|
+
self.conv = MyConvo2d(input_dim, output_dim, kernel_size)
|
|
62
|
+
|
|
63
|
+
def forward(self, input):
|
|
64
|
+
output = self.conv(input)
|
|
65
|
+
return output
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ResidualBlock(nn.Module):
|
|
69
|
+
def __init__(self, input_dim, output_dim, kernel_size, hw, resample=None, normalize=False,AF=nn.ELU()):
|
|
70
|
+
super(ResidualBlock, self).__init__()
|
|
71
|
+
|
|
72
|
+
self.input_dim = input_dim
|
|
73
|
+
self.output_dim = output_dim
|
|
74
|
+
self.kernel_size = kernel_size
|
|
75
|
+
self.resample = resample
|
|
76
|
+
self.normalize = normalize
|
|
77
|
+
self.bn1 = None
|
|
78
|
+
self.bn2 = None
|
|
79
|
+
self.relu1 = AF
|
|
80
|
+
self.relu2 = AF
|
|
81
|
+
if resample == 'down':
|
|
82
|
+
self.bn1 = nn.LayerNorm([input_dim, hw, hw])
|
|
83
|
+
self.bn2 = nn.LayerNorm([input_dim, hw, hw])
|
|
84
|
+
elif resample == 'none':
|
|
85
|
+
self.bn1 = nn.LayerNorm([input_dim, hw, hw])
|
|
86
|
+
self.bn2 = nn.LayerNorm([input_dim, hw, hw])
|
|
87
|
+
|
|
88
|
+
if resample == 'down':
|
|
89
|
+
self.conv_shortcut = MeanPoolConv(input_dim, output_dim, kernel_size = 1)
|
|
90
|
+
self.conv_1 = MyConvo2d(input_dim, input_dim, kernel_size = kernel_size, bias = False)
|
|
91
|
+
self.conv_2 = ConvMeanPool(input_dim, output_dim, kernel_size = kernel_size)
|
|
92
|
+
elif resample == 'none':
|
|
93
|
+
self.conv_shortcut = MyConvo2d(input_dim, output_dim, kernel_size = 1)
|
|
94
|
+
self.conv_1 = MyConvo2d(input_dim, input_dim, kernel_size = kernel_size, bias = False)
|
|
95
|
+
self.conv_2 = MyConvo2d(input_dim, output_dim, kernel_size = kernel_size)
|
|
96
|
+
|
|
97
|
+
def forward(self, input):
|
|
98
|
+
if self.input_dim == self.output_dim and self.resample is None:
|
|
99
|
+
shortcut = input
|
|
100
|
+
else:
|
|
101
|
+
shortcut = self.conv_shortcut(input)
|
|
102
|
+
|
|
103
|
+
if self.normalize is False:
|
|
104
|
+
output = input
|
|
105
|
+
output = self.relu1(output)
|
|
106
|
+
output = self.conv_1(output)
|
|
107
|
+
output = self.relu2(output)
|
|
108
|
+
output = self.conv_2(output)
|
|
109
|
+
else:
|
|
110
|
+
output = input
|
|
111
|
+
output = self.bn1(output)
|
|
112
|
+
output = self.relu1(output)
|
|
113
|
+
output = self.conv_1(output)
|
|
114
|
+
output = self.bn2(output)
|
|
115
|
+
output = self.relu2(output)
|
|
116
|
+
output = self.conv_2(output)
|
|
117
|
+
|
|
118
|
+
return shortcut + output
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class Res12_Quadratic(nn.Module):
|
|
122
|
+
def __init__(self,inchan,dim,hw,normalize=False,AF=None):
|
|
123
|
+
super(Res12_Quadratic, self).__init__()
|
|
124
|
+
|
|
125
|
+
self.hw = hw
|
|
126
|
+
self.dim = dim
|
|
127
|
+
self.inchan = inchan
|
|
128
|
+
self.conv1 = MyConvo2d(inchan,dim, 3)
|
|
129
|
+
self.rb1 = ResidualBlock(dim, 2*dim, 3, int(hw), resample = 'down',normalize=normalize,AF=AF)
|
|
130
|
+
self.rbc1 = ResidualBlock(2*dim, 2*dim, 3, int(hw/2), resample = 'none',normalize=normalize,AF=AF)
|
|
131
|
+
self.rb2 = ResidualBlock(2*dim, 4*dim, 3, int(hw/2), resample = 'down',normalize=normalize,AF=AF)
|
|
132
|
+
self.rbc2 = ResidualBlock(4*dim, 4*dim, 3, int(hw/4), resample = 'none',normalize=normalize,AF=AF)
|
|
133
|
+
self.rb3 = ResidualBlock(4*dim, 8*dim, 3, int(hw/4), resample = 'down',normalize=normalize,AF=AF)
|
|
134
|
+
self.rbc3 = ResidualBlock(8*dim, 8*dim, 3, int(hw/8), resample = 'none',normalize=normalize,AF=AF)
|
|
135
|
+
self.ln1 = nn.Linear(int(hw/8)*int(hw/8)*8*dim, 1)
|
|
136
|
+
self.ln2 = nn.Linear(int(hw/8)*int(hw/8)*8*dim, 1)
|
|
137
|
+
self.lq = nn.Linear(int(hw/8)*int(hw/8)*8*dim, 1)
|
|
138
|
+
self.Square = Square()
|
|
139
|
+
|
|
140
|
+
def forward(self, x_in):
|
|
141
|
+
output = x_in
|
|
142
|
+
output = self.conv1(output)
|
|
143
|
+
# print(output.shape)
|
|
144
|
+
output = self.rb1(output)
|
|
145
|
+
output = self.rbc1(output)
|
|
146
|
+
output = self.rb2(output)
|
|
147
|
+
output = self.rbc2(output)
|
|
148
|
+
output = self.rb3(output)
|
|
149
|
+
output = self.rbc3(output)
|
|
150
|
+
output = output.view(-1, int(self.hw/8)*int(self.hw/8)*8*self.dim)
|
|
151
|
+
output = self.ln1(output)*self.ln2(output)+self.lq(self.Square(output))
|
|
152
|
+
output = output.view(-1)
|
|
153
|
+
return output
|