dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
# -*-Encoding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Description:
|
|
4
|
+
The model architecture of the bidirectional vae.
|
|
5
|
+
Note: Part of the code are borrowed from 'https://github.com/NVlabs/NVAE'
|
|
6
|
+
Authors:
|
|
7
|
+
Li,Yan (liyan22021121@gmail.com)
|
|
8
|
+
"""
|
|
9
|
+
import math
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
from .neural_operations import OPS, EncCombinerCell, DecCombinerCell, Conv2D, get_skip_connection
|
|
14
|
+
from .utils import get_stride_for_cell_type, get_arch_cells
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Cell(nn.Module):
|
|
18
|
+
def __init__(self, Cin, Cout, cell_type, arch, use_se):
|
|
19
|
+
super(Cell, self).__init__()
|
|
20
|
+
self.cell_type = cell_type
|
|
21
|
+
stride = get_stride_for_cell_type(self.cell_type)
|
|
22
|
+
self.skip = get_skip_connection(Cin, stride, channel_mult=2)
|
|
23
|
+
self.use_se = use_se
|
|
24
|
+
self._num_nodes = len(arch)
|
|
25
|
+
self._ops = nn.ModuleList()
|
|
26
|
+
for i in range(self._num_nodes):
|
|
27
|
+
stride = get_stride_for_cell_type(self.cell_type) if i == 0 else 1
|
|
28
|
+
if i==0:
|
|
29
|
+
primitive = arch[i]
|
|
30
|
+
op = OPS[primitive](Cin, Cout, stride)
|
|
31
|
+
else:
|
|
32
|
+
primitive = arch[i]
|
|
33
|
+
op = OPS[primitive](Cout, Cout, stride)
|
|
34
|
+
self._ops.append(op)
|
|
35
|
+
|
|
36
|
+
def forward(self, s):
|
|
37
|
+
# skip branch
|
|
38
|
+
skip = self.skip(s)
|
|
39
|
+
for i in range(self._num_nodes):
|
|
40
|
+
s = self._ops[i](s)
|
|
41
|
+
return skip + 0.1 * s
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def soft_clamp5(x: torch.Tensor):
|
|
45
|
+
return x.div(5.).tanh_().mul(5.)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def sample_normal_jit(mu, sigma):
|
|
49
|
+
eps = mu.mul(0).normal_()
|
|
50
|
+
# print(eps)
|
|
51
|
+
z = eps.mul_(sigma).add_(mu)
|
|
52
|
+
# print(z.shape)
|
|
53
|
+
return z, eps
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Normal:
|
|
57
|
+
def __init__(self, mu, log_sigma, temp=1.):
|
|
58
|
+
self.mu = soft_clamp5(mu)
|
|
59
|
+
log_sigma = soft_clamp5(log_sigma)
|
|
60
|
+
self.sigma = torch.exp(log_sigma)
|
|
61
|
+
if temp != 1.:
|
|
62
|
+
self.sigma *= temp
|
|
63
|
+
|
|
64
|
+
def sample(self):
|
|
65
|
+
return sample_normal_jit(self.mu, self.sigma)
|
|
66
|
+
|
|
67
|
+
def sample_given_eps(self, eps):
|
|
68
|
+
return eps * self.sigma + self.mu
|
|
69
|
+
|
|
70
|
+
def log_p(self, samples):
|
|
71
|
+
normalized_samples = (samples - self.mu) / self.sigma
|
|
72
|
+
log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - torch.log(self.sigma)
|
|
73
|
+
return log_p
|
|
74
|
+
|
|
75
|
+
def kl(self, normal_dist):
|
|
76
|
+
term1 = (self.mu - normal_dist.mu) / normal_dist.sigma
|
|
77
|
+
term2 = self.sigma / normal_dist.sigma
|
|
78
|
+
return 0.5 * (term1 * term1 + term2 * term2) - 0.5 - torch.log(term2)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class NormalDecoder:
|
|
82
|
+
def __init__(self, param):
|
|
83
|
+
B, C, H, W = param.size()
|
|
84
|
+
self.num_c = C // 2
|
|
85
|
+
self.mu = param[:, :self.num_c, :, :] # B, 3, H, W
|
|
86
|
+
self.log_sigma = param[:, self.num_c:, :, :] # B, 3, H, W
|
|
87
|
+
self.sigma = torch.exp(self.log_sigma) + 1e-2
|
|
88
|
+
self.dist = Normal(self.mu, self.log_sigma)
|
|
89
|
+
|
|
90
|
+
def log_prob(self, samples):
|
|
91
|
+
return self.dist.log_p(samples)
|
|
92
|
+
|
|
93
|
+
def sample(self,):
|
|
94
|
+
x, _ = self.dist.sample()
|
|
95
|
+
return x
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def log_density_gaussian(sample, mu, logvar):
|
|
99
|
+
"""Calculates log density of a Gaussian.
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
x: torch.Tensor or np.ndarray or float
|
|
103
|
+
Value at which to compute the density.
|
|
104
|
+
mu: torch.Tensor or np.ndarray or float
|
|
105
|
+
Mean.
|
|
106
|
+
logvar: torch.Tensor or np.ndarray or float
|
|
107
|
+
Log variance.
|
|
108
|
+
"""
|
|
109
|
+
normalization = - 0.5 * (math.log(2 * math.pi) + logvar)
|
|
110
|
+
inv_var = torch.exp(-logvar)
|
|
111
|
+
log_density = normalization - 0.5 * ((sample - mu)**2 * inv_var)
|
|
112
|
+
log_qz = torch.logsumexp(torch.sum(log_density, [2,3]), dim=1, keepdim=False)
|
|
113
|
+
log_prod_qzi = torch.logsumexp(log_density, dim=1, keepdim=False).sum((1,2))
|
|
114
|
+
loss_p_z = (log_qz - log_prod_qzi)
|
|
115
|
+
loss_p_z = ((loss_p_z - torch.min(loss_p_z))/(torch.max(loss_p_z)-torch.min(loss_p_z))).mean()
|
|
116
|
+
return loss_p_z
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class Encoder(nn.Module):
|
|
120
|
+
def __init__(self, channel_mult,mult,prediction_length,num_preprocess_blocks,num_preprocess_cells,num_channels_enc,
|
|
121
|
+
arch_instance,num_latent_per_group,num_channels_dec,groups_per_scale,num_postprocess_blocks,num_postprocess_cells,embedding_dimension,hidden_size,target_dim,sequence_length,num_layers,dropout_rate):
|
|
122
|
+
super(Encoder, self).__init__()
|
|
123
|
+
|
|
124
|
+
self.channel_mult = channel_mult
|
|
125
|
+
self.mult = mult
|
|
126
|
+
self.prediction_length = prediction_length
|
|
127
|
+
self.num_preprocess_blocks = num_preprocess_blocks
|
|
128
|
+
self.num_preprocess_cells = num_preprocess_cells
|
|
129
|
+
self.num_channels_enc = num_channels_enc
|
|
130
|
+
self.arch_instance = get_arch_cells(arch_instance)
|
|
131
|
+
self.stem = Conv2D(1, num_channels_enc, 3, padding=1, bias=True)
|
|
132
|
+
self.num_latent_per_group = num_latent_per_group
|
|
133
|
+
|
|
134
|
+
self.num_channels_dec = num_channels_dec
|
|
135
|
+
self.groups_per_scale = groups_per_scale
|
|
136
|
+
self.num_postprocess_blocks = num_postprocess_blocks
|
|
137
|
+
self.num_postprocess_cells = num_postprocess_cells
|
|
138
|
+
self.use_se = False
|
|
139
|
+
self.input_size = embedding_dimension
|
|
140
|
+
self.hidden_size = hidden_size
|
|
141
|
+
self.projection = nn.Linear(embedding_dimension+hidden_size, target_dim)
|
|
142
|
+
|
|
143
|
+
c_scaling = self.channel_mult ** (self.num_preprocess_blocks) #4
|
|
144
|
+
spatial_scaling = 2 ** (self.num_preprocess_blocks) #4
|
|
145
|
+
|
|
146
|
+
prior_ftr0_size = (int(c_scaling * self.num_channels_dec),
|
|
147
|
+
sequence_length// spatial_scaling, #prediction_length
|
|
148
|
+
(embedding_dimension + hidden_size + 1) // spatial_scaling)
|
|
149
|
+
self.prior_ftr0 = nn.Parameter(torch.rand(size=prior_ftr0_size), requires_grad=True)
|
|
150
|
+
self.z0_size = [self.num_latent_per_group, sequence_length // spatial_scaling, #prediction_length
|
|
151
|
+
(embedding_dimension+ hidden_size + 1) // spatial_scaling]
|
|
152
|
+
|
|
153
|
+
self.pre_process = self.init_pre_process(self.mult)
|
|
154
|
+
self.enc_tower = self.init_encoder_tower(self.mult)
|
|
155
|
+
|
|
156
|
+
self.enc0 = nn.Sequential(nn.ELU(), Conv2D(self.num_channels_enc * self.mult,
|
|
157
|
+
self.num_channels_enc * self.mult, kernel_size=1, bias=True), nn.ELU())
|
|
158
|
+
|
|
159
|
+
self.enc_sampler, self.dec_sampler = self.init_sampler(self.mult)
|
|
160
|
+
|
|
161
|
+
self.dec_tower = self.init_decoder_tower(self.mult)
|
|
162
|
+
|
|
163
|
+
self.post_process = self.init_post_process(self.mult)
|
|
164
|
+
self.image_conditional = nn.Sequential(nn.ELU(),
|
|
165
|
+
Conv2D(int(self.num_channels_dec * self.mult), 2, 3, padding=1, bias=True))
|
|
166
|
+
self.rnn = nn.GRU(
|
|
167
|
+
input_size=sequence_length,
|
|
168
|
+
hidden_size=prediction_length,
|
|
169
|
+
num_layers=num_layers,
|
|
170
|
+
dropout=dropout_rate,
|
|
171
|
+
batch_first=True,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def init_pre_process(self, mult):
|
|
175
|
+
pre_process = nn.ModuleList()
|
|
176
|
+
for b in range(self.num_preprocess_blocks):
|
|
177
|
+
for c in range(self.num_preprocess_cells):
|
|
178
|
+
if c == self.num_preprocess_cells - 1:
|
|
179
|
+
arch = self.arch_instance['down_pre']
|
|
180
|
+
num_ci = int(self.num_channels_enc * mult)
|
|
181
|
+
num_co = int(self.channel_mult * num_ci)
|
|
182
|
+
cell = Cell(num_ci, num_co, cell_type='down_pre', arch=arch, use_se=self.use_se)
|
|
183
|
+
mult = self.channel_mult * mult
|
|
184
|
+
else:
|
|
185
|
+
arch = self.arch_instance['normal_pre']
|
|
186
|
+
num_c = self.num_channels_enc * mult
|
|
187
|
+
cell = Cell(num_c, num_c, cell_type='normal_pre', arch=arch, use_se=self.use_se)
|
|
188
|
+
pre_process.append(cell)
|
|
189
|
+
self.mult = mult
|
|
190
|
+
return pre_process
|
|
191
|
+
|
|
192
|
+
def init_encoder_tower(self, mult):
|
|
193
|
+
enc_tower = nn.ModuleList()
|
|
194
|
+
for g in range(self.groups_per_scale):
|
|
195
|
+
arch = self.arch_instance['normal_enc']
|
|
196
|
+
num_c = int(self.num_channels_enc * mult)
|
|
197
|
+
cell = Cell(num_c, num_c, cell_type='normal_enc', arch=arch, use_se=self.use_se)
|
|
198
|
+
enc_tower.append(cell)
|
|
199
|
+
|
|
200
|
+
if not (g == self.groups_per_scale - 1):
|
|
201
|
+
num_ce = int(self.num_channels_enc * mult)
|
|
202
|
+
num_cd = int(self.num_channels_dec * mult)
|
|
203
|
+
cell = EncCombinerCell(num_ce, num_cd, num_ce, cell_type='combiner_enc')
|
|
204
|
+
enc_tower.append(cell)
|
|
205
|
+
|
|
206
|
+
self.mult = mult
|
|
207
|
+
return enc_tower
|
|
208
|
+
|
|
209
|
+
def init_decoder_tower(self, mult):
|
|
210
|
+
|
|
211
|
+
dec_tower = nn.ModuleList()
|
|
212
|
+
for g in range(self.groups_per_scale):
|
|
213
|
+
num_c = int(self.num_channels_dec * mult)
|
|
214
|
+
if not (g == 0):
|
|
215
|
+
arch = self.arch_instance['normal_dec']
|
|
216
|
+
cell = Cell(num_c, num_c, cell_type='normal_dec', arch=arch, use_se=self.use_se)
|
|
217
|
+
dec_tower.append(cell)
|
|
218
|
+
#print(num_c)
|
|
219
|
+
cell = DecCombinerCell(num_c, self.num_latent_per_group, num_c, cell_type='combiner_dec')
|
|
220
|
+
dec_tower.append(cell)
|
|
221
|
+
self.mult = mult
|
|
222
|
+
return dec_tower
|
|
223
|
+
|
|
224
|
+
def init_sampler(self, mult):
|
|
225
|
+
enc_sampler = nn.ModuleList()
|
|
226
|
+
dec_sampler = nn.ModuleList()
|
|
227
|
+
for g in range(self.groups_per_scale):
|
|
228
|
+
num_c = int(self.num_channels_enc * mult)
|
|
229
|
+
cell = Conv2D(num_c, 2 * self.num_latent_per_group, kernel_size=3, padding=1, bias=True)
|
|
230
|
+
enc_sampler.append(cell)
|
|
231
|
+
if g != 0:
|
|
232
|
+
num_c = int(self.num_channels_dec * mult)
|
|
233
|
+
cell = nn.Sequential(
|
|
234
|
+
nn.ELU(),
|
|
235
|
+
Conv2D(num_c, 2 * self.num_latent_per_group, kernel_size=1, padding=0, bias=True))
|
|
236
|
+
dec_sampler.append(cell)
|
|
237
|
+
mult = mult/self.channel_mult
|
|
238
|
+
return enc_sampler, dec_sampler
|
|
239
|
+
|
|
240
|
+
def init_post_process(self, mult):
|
|
241
|
+
post_process = nn.ModuleList()
|
|
242
|
+
for b in range(self.num_postprocess_blocks):
|
|
243
|
+
for c in range(self.num_postprocess_cells):
|
|
244
|
+
if c == 0:
|
|
245
|
+
arch = self.arch_instance['up_post']
|
|
246
|
+
num_ci = int(self.num_channels_dec * mult)
|
|
247
|
+
num_co = int(num_ci / self.channel_mult)
|
|
248
|
+
cell = Cell(num_ci, num_co, cell_type='up_post', arch=arch, use_se=self.use_se)
|
|
249
|
+
mult = mult / self.channel_mult
|
|
250
|
+
else:
|
|
251
|
+
arch = self.arch_instance['normal_post']
|
|
252
|
+
num_c = int(self.num_channels_dec * mult)
|
|
253
|
+
cell = Cell(num_c, num_c, cell_type='normal_post', arch=arch, use_se=self.use_se)
|
|
254
|
+
post_process.append(cell)
|
|
255
|
+
self.mult = mult
|
|
256
|
+
return post_process
|
|
257
|
+
|
|
258
|
+
def forward(self, x):
|
|
259
|
+
|
|
260
|
+
s = self.stem(2 * x - 1.0)
|
|
261
|
+
for cell in self.pre_process:
|
|
262
|
+
s = cell(s)
|
|
263
|
+
combiner_cells_enc = []
|
|
264
|
+
combiner_cells_s = []
|
|
265
|
+
all_z = []
|
|
266
|
+
for cell in self.enc_tower:
|
|
267
|
+
if cell.cell_type == 'combiner_enc':
|
|
268
|
+
combiner_cells_enc.append(cell)
|
|
269
|
+
combiner_cells_s.append(s)
|
|
270
|
+
else:
|
|
271
|
+
s = cell(s)
|
|
272
|
+
|
|
273
|
+
combiner_cells_enc.reverse()
|
|
274
|
+
combiner_cells_s.reverse()
|
|
275
|
+
idx_dec = 0
|
|
276
|
+
ftr = self.enc0(s) #conv
|
|
277
|
+
param0 = self.enc_sampler[idx_dec](ftr) # another conv2d
|
|
278
|
+
mu_q, log_sig_q = torch.chunk(param0, 2, dim=1)
|
|
279
|
+
dist = Normal(mu_q, log_sig_q)
|
|
280
|
+
z, _ = dist.sample() #z_0
|
|
281
|
+
all_z.append(z)
|
|
282
|
+
loss_qz = log_density_gaussian(z, mu_q, log_sig_q)
|
|
283
|
+
# total_c = [loss_qz]
|
|
284
|
+
idx_dec = 0
|
|
285
|
+
s = self.prior_ftr0.unsqueeze(0) # random value
|
|
286
|
+
batch_size = z.size(0)
|
|
287
|
+
s = s.expand(batch_size, -1, -1, -1)
|
|
288
|
+
total_c = 0
|
|
289
|
+
idx_dec = 0
|
|
290
|
+
|
|
291
|
+
for cell in self.dec_tower:
|
|
292
|
+
if cell.cell_type == 'combiner_dec':
|
|
293
|
+
if idx_dec > 0:
|
|
294
|
+
ftr = combiner_cells_enc[idx_dec - 1](combiner_cells_s[idx_dec - 1], s)
|
|
295
|
+
param = self.enc_sampler[idx_dec](ftr)
|
|
296
|
+
mu_q, log_sig_q = torch.chunk(param, 2, dim=1)
|
|
297
|
+
dist = Normal(mu_q, log_sig_q)
|
|
298
|
+
z, _ = dist.sample() # z_n
|
|
299
|
+
all_z.append(z)
|
|
300
|
+
#print(z.shape)
|
|
301
|
+
loss_qz = log_density_gaussian(z, mu_q, log_sig_q)
|
|
302
|
+
total_c += loss_qz
|
|
303
|
+
|
|
304
|
+
#total_c.append(loss_qz)
|
|
305
|
+
s = cell(s, z)
|
|
306
|
+
idx_dec += 1
|
|
307
|
+
else:
|
|
308
|
+
s = cell(s)
|
|
309
|
+
|
|
310
|
+
for cell in self.post_process:
|
|
311
|
+
s = cell(s)
|
|
312
|
+
# print(s.shape)
|
|
313
|
+
|
|
314
|
+
logits = self.image_conditional(s)
|
|
315
|
+
tmp_tot =[]
|
|
316
|
+
for i in range(idx_dec):
|
|
317
|
+
tmp, _ = self.rnn(logits[:,i,:,:].squeeze().permute(0,2,1))
|
|
318
|
+
tmp_tot.append(tmp.permute(0,2,1))
|
|
319
|
+
logits = torch.stack(tmp_tot,1)
|
|
320
|
+
logits = self.projection(logits[...,-(self.input_size + self.hidden_size):])
|
|
321
|
+
# total_c = torch.mean(torch.tensor(total_c))
|
|
322
|
+
total_c = total_c/idx_dec
|
|
323
|
+
return logits, total_c, all_z# , log_q, log_p, kl_all, kl_diag
|
|
324
|
+
|
|
325
|
+
def decoder_output(self, logits):
|
|
326
|
+
return NormalDecoder(logits)
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# -*-Encoding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Authors:
|
|
4
|
+
Li,Yan (liyan22021121@gmail.com)
|
|
5
|
+
"""
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import numpy as np
|
|
9
|
+
from .resnet import Res12_Quadratic
|
|
10
|
+
from .diffusion_process import GaussianDiffusion, get_beta_schedule
|
|
11
|
+
from .encoder import Encoder
|
|
12
|
+
from .embedding import DataEmbedding
|
|
13
|
+
from ...data_structure.utils import beauty_string
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class diffusion_generate(nn.Module):
|
|
17
|
+
def __init__(self, target_dim,embedding_dimension,prediction_length,sequence_length,scale,hidden_size,num_layers,dropout_rate,diff_steps,loss_type,beta_end,beta_schedule, channel_mult,mult,
|
|
18
|
+
num_preprocess_blocks,num_preprocess_cells,num_channels_enc,arch_instance,num_latent_per_group,num_channels_dec,groups_per_scale,num_postprocess_blocks,num_postprocess_cells):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.target_dim = target_dim
|
|
21
|
+
self.input_size = embedding_dimension
|
|
22
|
+
self.prediction_length = prediction_length
|
|
23
|
+
self.seq_length = sequence_length
|
|
24
|
+
self.scale = scale
|
|
25
|
+
self.rnn = nn.GRU(
|
|
26
|
+
input_size=self.input_size,
|
|
27
|
+
hidden_size=hidden_size,
|
|
28
|
+
num_layers=num_layers,
|
|
29
|
+
dropout=dropout_rate,
|
|
30
|
+
batch_first=True,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
self.generative = Encoder(channel_mult,mult,prediction_length,
|
|
34
|
+
#sequence_length,
|
|
35
|
+
num_preprocess_blocks,num_preprocess_cells,num_channels_enc,arch_instance,num_latent_per_group,num_channels_dec,groups_per_scale,num_postprocess_blocks,num_postprocess_cells,embedding_dimension,hidden_size,target_dim,sequence_length,num_layers,dropout_rate)
|
|
36
|
+
self.diffusion = GaussianDiffusion(
|
|
37
|
+
self.generative,
|
|
38
|
+
input_size=target_dim,
|
|
39
|
+
diff_steps=diff_steps,
|
|
40
|
+
loss_type=loss_type,
|
|
41
|
+
beta_end=beta_end,
|
|
42
|
+
beta_schedule=beta_schedule,
|
|
43
|
+
scale = scale,
|
|
44
|
+
)
|
|
45
|
+
self.projection = nn.Linear(embedding_dimension+hidden_size, embedding_dimension)
|
|
46
|
+
|
|
47
|
+
def forward(self, past_time_feat, future_time_feat, t):
|
|
48
|
+
"""
|
|
49
|
+
Output the generative results and related variables.
|
|
50
|
+
"""
|
|
51
|
+
time_feat, _ = self.rnn(past_time_feat)
|
|
52
|
+
input = torch.cat([time_feat, past_time_feat], dim=-1)
|
|
53
|
+
output, y_noisy, total_c, all_z = self.diffusion.log_prob(input, future_time_feat, t)
|
|
54
|
+
return output, y_noisy, total_c, all_z
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class denoise_net(nn.Module):
|
|
58
|
+
def __init__(self, target_dim,embedding_dimension,prediction_length,sequence_length,scale,hidden_size,num_layers,dropout_rate,diff_steps,loss_type,beta_end,beta_schedule, channel_mult,mult,
|
|
59
|
+
num_preprocess_blocks,num_preprocess_cells,num_channels_enc,arch_instance,num_latent_per_group,num_channels_dec,groups_per_scale,num_postprocess_blocks,num_postprocess_cells,beta_start,input_dim,freq,embs):
|
|
60
|
+
super().__init__()
|
|
61
|
+
"""
|
|
62
|
+
The whole model architecture consists of three main parts, the coupled diffusion process and the generative model are
|
|
63
|
+
included in diffusion_generate module, an resnet is used to calculate the score.
|
|
64
|
+
"""
|
|
65
|
+
# ResNet that used to calculate the scores.
|
|
66
|
+
self.score_net = Res12_Quadratic(1, 64, 32, normalize=False, AF=nn.ELU())
|
|
67
|
+
|
|
68
|
+
# Generate the diffusion schedule.
|
|
69
|
+
sigmas = get_beta_schedule(beta_schedule, beta_start, beta_end, diff_steps)
|
|
70
|
+
alphas = 1.0 - sigmas*0.5
|
|
71
|
+
self.alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0))
|
|
72
|
+
self.sqrt_alphas_cumprod = torch.tensor(np.sqrt(np.cumprod(alphas, axis=0)))
|
|
73
|
+
self.sqrt_one_minus_alphas_cumprod = torch.tensor(np.sqrt(1-np.cumprod(alphas, axis=0)))
|
|
74
|
+
self.sigmas = torch.tensor(1. - self.alphas_cumprod)
|
|
75
|
+
|
|
76
|
+
# The generative bvae model.
|
|
77
|
+
self.diffusion_gen = diffusion_generate(target_dim,embedding_dimension,prediction_length,sequence_length,scale,hidden_size,num_layers,dropout_rate,diff_steps,loss_type,beta_end,beta_schedule, channel_mult,mult,
|
|
78
|
+
num_preprocess_blocks,num_preprocess_cells,num_channels_enc,arch_instance,num_latent_per_group,num_channels_dec,groups_per_scale,num_postprocess_blocks,num_postprocess_cells)
|
|
79
|
+
|
|
80
|
+
# Data embedding module.
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
self.embedding = DataEmbedding(input_dim, embedding_dimension, embs,dropout_rate)
|
|
85
|
+
|
|
86
|
+
def extract(self, a, t, x_shape):
|
|
87
|
+
""" extract the t-th element from a"""
|
|
88
|
+
b, *_ = t.shape
|
|
89
|
+
out = a.gather(-1, t)
|
|
90
|
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
|
91
|
+
|
|
92
|
+
def forward(self, past_time_feat, mark, future_time_feat, t):
|
|
93
|
+
"""
|
|
94
|
+
Params:
|
|
95
|
+
past_time_feat: Tensor
|
|
96
|
+
the input time series.
|
|
97
|
+
mark: Tensor
|
|
98
|
+
the time feature mark.
|
|
99
|
+
future_time_feat: Tensor
|
|
100
|
+
the target time series.
|
|
101
|
+
t: Tensor
|
|
102
|
+
the diffusion step.
|
|
103
|
+
-------------
|
|
104
|
+
return:
|
|
105
|
+
output: Tensor
|
|
106
|
+
The gauaaian distribution of the generative results.
|
|
107
|
+
y_noisy: Tensor
|
|
108
|
+
The diffused target.
|
|
109
|
+
total_c: Float
|
|
110
|
+
Total correlation of all the latent variables in the BVAE, used for disentangling.
|
|
111
|
+
all_z: List
|
|
112
|
+
All the latent variables of bvae.
|
|
113
|
+
loss: Float
|
|
114
|
+
The loss of score matching.
|
|
115
|
+
"""
|
|
116
|
+
# Embed the original time series.
|
|
117
|
+
input = self.embedding(past_time_feat, mark)
|
|
118
|
+
#input, _ = self.diffusion_gen.rnn(input)
|
|
119
|
+
# Output the distribution of the generative results, the sampled generative results and the total correlations of the generative model.
|
|
120
|
+
output, y_noisy, total_c, all_z = self.diffusion_gen(input, future_time_feat, t)
|
|
121
|
+
|
|
122
|
+
# Score matching.
|
|
123
|
+
sigmas_t = self.extract(self.sigmas.to(y_noisy.device), t, y_noisy.shape)
|
|
124
|
+
y = future_time_feat.unsqueeze(1).float()
|
|
125
|
+
y_noisy1 = output.sample().float().requires_grad_()
|
|
126
|
+
E = self.score_net(y_noisy1).sum()
|
|
127
|
+
|
|
128
|
+
# The Loss of multiscale score matching.
|
|
129
|
+
grad_x = torch.autograd.grad(E, y_noisy1, create_graph=True)[0]
|
|
130
|
+
loss = torch.mean(torch.sum(((y-y_noisy1.detach())+grad_x*0.001)**2*sigmas_t, [1,2,3])).float()
|
|
131
|
+
return output, y_noisy, total_c, all_z, loss
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class pred_net(denoise_net):
|
|
135
|
+
def forward(self, x, mark):
|
|
136
|
+
"""
|
|
137
|
+
generate the prediction by the trained model.
|
|
138
|
+
Return:
|
|
139
|
+
y: The noisy generative results
|
|
140
|
+
out: Denoised results, remove the noise from y through score matching.
|
|
141
|
+
tc: Total correlations, indicator of extent of disentangling.
|
|
142
|
+
"""
|
|
143
|
+
input = self.embedding(x, mark)
|
|
144
|
+
x_t, _ = self.diffusion_gen.rnn(input)
|
|
145
|
+
input = torch.cat([x_t, input], dim=-1)
|
|
146
|
+
input = input.unsqueeze(1)
|
|
147
|
+
logits, tc, all_z= self.diffusion_gen.generative(input)
|
|
148
|
+
output = self.diffusion_gen.generative.decoder_output(logits)
|
|
149
|
+
y = output.mu.float().requires_grad_()
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
E = self.score_net(y).sum()
|
|
153
|
+
grad_x = torch.autograd.grad(E, y, create_graph=True,allow_unused=True)[0]
|
|
154
|
+
except Exception as e:
|
|
155
|
+
beauty_string(e,'')
|
|
156
|
+
grad_x = 0
|
|
157
|
+
|
|
158
|
+
out = y - grad_x*0.001
|
|
159
|
+
return y, out, tc, all_z
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class Discriminator(nn.Module):
|
|
163
|
+
def __init__(self, neg_slope=0.2, latent_dim=10, hidden_units=1000, out_units=2):
|
|
164
|
+
"""Discriminator proposed in [1].
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
neg_slope: float
|
|
168
|
+
Hyperparameter for the Leaky ReLu
|
|
169
|
+
latent_dim : int
|
|
170
|
+
Dimensionality of latent variables.
|
|
171
|
+
hidden_units: int
|
|
172
|
+
Number of hidden units in the MLP
|
|
173
|
+
Model Architecture
|
|
174
|
+
------------
|
|
175
|
+
- 6 layer multi-layer perceptron, each with 1000 hidden units
|
|
176
|
+
- Leaky ReLu activations
|
|
177
|
+
- Output 2 logits
|
|
178
|
+
References:
|
|
179
|
+
[1] Kim, Hyunjik, and Andriy Mnih. "Disentangling by factorising."
|
|
180
|
+
arXiv preprint arXiv:1802.05983 (2018).
|
|
181
|
+
"""
|
|
182
|
+
super(Discriminator, self).__init__()
|
|
183
|
+
|
|
184
|
+
# Activation parameters
|
|
185
|
+
self.neg_slope = neg_slope
|
|
186
|
+
self.leaky_relu = nn.LeakyReLU(self.neg_slope, True)
|
|
187
|
+
|
|
188
|
+
# Layer parameters
|
|
189
|
+
self.z_dim = latent_dim
|
|
190
|
+
self.hidden_units = hidden_units
|
|
191
|
+
# theoretically 1 with sigmoid but gives bad results => use 2 and softmax
|
|
192
|
+
out_units = out_units
|
|
193
|
+
|
|
194
|
+
# Fully connected layers
|
|
195
|
+
self.lin1 = nn.Linear(self.z_dim, hidden_units)
|
|
196
|
+
self.lin2 = nn.Linear(hidden_units, hidden_units)
|
|
197
|
+
self.lin3 = nn.Linear(hidden_units, hidden_units)
|
|
198
|
+
self.lin4 = nn.Linear(hidden_units, hidden_units)
|
|
199
|
+
self.lin5 = nn.Linear(hidden_units, hidden_units)
|
|
200
|
+
self.lin6 = nn.Linear(hidden_units, out_units)
|
|
201
|
+
self.softmax = nn.Softmax()
|
|
202
|
+
|
|
203
|
+
def forward(self, z):
|
|
204
|
+
# Fully connected layers with leaky ReLu activations
|
|
205
|
+
z = self.leaky_relu(self.lin1(z))
|
|
206
|
+
z = self.leaky_relu(self.lin2(z))
|
|
207
|
+
z = self.leaky_relu(self.lin3(z))
|
|
208
|
+
z = self.leaky_relu(self.lin4(z))
|
|
209
|
+
z = self.leaky_relu(self.lin5(z))
|
|
210
|
+
z = self.lin6(z)
|
|
211
|
+
return z
|