dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,326 @@
1
+ # -*-Encoding: utf-8 -*-
2
+ """
3
+ Description:
4
+ The model architecture of the bidirectional vae.
5
+ Note: Part of the code are borrowed from 'https://github.com/NVlabs/NVAE'
6
+ Authors:
7
+ Li,Yan (liyan22021121@gmail.com)
8
+ """
9
+ import math
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from .neural_operations import OPS, EncCombinerCell, DecCombinerCell, Conv2D, get_skip_connection
14
+ from .utils import get_stride_for_cell_type, get_arch_cells
15
+
16
+
17
+ class Cell(nn.Module):
18
+ def __init__(self, Cin, Cout, cell_type, arch, use_se):
19
+ super(Cell, self).__init__()
20
+ self.cell_type = cell_type
21
+ stride = get_stride_for_cell_type(self.cell_type)
22
+ self.skip = get_skip_connection(Cin, stride, channel_mult=2)
23
+ self.use_se = use_se
24
+ self._num_nodes = len(arch)
25
+ self._ops = nn.ModuleList()
26
+ for i in range(self._num_nodes):
27
+ stride = get_stride_for_cell_type(self.cell_type) if i == 0 else 1
28
+ if i==0:
29
+ primitive = arch[i]
30
+ op = OPS[primitive](Cin, Cout, stride)
31
+ else:
32
+ primitive = arch[i]
33
+ op = OPS[primitive](Cout, Cout, stride)
34
+ self._ops.append(op)
35
+
36
+ def forward(self, s):
37
+ # skip branch
38
+ skip = self.skip(s)
39
+ for i in range(self._num_nodes):
40
+ s = self._ops[i](s)
41
+ return skip + 0.1 * s
42
+
43
+
44
+ def soft_clamp5(x: torch.Tensor):
45
+ return x.div(5.).tanh_().mul(5.)
46
+
47
+
48
+ def sample_normal_jit(mu, sigma):
49
+ eps = mu.mul(0).normal_()
50
+ # print(eps)
51
+ z = eps.mul_(sigma).add_(mu)
52
+ # print(z.shape)
53
+ return z, eps
54
+
55
+
56
+ class Normal:
57
+ def __init__(self, mu, log_sigma, temp=1.):
58
+ self.mu = soft_clamp5(mu)
59
+ log_sigma = soft_clamp5(log_sigma)
60
+ self.sigma = torch.exp(log_sigma)
61
+ if temp != 1.:
62
+ self.sigma *= temp
63
+
64
+ def sample(self):
65
+ return sample_normal_jit(self.mu, self.sigma)
66
+
67
+ def sample_given_eps(self, eps):
68
+ return eps * self.sigma + self.mu
69
+
70
+ def log_p(self, samples):
71
+ normalized_samples = (samples - self.mu) / self.sigma
72
+ log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - torch.log(self.sigma)
73
+ return log_p
74
+
75
+ def kl(self, normal_dist):
76
+ term1 = (self.mu - normal_dist.mu) / normal_dist.sigma
77
+ term2 = self.sigma / normal_dist.sigma
78
+ return 0.5 * (term1 * term1 + term2 * term2) - 0.5 - torch.log(term2)
79
+
80
+
81
+ class NormalDecoder:
82
+ def __init__(self, param):
83
+ B, C, H, W = param.size()
84
+ self.num_c = C // 2
85
+ self.mu = param[:, :self.num_c, :, :] # B, 3, H, W
86
+ self.log_sigma = param[:, self.num_c:, :, :] # B, 3, H, W
87
+ self.sigma = torch.exp(self.log_sigma) + 1e-2
88
+ self.dist = Normal(self.mu, self.log_sigma)
89
+
90
+ def log_prob(self, samples):
91
+ return self.dist.log_p(samples)
92
+
93
+ def sample(self,):
94
+ x, _ = self.dist.sample()
95
+ return x
96
+
97
+
98
+ def log_density_gaussian(sample, mu, logvar):
99
+ """Calculates log density of a Gaussian.
100
+ Parameters
101
+ ----------
102
+ x: torch.Tensor or np.ndarray or float
103
+ Value at which to compute the density.
104
+ mu: torch.Tensor or np.ndarray or float
105
+ Mean.
106
+ logvar: torch.Tensor or np.ndarray or float
107
+ Log variance.
108
+ """
109
+ normalization = - 0.5 * (math.log(2 * math.pi) + logvar)
110
+ inv_var = torch.exp(-logvar)
111
+ log_density = normalization - 0.5 * ((sample - mu)**2 * inv_var)
112
+ log_qz = torch.logsumexp(torch.sum(log_density, [2,3]), dim=1, keepdim=False)
113
+ log_prod_qzi = torch.logsumexp(log_density, dim=1, keepdim=False).sum((1,2))
114
+ loss_p_z = (log_qz - log_prod_qzi)
115
+ loss_p_z = ((loss_p_z - torch.min(loss_p_z))/(torch.max(loss_p_z)-torch.min(loss_p_z))).mean()
116
+ return loss_p_z
117
+
118
+
119
+ class Encoder(nn.Module):
120
+ def __init__(self, channel_mult,mult,prediction_length,num_preprocess_blocks,num_preprocess_cells,num_channels_enc,
121
+ arch_instance,num_latent_per_group,num_channels_dec,groups_per_scale,num_postprocess_blocks,num_postprocess_cells,embedding_dimension,hidden_size,target_dim,sequence_length,num_layers,dropout_rate):
122
+ super(Encoder, self).__init__()
123
+
124
+ self.channel_mult = channel_mult
125
+ self.mult = mult
126
+ self.prediction_length = prediction_length
127
+ self.num_preprocess_blocks = num_preprocess_blocks
128
+ self.num_preprocess_cells = num_preprocess_cells
129
+ self.num_channels_enc = num_channels_enc
130
+ self.arch_instance = get_arch_cells(arch_instance)
131
+ self.stem = Conv2D(1, num_channels_enc, 3, padding=1, bias=True)
132
+ self.num_latent_per_group = num_latent_per_group
133
+
134
+ self.num_channels_dec = num_channels_dec
135
+ self.groups_per_scale = groups_per_scale
136
+ self.num_postprocess_blocks = num_postprocess_blocks
137
+ self.num_postprocess_cells = num_postprocess_cells
138
+ self.use_se = False
139
+ self.input_size = embedding_dimension
140
+ self.hidden_size = hidden_size
141
+ self.projection = nn.Linear(embedding_dimension+hidden_size, target_dim)
142
+
143
+ c_scaling = self.channel_mult ** (self.num_preprocess_blocks) #4
144
+ spatial_scaling = 2 ** (self.num_preprocess_blocks) #4
145
+
146
+ prior_ftr0_size = (int(c_scaling * self.num_channels_dec),
147
+ sequence_length// spatial_scaling, #prediction_length
148
+ (embedding_dimension + hidden_size + 1) // spatial_scaling)
149
+ self.prior_ftr0 = nn.Parameter(torch.rand(size=prior_ftr0_size), requires_grad=True)
150
+ self.z0_size = [self.num_latent_per_group, sequence_length // spatial_scaling, #prediction_length
151
+ (embedding_dimension+ hidden_size + 1) // spatial_scaling]
152
+
153
+ self.pre_process = self.init_pre_process(self.mult)
154
+ self.enc_tower = self.init_encoder_tower(self.mult)
155
+
156
+ self.enc0 = nn.Sequential(nn.ELU(), Conv2D(self.num_channels_enc * self.mult,
157
+ self.num_channels_enc * self.mult, kernel_size=1, bias=True), nn.ELU())
158
+
159
+ self.enc_sampler, self.dec_sampler = self.init_sampler(self.mult)
160
+
161
+ self.dec_tower = self.init_decoder_tower(self.mult)
162
+
163
+ self.post_process = self.init_post_process(self.mult)
164
+ self.image_conditional = nn.Sequential(nn.ELU(),
165
+ Conv2D(int(self.num_channels_dec * self.mult), 2, 3, padding=1, bias=True))
166
+ self.rnn = nn.GRU(
167
+ input_size=sequence_length,
168
+ hidden_size=prediction_length,
169
+ num_layers=num_layers,
170
+ dropout=dropout_rate,
171
+ batch_first=True,
172
+ )
173
+
174
+ def init_pre_process(self, mult):
175
+ pre_process = nn.ModuleList()
176
+ for b in range(self.num_preprocess_blocks):
177
+ for c in range(self.num_preprocess_cells):
178
+ if c == self.num_preprocess_cells - 1:
179
+ arch = self.arch_instance['down_pre']
180
+ num_ci = int(self.num_channels_enc * mult)
181
+ num_co = int(self.channel_mult * num_ci)
182
+ cell = Cell(num_ci, num_co, cell_type='down_pre', arch=arch, use_se=self.use_se)
183
+ mult = self.channel_mult * mult
184
+ else:
185
+ arch = self.arch_instance['normal_pre']
186
+ num_c = self.num_channels_enc * mult
187
+ cell = Cell(num_c, num_c, cell_type='normal_pre', arch=arch, use_se=self.use_se)
188
+ pre_process.append(cell)
189
+ self.mult = mult
190
+ return pre_process
191
+
192
+ def init_encoder_tower(self, mult):
193
+ enc_tower = nn.ModuleList()
194
+ for g in range(self.groups_per_scale):
195
+ arch = self.arch_instance['normal_enc']
196
+ num_c = int(self.num_channels_enc * mult)
197
+ cell = Cell(num_c, num_c, cell_type='normal_enc', arch=arch, use_se=self.use_se)
198
+ enc_tower.append(cell)
199
+
200
+ if not (g == self.groups_per_scale - 1):
201
+ num_ce = int(self.num_channels_enc * mult)
202
+ num_cd = int(self.num_channels_dec * mult)
203
+ cell = EncCombinerCell(num_ce, num_cd, num_ce, cell_type='combiner_enc')
204
+ enc_tower.append(cell)
205
+
206
+ self.mult = mult
207
+ return enc_tower
208
+
209
+ def init_decoder_tower(self, mult):
210
+
211
+ dec_tower = nn.ModuleList()
212
+ for g in range(self.groups_per_scale):
213
+ num_c = int(self.num_channels_dec * mult)
214
+ if not (g == 0):
215
+ arch = self.arch_instance['normal_dec']
216
+ cell = Cell(num_c, num_c, cell_type='normal_dec', arch=arch, use_se=self.use_se)
217
+ dec_tower.append(cell)
218
+ #print(num_c)
219
+ cell = DecCombinerCell(num_c, self.num_latent_per_group, num_c, cell_type='combiner_dec')
220
+ dec_tower.append(cell)
221
+ self.mult = mult
222
+ return dec_tower
223
+
224
+ def init_sampler(self, mult):
225
+ enc_sampler = nn.ModuleList()
226
+ dec_sampler = nn.ModuleList()
227
+ for g in range(self.groups_per_scale):
228
+ num_c = int(self.num_channels_enc * mult)
229
+ cell = Conv2D(num_c, 2 * self.num_latent_per_group, kernel_size=3, padding=1, bias=True)
230
+ enc_sampler.append(cell)
231
+ if g != 0:
232
+ num_c = int(self.num_channels_dec * mult)
233
+ cell = nn.Sequential(
234
+ nn.ELU(),
235
+ Conv2D(num_c, 2 * self.num_latent_per_group, kernel_size=1, padding=0, bias=True))
236
+ dec_sampler.append(cell)
237
+ mult = mult/self.channel_mult
238
+ return enc_sampler, dec_sampler
239
+
240
+ def init_post_process(self, mult):
241
+ post_process = nn.ModuleList()
242
+ for b in range(self.num_postprocess_blocks):
243
+ for c in range(self.num_postprocess_cells):
244
+ if c == 0:
245
+ arch = self.arch_instance['up_post']
246
+ num_ci = int(self.num_channels_dec * mult)
247
+ num_co = int(num_ci / self.channel_mult)
248
+ cell = Cell(num_ci, num_co, cell_type='up_post', arch=arch, use_se=self.use_se)
249
+ mult = mult / self.channel_mult
250
+ else:
251
+ arch = self.arch_instance['normal_post']
252
+ num_c = int(self.num_channels_dec * mult)
253
+ cell = Cell(num_c, num_c, cell_type='normal_post', arch=arch, use_se=self.use_se)
254
+ post_process.append(cell)
255
+ self.mult = mult
256
+ return post_process
257
+
258
+ def forward(self, x):
259
+
260
+ s = self.stem(2 * x - 1.0)
261
+ for cell in self.pre_process:
262
+ s = cell(s)
263
+ combiner_cells_enc = []
264
+ combiner_cells_s = []
265
+ all_z = []
266
+ for cell in self.enc_tower:
267
+ if cell.cell_type == 'combiner_enc':
268
+ combiner_cells_enc.append(cell)
269
+ combiner_cells_s.append(s)
270
+ else:
271
+ s = cell(s)
272
+
273
+ combiner_cells_enc.reverse()
274
+ combiner_cells_s.reverse()
275
+ idx_dec = 0
276
+ ftr = self.enc0(s) #conv
277
+ param0 = self.enc_sampler[idx_dec](ftr) # another conv2d
278
+ mu_q, log_sig_q = torch.chunk(param0, 2, dim=1)
279
+ dist = Normal(mu_q, log_sig_q)
280
+ z, _ = dist.sample() #z_0
281
+ all_z.append(z)
282
+ loss_qz = log_density_gaussian(z, mu_q, log_sig_q)
283
+ # total_c = [loss_qz]
284
+ idx_dec = 0
285
+ s = self.prior_ftr0.unsqueeze(0) # random value
286
+ batch_size = z.size(0)
287
+ s = s.expand(batch_size, -1, -1, -1)
288
+ total_c = 0
289
+ idx_dec = 0
290
+
291
+ for cell in self.dec_tower:
292
+ if cell.cell_type == 'combiner_dec':
293
+ if idx_dec > 0:
294
+ ftr = combiner_cells_enc[idx_dec - 1](combiner_cells_s[idx_dec - 1], s)
295
+ param = self.enc_sampler[idx_dec](ftr)
296
+ mu_q, log_sig_q = torch.chunk(param, 2, dim=1)
297
+ dist = Normal(mu_q, log_sig_q)
298
+ z, _ = dist.sample() # z_n
299
+ all_z.append(z)
300
+ #print(z.shape)
301
+ loss_qz = log_density_gaussian(z, mu_q, log_sig_q)
302
+ total_c += loss_qz
303
+
304
+ #total_c.append(loss_qz)
305
+ s = cell(s, z)
306
+ idx_dec += 1
307
+ else:
308
+ s = cell(s)
309
+
310
+ for cell in self.post_process:
311
+ s = cell(s)
312
+ # print(s.shape)
313
+
314
+ logits = self.image_conditional(s)
315
+ tmp_tot =[]
316
+ for i in range(idx_dec):
317
+ tmp, _ = self.rnn(logits[:,i,:,:].squeeze().permute(0,2,1))
318
+ tmp_tot.append(tmp.permute(0,2,1))
319
+ logits = torch.stack(tmp_tot,1)
320
+ logits = self.projection(logits[...,-(self.input_size + self.hidden_size):])
321
+ # total_c = torch.mean(torch.tensor(total_c))
322
+ total_c = total_c/idx_dec
323
+ return logits, total_c, all_z# , log_q, log_p, kl_all, kl_diag
324
+
325
+ def decoder_output(self, logits):
326
+ return NormalDecoder(logits)
@@ -0,0 +1,211 @@
1
+ # -*-Encoding: utf-8 -*-
2
+ """
3
+ Authors:
4
+ Li,Yan (liyan22021121@gmail.com)
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ from .resnet import Res12_Quadratic
10
+ from .diffusion_process import GaussianDiffusion, get_beta_schedule
11
+ from .encoder import Encoder
12
+ from .embedding import DataEmbedding
13
+ from ...data_structure.utils import beauty_string
14
+
15
+
16
+ class diffusion_generate(nn.Module):
17
+ def __init__(self, target_dim,embedding_dimension,prediction_length,sequence_length,scale,hidden_size,num_layers,dropout_rate,diff_steps,loss_type,beta_end,beta_schedule, channel_mult,mult,
18
+ num_preprocess_blocks,num_preprocess_cells,num_channels_enc,arch_instance,num_latent_per_group,num_channels_dec,groups_per_scale,num_postprocess_blocks,num_postprocess_cells):
19
+ super().__init__()
20
+ self.target_dim = target_dim
21
+ self.input_size = embedding_dimension
22
+ self.prediction_length = prediction_length
23
+ self.seq_length = sequence_length
24
+ self.scale = scale
25
+ self.rnn = nn.GRU(
26
+ input_size=self.input_size,
27
+ hidden_size=hidden_size,
28
+ num_layers=num_layers,
29
+ dropout=dropout_rate,
30
+ batch_first=True,
31
+ )
32
+
33
+ self.generative = Encoder(channel_mult,mult,prediction_length,
34
+ #sequence_length,
35
+ num_preprocess_blocks,num_preprocess_cells,num_channels_enc,arch_instance,num_latent_per_group,num_channels_dec,groups_per_scale,num_postprocess_blocks,num_postprocess_cells,embedding_dimension,hidden_size,target_dim,sequence_length,num_layers,dropout_rate)
36
+ self.diffusion = GaussianDiffusion(
37
+ self.generative,
38
+ input_size=target_dim,
39
+ diff_steps=diff_steps,
40
+ loss_type=loss_type,
41
+ beta_end=beta_end,
42
+ beta_schedule=beta_schedule,
43
+ scale = scale,
44
+ )
45
+ self.projection = nn.Linear(embedding_dimension+hidden_size, embedding_dimension)
46
+
47
+ def forward(self, past_time_feat, future_time_feat, t):
48
+ """
49
+ Output the generative results and related variables.
50
+ """
51
+ time_feat, _ = self.rnn(past_time_feat)
52
+ input = torch.cat([time_feat, past_time_feat], dim=-1)
53
+ output, y_noisy, total_c, all_z = self.diffusion.log_prob(input, future_time_feat, t)
54
+ return output, y_noisy, total_c, all_z
55
+
56
+
57
+ class denoise_net(nn.Module):
58
+ def __init__(self, target_dim,embedding_dimension,prediction_length,sequence_length,scale,hidden_size,num_layers,dropout_rate,diff_steps,loss_type,beta_end,beta_schedule, channel_mult,mult,
59
+ num_preprocess_blocks,num_preprocess_cells,num_channels_enc,arch_instance,num_latent_per_group,num_channels_dec,groups_per_scale,num_postprocess_blocks,num_postprocess_cells,beta_start,input_dim,freq,embs):
60
+ super().__init__()
61
+ """
62
+ The whole model architecture consists of three main parts, the coupled diffusion process and the generative model are
63
+ included in diffusion_generate module, an resnet is used to calculate the score.
64
+ """
65
+ # ResNet that used to calculate the scores.
66
+ self.score_net = Res12_Quadratic(1, 64, 32, normalize=False, AF=nn.ELU())
67
+
68
+ # Generate the diffusion schedule.
69
+ sigmas = get_beta_schedule(beta_schedule, beta_start, beta_end, diff_steps)
70
+ alphas = 1.0 - sigmas*0.5
71
+ self.alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0))
72
+ self.sqrt_alphas_cumprod = torch.tensor(np.sqrt(np.cumprod(alphas, axis=0)))
73
+ self.sqrt_one_minus_alphas_cumprod = torch.tensor(np.sqrt(1-np.cumprod(alphas, axis=0)))
74
+ self.sigmas = torch.tensor(1. - self.alphas_cumprod)
75
+
76
+ # The generative bvae model.
77
+ self.diffusion_gen = diffusion_generate(target_dim,embedding_dimension,prediction_length,sequence_length,scale,hidden_size,num_layers,dropout_rate,diff_steps,loss_type,beta_end,beta_schedule, channel_mult,mult,
78
+ num_preprocess_blocks,num_preprocess_cells,num_channels_enc,arch_instance,num_latent_per_group,num_channels_dec,groups_per_scale,num_postprocess_blocks,num_postprocess_cells)
79
+
80
+ # Data embedding module.
81
+
82
+
83
+
84
+ self.embedding = DataEmbedding(input_dim, embedding_dimension, embs,dropout_rate)
85
+
86
+ def extract(self, a, t, x_shape):
87
+ """ extract the t-th element from a"""
88
+ b, *_ = t.shape
89
+ out = a.gather(-1, t)
90
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
91
+
92
+ def forward(self, past_time_feat, mark, future_time_feat, t):
93
+ """
94
+ Params:
95
+ past_time_feat: Tensor
96
+ the input time series.
97
+ mark: Tensor
98
+ the time feature mark.
99
+ future_time_feat: Tensor
100
+ the target time series.
101
+ t: Tensor
102
+ the diffusion step.
103
+ -------------
104
+ return:
105
+ output: Tensor
106
+ The gauaaian distribution of the generative results.
107
+ y_noisy: Tensor
108
+ The diffused target.
109
+ total_c: Float
110
+ Total correlation of all the latent variables in the BVAE, used for disentangling.
111
+ all_z: List
112
+ All the latent variables of bvae.
113
+ loss: Float
114
+ The loss of score matching.
115
+ """
116
+ # Embed the original time series.
117
+ input = self.embedding(past_time_feat, mark)
118
+ #input, _ = self.diffusion_gen.rnn(input)
119
+ # Output the distribution of the generative results, the sampled generative results and the total correlations of the generative model.
120
+ output, y_noisy, total_c, all_z = self.diffusion_gen(input, future_time_feat, t)
121
+
122
+ # Score matching.
123
+ sigmas_t = self.extract(self.sigmas.to(y_noisy.device), t, y_noisy.shape)
124
+ y = future_time_feat.unsqueeze(1).float()
125
+ y_noisy1 = output.sample().float().requires_grad_()
126
+ E = self.score_net(y_noisy1).sum()
127
+
128
+ # The Loss of multiscale score matching.
129
+ grad_x = torch.autograd.grad(E, y_noisy1, create_graph=True)[0]
130
+ loss = torch.mean(torch.sum(((y-y_noisy1.detach())+grad_x*0.001)**2*sigmas_t, [1,2,3])).float()
131
+ return output, y_noisy, total_c, all_z, loss
132
+
133
+
134
+ class pred_net(denoise_net):
135
+ def forward(self, x, mark):
136
+ """
137
+ generate the prediction by the trained model.
138
+ Return:
139
+ y: The noisy generative results
140
+ out: Denoised results, remove the noise from y through score matching.
141
+ tc: Total correlations, indicator of extent of disentangling.
142
+ """
143
+ input = self.embedding(x, mark)
144
+ x_t, _ = self.diffusion_gen.rnn(input)
145
+ input = torch.cat([x_t, input], dim=-1)
146
+ input = input.unsqueeze(1)
147
+ logits, tc, all_z= self.diffusion_gen.generative(input)
148
+ output = self.diffusion_gen.generative.decoder_output(logits)
149
+ y = output.mu.float().requires_grad_()
150
+
151
+ try:
152
+ E = self.score_net(y).sum()
153
+ grad_x = torch.autograd.grad(E, y, create_graph=True,allow_unused=True)[0]
154
+ except Exception as e:
155
+ beauty_string(e,'')
156
+ grad_x = 0
157
+
158
+ out = y - grad_x*0.001
159
+ return y, out, tc, all_z
160
+
161
+
162
+ class Discriminator(nn.Module):
163
+ def __init__(self, neg_slope=0.2, latent_dim=10, hidden_units=1000, out_units=2):
164
+ """Discriminator proposed in [1].
165
+ Parameters
166
+ ----------
167
+ neg_slope: float
168
+ Hyperparameter for the Leaky ReLu
169
+ latent_dim : int
170
+ Dimensionality of latent variables.
171
+ hidden_units: int
172
+ Number of hidden units in the MLP
173
+ Model Architecture
174
+ ------------
175
+ - 6 layer multi-layer perceptron, each with 1000 hidden units
176
+ - Leaky ReLu activations
177
+ - Output 2 logits
178
+ References:
179
+ [1] Kim, Hyunjik, and Andriy Mnih. "Disentangling by factorising."
180
+ arXiv preprint arXiv:1802.05983 (2018).
181
+ """
182
+ super(Discriminator, self).__init__()
183
+
184
+ # Activation parameters
185
+ self.neg_slope = neg_slope
186
+ self.leaky_relu = nn.LeakyReLU(self.neg_slope, True)
187
+
188
+ # Layer parameters
189
+ self.z_dim = latent_dim
190
+ self.hidden_units = hidden_units
191
+ # theoretically 1 with sigmoid but gives bad results => use 2 and softmax
192
+ out_units = out_units
193
+
194
+ # Fully connected layers
195
+ self.lin1 = nn.Linear(self.z_dim, hidden_units)
196
+ self.lin2 = nn.Linear(hidden_units, hidden_units)
197
+ self.lin3 = nn.Linear(hidden_units, hidden_units)
198
+ self.lin4 = nn.Linear(hidden_units, hidden_units)
199
+ self.lin5 = nn.Linear(hidden_units, hidden_units)
200
+ self.lin6 = nn.Linear(hidden_units, out_units)
201
+ self.softmax = nn.Softmax()
202
+
203
+ def forward(self, z):
204
+ # Fully connected layers with leaky ReLu activations
205
+ z = self.leaky_relu(self.lin1(z))
206
+ z = self.leaky_relu(self.lin2(z))
207
+ z = self.leaky_relu(self.lin3(z))
208
+ z = self.leaky_relu(self.lin4(z))
209
+ z = self.leaky_relu(self.lin5(z))
210
+ z = self.lin6(z)
211
+ return z