dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,314 @@
1
+ # -*-Encoding: utf-8 -*-
2
+ """
3
+ Authors:
4
+ Li,Yan (liyan22021121@gmail.com)
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn.modules.batchnorm import _BatchNorm
10
+ from collections import OrderedDict
11
+
12
+
13
+ BN_EPS = 1e-5
14
+ SYNC_BN = False
15
+
16
+ OPS = OrderedDict([
17
+ ('res_elu', lambda Cin, Cout, stride: ELUConv(Cin, Cout, 3, stride, 1)),
18
+ ('res_bnelu', lambda Cin, Cout, stride: BNELUConv(Cin, Cout, 3, stride, 1)),
19
+ ('res_bnswish', lambda Cin, Cout, stride: BNSwishConv(Cin, Cout, 3, stride, 1)),
20
+ ('res_bnswish5', lambda Cin, Cout, stride: BNSwishConv(Cin, Cout, 3, stride, 2, 2)),
21
+ ('mconv_e6k5g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=6, dil=1, k=5, g=1)),
22
+ ('mconv_e3k5g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=3, dil=1, k=5, g=1)),
23
+ ('mconv_e3k5g8', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=3, dil=1, k=5, g=8)),
24
+ ('mconv_e6k11g0', lambda Cin, Cout, stride: InvertedResidual(Cin, Cout, stride, ex=6, dil=1, k=11, g=0)),
25
+ ])
26
+
27
+
28
+ class SyncBatchNormSwish(_BatchNorm):
29
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
30
+ track_running_stats=True, process_group=None):
31
+ super(SyncBatchNormSwish, self).__init__(num_features, eps, momentum, affine, track_running_stats)
32
+ self.process_group = process_group
33
+ self.ddp_gpu_size = None
34
+
35
+ def forward(self, input):
36
+ exponential_average_factor = self.momentum
37
+ out = F.batch_norm(
38
+ input, self.running_mean, self.running_var, self.weight, self.bias,
39
+ self.training or not self.track_running_stats,
40
+ exponential_average_factor, self.eps)
41
+ return out
42
+
43
+
44
+ def get_skip_connection(C, stride, channel_mult):
45
+ if stride == 1:
46
+ return Identity()
47
+ elif stride == 2:
48
+ return FactorizedReduce(C, int(channel_mult * C))
49
+ elif stride == -1:
50
+ return nn.Sequential(UpSample(), Conv2D(C, int(C / channel_mult), kernel_size=1))
51
+
52
+
53
+ def norm(t, dim):
54
+ return torch.sqrt(torch.sum(t * t, dim))
55
+
56
+
57
+ def logit(t):
58
+ return torch.log(t) - torch.log(1 - t)
59
+
60
+
61
+ def act(t):
62
+ # The following implementation has lower memory.
63
+ return SwishFN.apply(t)
64
+
65
+
66
+ class SwishFN(torch.autograd.Function):
67
+ def forward(ctx, i):
68
+ result = i * torch.sigmoid(i)
69
+ ctx.save_for_backward(i)
70
+ return result
71
+
72
+ def backward(ctx, grad_output):
73
+ i = ctx.saved_variables[0]
74
+ sigmoid_i = torch.sigmoid(i)
75
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
76
+
77
+
78
+ class Swish(nn.Module):
79
+ def __init__(self):
80
+ super(Swish, self).__init__()
81
+
82
+ def forward(self, x):
83
+ return act(x)
84
+
85
+
86
+ def normalize_weight_jit(log_weight_norm, weight):
87
+ n = torch.exp(log_weight_norm)
88
+ wn = torch.sqrt(torch.sum(weight * weight, dim=[1, 2, 3])) # norm(w)
89
+ weight = n * weight / (wn.view(-1, 1, 1, 1) + 1e-5)
90
+ return weight
91
+
92
+
93
+ class Conv2D(nn.Conv2d):
94
+ """Allows for weights as input."""
95
+
96
+ def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, data_init=False,
97
+ weight_norm=True):
98
+ """
99
+ Args:
100
+ use_shared (bool): Use weights for this layer or not?
101
+ """
102
+ super(Conv2D, self).__init__(C_in, C_out, kernel_size, stride, padding, dilation, groups, bias)
103
+
104
+ self.log_weight_norm = None
105
+ if weight_norm:
106
+ init = norm(self.weight, dim=[1, 2, 3]).view(-1, 1, 1, 1)
107
+ self.log_weight_norm = nn.Parameter(torch.log(init + 1e-2), requires_grad=True)
108
+
109
+ self.data_init = data_init
110
+ self.init_done = False
111
+ self.weight_normalized = self.normalize_weight()
112
+
113
+ def forward(self, x):
114
+ # do data based initialization
115
+ self.weight_normalized = self.normalize_weight()
116
+ #print(self.weight_normalized.shape)
117
+ bias = self.bias
118
+ return F.conv2d(x, self.weight_normalized, bias, self.stride,
119
+ self.padding, self.dilation, self.groups)
120
+
121
+ def normalize_weight(self):
122
+ """ applies weight normalization """
123
+ if self.log_weight_norm is not None:
124
+ weight = normalize_weight_jit(self.log_weight_norm, self.weight)
125
+ else:
126
+ weight = self.weight
127
+
128
+ return weight
129
+
130
+
131
+ class Identity(nn.Module):
132
+ def __init__(self):
133
+ super(Identity, self).__init__()
134
+
135
+ def forward(self, x):
136
+ return x
137
+
138
+
139
+ class SyncBatchNorm(nn.Module):
140
+ def __init__(self, *args, **kwargs):
141
+ super(SyncBatchNorm, self).__init__()
142
+ self.bn = nn.BatchNorm(*args, **kwargs)
143
+
144
+ def forward(self, x):
145
+ return self.bn(x)
146
+
147
+
148
+ # quick switch between multi-gpu, single-gpu batch norm
149
+ def get_batchnorm(*args, **kwargs):
150
+ return nn.BatchNorm2d(*args, **kwargs)
151
+
152
+
153
+ class ELUConv(nn.Module):
154
+ def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1):
155
+ super(ELUConv, self).__init__()
156
+ self.upsample = stride == -1
157
+ stride = abs(stride)
158
+ self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation,
159
+ data_init=True)
160
+
161
+ def forward(self, x):
162
+ out = F.elu(x)
163
+ if self.upsample:
164
+ out = F.interpolate(out, scale_factor=2, mode='nearest')
165
+ out = self.conv_0(out)
166
+ return out
167
+
168
+
169
+ class BNELUConv(nn.Module):
170
+ def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1):
171
+ super(BNELUConv, self).__init__()
172
+ self.upsample = stride == -1
173
+ stride = abs(stride)
174
+ self.bn = get_batchnorm(C_in, eps=BN_EPS, momentum=0.05)
175
+ self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation)
176
+
177
+ def forward(self, x):
178
+ x = self.bn(x)
179
+ out = F.elu(x)
180
+ if self.upsample:
181
+ out = F.interpolate(out, scale_factor=2, mode='nearest')
182
+ out = self.conv_0(out)
183
+ return out
184
+
185
+
186
+ class BNSwishConv(nn.Module):
187
+ """ReLU + Conv2d + BN."""
188
+
189
+ def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, dilation=1):
190
+ super(BNSwishConv, self).__init__()
191
+ self.upsample = stride == -1
192
+ stride = abs(stride)
193
+ self.bn_act = SyncBatchNormSwish(C_in, eps=BN_EPS, momentum=0.05)
194
+ self.conv_0 = Conv2D(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=True, dilation=dilation)
195
+
196
+ def forward(self, x):
197
+ """
198
+ Args:
199
+ x (torch.Tensor): of size (B, C_in, H, W)
200
+ """
201
+ out = self.bn_act(x)
202
+ if self.upsample:
203
+ out = F.interpolate(out, scale_factor=2, mode='nearest')
204
+ out = self.conv_0(out)
205
+ return out
206
+
207
+
208
+ class FactorizedReduce(nn.Module):
209
+ def __init__(self, C_in, C_out):
210
+ super(FactorizedReduce, self).__init__()
211
+ assert C_out % 2 == 0
212
+ self.conv_1 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True)
213
+ self.conv_2 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True)
214
+ self.conv_3 = Conv2D(C_in, C_out // 4, 1, stride=2, padding=0, bias=True)
215
+ self.conv_4 = Conv2D(C_in, C_out - 3 * (C_out // 4), 1, stride=2, padding=0, bias=True)
216
+
217
+ def forward(self, x):
218
+ out = act(x)
219
+ conv1 = self.conv_1(out[:,:,:, :])
220
+ conv2 = self.conv_2(out[:, :, 1:, :])
221
+ conv3 = self.conv_3(out[:, :, :, :])
222
+ conv4 = self.conv_4(out[:, :, 1:, :])
223
+ out = torch.cat([conv1, conv2, conv3, conv4], dim=1)
224
+ return out
225
+
226
+
227
+ class UpSample(nn.Module):
228
+ def __init__(self):
229
+ super(UpSample, self).__init__()
230
+ pass
231
+
232
+ def forward(self, x):
233
+ return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
234
+
235
+
236
+ class EncCombinerCell(nn.Module):
237
+ def __init__(self, Cin1, Cin2, Cout, cell_type):
238
+ super(EncCombinerCell, self).__init__()
239
+ self.cell_type = cell_type
240
+ # Cin = Cin1 + Cin2
241
+ self.conv = Conv2D(Cin2, Cout, kernel_size=1, stride=1, padding=0, bias=True)
242
+
243
+ def forward(self, x1, x2):
244
+ x2 = self.conv(x2)
245
+ out = x1 + x2
246
+ return out
247
+
248
+
249
+ # original combiner
250
+ class DecCombinerCell(nn.Module):
251
+ def __init__(self, Cin1, Cin2, Cout, cell_type):
252
+ super(DecCombinerCell, self).__init__()
253
+ self.cell_type = cell_type
254
+ self.conv = Conv2D(Cin1 + Cin2, Cout, kernel_size=1, stride=1, padding=0, bias=True)
255
+
256
+ def forward(self, x1, x2):
257
+ out = torch.cat([x1, x2], dim=1)
258
+ out = self.conv(out)
259
+ return out
260
+
261
+
262
+ class ConvBNSwish(nn.Module):
263
+ def __init__(self, Cin, Cout, k=3, stride=1, groups=1, dilation=1):
264
+ padding = dilation * (k - 1) // 2
265
+ super(ConvBNSwish, self).__init__()
266
+
267
+ self.conv = nn.Sequential(
268
+ Conv2D(Cin, Cout, k, stride, padding, groups=groups, bias=False, dilation=dilation, weight_norm=False),
269
+ SyncBatchNormSwish(Cout, eps=BN_EPS, momentum=0.05) # drop in replacement for BN + Swish
270
+ )
271
+
272
+ def forward(self, x):
273
+ return self.conv(x)
274
+
275
+
276
+ class SE(nn.Module):
277
+ def __init__(self, Cin, Cout):
278
+ super(SE, self).__init__()
279
+ num_hidden = max(Cout // 16, 4)
280
+ self.se = nn.Sequential(nn.Linear(Cin, num_hidden), nn.ReLU(inplace=True),
281
+ nn.Linear(num_hidden, Cout), nn.Sigmoid())
282
+
283
+ def forward(self, x):
284
+ se = torch.mean(x, dim=[2, 3])
285
+ se = se.view(se.size(0), -1)
286
+ se = self.se(se)
287
+ se = se.view(se.size(0), -1, 1, 1)
288
+ return x * se
289
+
290
+
291
+ class InvertedResidual(nn.Module):
292
+ def __init__(self, Cin, Cout, stride, ex, dil, k, g):
293
+ super(InvertedResidual, self).__init__()
294
+ self.stride = stride
295
+ assert stride in [1, 2, -1]
296
+
297
+ hidden_dim = int(round(Cin * ex))
298
+ self.use_res_connect = self.stride == 1 and Cin == Cout
299
+ self.upsample = self.stride == -1
300
+ self.stride = abs(self.stride)
301
+ groups = hidden_dim if g == 0 else g
302
+
303
+ layers0 = [nn.UpsamplingNearest2d(scale_factor=2)] if self.upsample else []
304
+ layers = [get_batchnorm(Cin, eps=BN_EPS, momentum=0.05),
305
+ ConvBNSwish(Cin, hidden_dim, k=1),
306
+ ConvBNSwish(hidden_dim, hidden_dim, stride=self.stride, groups=groups, k=k, dilation=dil),
307
+ Conv2D(hidden_dim, Cout, 1, 1, 0, bias=False, weight_norm=False),
308
+ get_batchnorm(Cout, momentum=0.05)]
309
+
310
+ layers0.extend(layers)
311
+ self.conv = nn.Sequential(*layers0)
312
+
313
+ def forward(self, x):
314
+ return self.conv(x)
@@ -0,0 +1,153 @@
1
+ # -*-Encoding: utf-8 -*-
2
+ """
3
+ Authors:
4
+ Li,Yan (liyan22021121@gmail.com)
5
+ """
6
+ from torch import nn
7
+
8
+
9
+ def weights_init(m):
10
+ classname = m.__class__.__name__
11
+ if classname.find("Conv") != -1:
12
+ m.weight.data.normal_(0.0, 0.2)
13
+ elif classname.find("BatchNorm") != -1:
14
+ m.weight.data.normal_(1.0, 0.2)
15
+ m.bias.data.fill_(0)
16
+
17
+
18
+ class MyConvo2d(nn.Module):
19
+ def __init__(self, input_dim, output_dim, kernel_size, stride = 1, bias = True):
20
+ super(MyConvo2d, self).__init__()
21
+ self.padding = int((kernel_size - 1)/2)
22
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=self.padding, bias = bias)
23
+
24
+ def forward(self, input):
25
+ output = self.conv(input)
26
+ return output
27
+
28
+
29
+ class Square(nn.Module):
30
+ def __init__(self):
31
+ super(Square,self).__init__()
32
+ pass
33
+
34
+ def forward(self,in_vect):
35
+ return in_vect**2
36
+
37
+
38
+ class Swish(nn.Module):
39
+ def __init__(self):
40
+ super(Swish,self).__init__()
41
+ pass
42
+
43
+ def forward(self,in_vect):
44
+ return in_vect*nn.functional.sigmoid(in_vect)
45
+
46
+
47
+ class MeanPoolConv(nn.Module):
48
+ def __init__(self, input_dim, output_dim, kernel_size):
49
+ super(MeanPoolConv, self).__init__()
50
+ self.conv = MyConvo2d(input_dim, output_dim, kernel_size)
51
+
52
+ def forward(self, input):
53
+ output = input
54
+ output = self.conv(output)
55
+ return output
56
+
57
+
58
+ class ConvMeanPool(nn.Module):
59
+ def __init__(self, input_dim, output_dim, kernel_size):
60
+ super(ConvMeanPool, self).__init__()
61
+ self.conv = MyConvo2d(input_dim, output_dim, kernel_size)
62
+
63
+ def forward(self, input):
64
+ output = self.conv(input)
65
+ return output
66
+
67
+
68
+ class ResidualBlock(nn.Module):
69
+ def __init__(self, input_dim, output_dim, kernel_size, hw, resample=None, normalize=False,AF=nn.ELU()):
70
+ super(ResidualBlock, self).__init__()
71
+
72
+ self.input_dim = input_dim
73
+ self.output_dim = output_dim
74
+ self.kernel_size = kernel_size
75
+ self.resample = resample
76
+ self.normalize = normalize
77
+ self.bn1 = None
78
+ self.bn2 = None
79
+ self.relu1 = AF
80
+ self.relu2 = AF
81
+ if resample == 'down':
82
+ self.bn1 = nn.LayerNorm([input_dim, hw, hw])
83
+ self.bn2 = nn.LayerNorm([input_dim, hw, hw])
84
+ elif resample == 'none':
85
+ self.bn1 = nn.LayerNorm([input_dim, hw, hw])
86
+ self.bn2 = nn.LayerNorm([input_dim, hw, hw])
87
+
88
+ if resample == 'down':
89
+ self.conv_shortcut = MeanPoolConv(input_dim, output_dim, kernel_size = 1)
90
+ self.conv_1 = MyConvo2d(input_dim, input_dim, kernel_size = kernel_size, bias = False)
91
+ self.conv_2 = ConvMeanPool(input_dim, output_dim, kernel_size = kernel_size)
92
+ elif resample == 'none':
93
+ self.conv_shortcut = MyConvo2d(input_dim, output_dim, kernel_size = 1)
94
+ self.conv_1 = MyConvo2d(input_dim, input_dim, kernel_size = kernel_size, bias = False)
95
+ self.conv_2 = MyConvo2d(input_dim, output_dim, kernel_size = kernel_size)
96
+
97
+ def forward(self, input):
98
+ if self.input_dim == self.output_dim and self.resample is None:
99
+ shortcut = input
100
+ else:
101
+ shortcut = self.conv_shortcut(input)
102
+
103
+ if self.normalize is False:
104
+ output = input
105
+ output = self.relu1(output)
106
+ output = self.conv_1(output)
107
+ output = self.relu2(output)
108
+ output = self.conv_2(output)
109
+ else:
110
+ output = input
111
+ output = self.bn1(output)
112
+ output = self.relu1(output)
113
+ output = self.conv_1(output)
114
+ output = self.bn2(output)
115
+ output = self.relu2(output)
116
+ output = self.conv_2(output)
117
+
118
+ return shortcut + output
119
+
120
+
121
+ class Res12_Quadratic(nn.Module):
122
+ def __init__(self,inchan,dim,hw,normalize=False,AF=None):
123
+ super(Res12_Quadratic, self).__init__()
124
+
125
+ self.hw = hw
126
+ self.dim = dim
127
+ self.inchan = inchan
128
+ self.conv1 = MyConvo2d(inchan,dim, 3)
129
+ self.rb1 = ResidualBlock(dim, 2*dim, 3, int(hw), resample = 'down',normalize=normalize,AF=AF)
130
+ self.rbc1 = ResidualBlock(2*dim, 2*dim, 3, int(hw/2), resample = 'none',normalize=normalize,AF=AF)
131
+ self.rb2 = ResidualBlock(2*dim, 4*dim, 3, int(hw/2), resample = 'down',normalize=normalize,AF=AF)
132
+ self.rbc2 = ResidualBlock(4*dim, 4*dim, 3, int(hw/4), resample = 'none',normalize=normalize,AF=AF)
133
+ self.rb3 = ResidualBlock(4*dim, 8*dim, 3, int(hw/4), resample = 'down',normalize=normalize,AF=AF)
134
+ self.rbc3 = ResidualBlock(8*dim, 8*dim, 3, int(hw/8), resample = 'none',normalize=normalize,AF=AF)
135
+ self.ln1 = nn.Linear(int(hw/8)*int(hw/8)*8*dim, 1)
136
+ self.ln2 = nn.Linear(int(hw/8)*int(hw/8)*8*dim, 1)
137
+ self.lq = nn.Linear(int(hw/8)*int(hw/8)*8*dim, 1)
138
+ self.Square = Square()
139
+
140
+ def forward(self, x_in):
141
+ output = x_in
142
+ output = self.conv1(output)
143
+ # print(output.shape)
144
+ output = self.rb1(output)
145
+ output = self.rbc1(output)
146
+ output = self.rb2(output)
147
+ output = self.rbc2(output)
148
+ output = self.rb3(output)
149
+ output = self.rbc3(output)
150
+ output = output.view(-1, int(self.hw/8)*int(self.hw/8)*8*self.dim)
151
+ output = self.ln1(output)*self.ln2(output)+self.lq(self.Square(output))
152
+ output = output.view(-1)
153
+ return output