dsipts 1.1.12__py3-none-any.whl → 1.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

@@ -0,0 +1,96 @@
1
+ ## Copyright https://github.com/thuml/Time-Series-Library/blob/main/models/TimesNet.py
2
+ ## Modified for notation alignmenet and batch structure
3
+ ## extended to what inside itransformer folder
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ from .timesnet.Layers import TimesBlock
9
+ from ..data_structure.utils import beauty_string
10
+ from .utils import get_scope,get_activation,Embedding_cat_variables
11
+
12
+ try:
13
+ import lightning.pytorch as pl
14
+ from .base_v2 import Base
15
+ OLD_PL = False
16
+ except:
17
+ import pytorch_lightning as pl
18
+ OLD_PL = True
19
+ from .base import Base
20
+
21
+
22
+
23
+ class TimesNet(Base):
24
+ handle_multivariate = True
25
+ handle_future_covariates = False
26
+ handle_categorical_variables = True
27
+ handle_quantile_loss = True
28
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
29
+
30
+ def __init__(self,
31
+ # specific params
32
+ e_layers:int,
33
+ d_model: int,
34
+ top_k: int,
35
+ d_ff: int,
36
+ num_kernels: int,
37
+ **kwargs)->None:
38
+
39
+
40
+
41
+
42
+ super().__init__(**kwargs)
43
+
44
+ self.save_hyperparameters(logger=False)
45
+ self.e_layers = e_layers
46
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
47
+ #self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
48
+ emb_past_out_channel = self.emb_past.output_channels
49
+ #emb_fut_out_channel = self.emb_fut.output_channels
50
+
51
+ self.prepare = nn.Linear(emb_past_out_channel+self.past_channels, d_model)
52
+
53
+ self.model = nn.ModuleList([TimesBlock(self.past_steps,self.future_steps,top_k,d_model,d_ff,num_kernels) for _ in range(e_layers)])
54
+ self.layer_norm = nn.LayerNorm(d_model)
55
+
56
+ self.predict_linear = nn.Linear(self.past_steps, self.future_steps + self.past_steps)
57
+
58
+ self.projection = nn.Linear(d_model, self.out_channels*self.mul, bias=True)
59
+
60
+ def can_be_compiled(self):
61
+ return False#True
62
+
63
+ def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
64
+ # Normalization from Non-stationary Transformer
65
+ #means = x_enc.mean(1, keepdim=True).detach()
66
+ #x_enc = x_enc.sub(means)
67
+ #stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
68
+ #x_enc = x_enc.div(stdev)
69
+
70
+ # embedding
71
+ enc_out = torch.cat([x_enc, x_mark_enc],axis=2) # [B,T,C]
72
+ enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(0, 2, 1) # align temporal dimension
73
+ # TimesNet
74
+ enc_out = self.prepare(enc_out)
75
+ for i in range(self.e_layers):
76
+ enc_out = self.layer_norm(self.model[i](enc_out))
77
+ # project back
78
+ dec_out = self.projection(enc_out)
79
+
80
+ # De-Normalization from Non-stationary Transformer
81
+ #dec_out = dec_out.mul((stdev[:, 0, :].unsqueeze(1).repeat(1, self.future_steps + self.past_steps, 1)))
82
+ #dec_out = dec_out.add((means[:, 0, :].unsqueeze(1).repeat(1, self.future_steps + self.past_steps, 1)))
83
+ return dec_out
84
+
85
+ def forward(self, batch:dict)-> float:
86
+
87
+ x_enc = batch['x_num_past'].to(self.device)
88
+ BS = x_enc.shape[0]
89
+ if 'x_cat_past' in batch.keys():
90
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
91
+ else:
92
+ emb_past = self.emb_past(BS,None)
93
+
94
+ dec_out = self.forecast(x_enc, emb_past, None, None)
95
+
96
+ return dec_out[:, -self.future_steps:,:].reshape(BS,self.future_steps,self.out_channels,self.mul)
dsipts/models/base.py CHANGED
@@ -14,6 +14,7 @@ import matplotlib.pyplot as plt
14
14
  from typing import List, Union
15
15
  from .utils import QuantileLossMO, CPRS
16
16
  import torch.nn as nn
17
+ from torch.optim import Adam, AdamW, SGD, RMSprop
17
18
 
18
19
  def standardize_momentum(x,order):
19
20
  mean = torch.mean(x,1).unsqueeze(1).repeat(1,x.shape[1],1)
@@ -336,7 +337,7 @@ class Base(pl.LightningModule):
336
337
 
337
338
  :meta private:
338
339
  """
339
- if len(self._val_outputs)>0:
340
+ if (len(self._val_outputs)>0) & (self.trainer.max_epochs>0):
340
341
  ys = torch.cat([o["y"] for o in self._val_outputs])
341
342
  y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
342
343
  if self.use_quantiles:
@@ -392,8 +393,15 @@ class Base(pl.LightningModule):
392
393
  initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
393
394
  else:
394
395
  initial_loss = self.loss(y_hat, batch['y'])
396
+
397
+ if self.loss_type in ['mse','l1']:
398
+ return initial_loss
399
+
395
400
  x = batch['x_num_past'].to(self.device)
396
401
  idx_target = batch['idx_target'][0]
402
+ if idx_target is None:
403
+ beauty_string(f'Can not compute non-standard loss for non autoregressive models, if you want to use custom losses please add check=True wile initialize the time series object','info',self.verbose)
404
+ return initial_loss
397
405
  x_start = x[:,-1,idx_target].unsqueeze(1)
398
406
  y_persistence = x_start.repeat(1,self.future_steps,1)
399
407
 
dsipts/models/base_v2.py CHANGED
@@ -14,6 +14,7 @@ import matplotlib.pyplot as plt
14
14
  from typing import List, Union
15
15
  from .utils import QuantileLossMO, CPRS
16
16
  import torch.nn as nn
17
+ from torch.optim import Adam, AdamW, SGD, RMSprop
17
18
 
18
19
  def standardize_momentum(x,order):
19
20
  mean = torch.mean(x,1).unsqueeze(1).repeat(1,x.shape[1],1)
@@ -220,7 +221,7 @@ class Base(pl.LightningModule):
220
221
 
221
222
 
222
223
  if self.optim is None:
223
- optimizer = optim.Adam(self.parameters(), **self.optim_config)
224
+ optimizer = Adam(self.parameters(), **self.optim_config)
224
225
  self.initialize = True
225
226
 
226
227
  else:
@@ -237,7 +238,7 @@ class Base(pl.LightningModule):
237
238
 
238
239
  beauty_string(self.optim,'',self.verbose)
239
240
  if self.has_sam_optim:
240
- optimizer = SAM(self.parameters(), base_optimizer=torch.optim.Adam, **self.optim_config)
241
+ optimizer = SAM(self.parameters(), base_optimizer=Adam, **self.optim_config)
241
242
  else:
242
243
  optimizer = self.optim(self.parameters(), **self.optim_config)
243
244
  beauty_string(optimizer,'',self.verbose)
@@ -314,7 +315,7 @@ class Base(pl.LightningModule):
314
315
 
315
316
  :meta private:
316
317
  """
317
-
318
+
318
319
  if self.return_additional_loss:
319
320
  y_hat,score = self(batch)
320
321
  else:
@@ -335,7 +336,6 @@ class Base(pl.LightningModule):
335
336
  def on_validation_start(self):
336
337
  # reset buffer each epoch
337
338
  self._val_outputs = []
338
-
339
339
 
340
340
  def on_validation_epoch_end(self):
341
341
  """
@@ -344,7 +344,7 @@ class Base(pl.LightningModule):
344
344
  :meta private:
345
345
  """
346
346
 
347
- if len(self._val_outputs)>0:
347
+ if (len(self._val_outputs)>0) & (self.trainer.max_epochs>0):
348
348
  ys = torch.cat([o["y"] for o in self._val_outputs])
349
349
  y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
350
350
  if self.use_quantiles:
@@ -353,6 +353,7 @@ class Base(pl.LightningModule):
353
353
  idx = 0
354
354
  for i in range(ys.shape[2]):
355
355
  real = ys[0,:,i].cpu().detach().numpy()
356
+
356
357
  pred = y_hats[0,:,i,idx].cpu().detach().numpy()
357
358
  fig, ax = plt.subplots(figsize=(7,5))
358
359
  ax.plot(real,'o-',label='real')
@@ -363,7 +364,7 @@ class Base(pl.LightningModule):
363
364
  #self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
364
365
  plt.close(fig)
365
366
 
366
-
367
+
367
368
  avg = self.validation_epoch_metrics/self.validation_epoch_count
368
369
 
369
370
  self.validation_epoch_metrics.zero_()
@@ -407,8 +408,17 @@ class Base(pl.LightningModule):
407
408
  initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
408
409
  else:
409
410
  initial_loss = self.loss(y_hat, batch['y'])
411
+
412
+ if self.loss_type in ['mse','l1']:
413
+ return initial_loss
414
+
410
415
  x = batch['x_num_past'].to(self.device)
416
+
411
417
  idx_target = batch['idx_target'][0]
418
+
419
+ if idx_target is None:
420
+ beauty_string(f'Can not compute non-standard loss for non autoregressive models, if you want to use custom losses please add check=True wile initialize the time series object','info',self.verbose)
421
+ return initial_loss
412
422
  x_start = x[:,-1,idx_target].unsqueeze(1)
413
423
  y_persistence = x_start.repeat(1,self.future_steps,1)
414
424
 
@@ -423,7 +433,7 @@ class Base(pl.LightningModule):
423
433
  persistence_error = (2.0-10.0*torch.clamp( torch.abs((y_persistence-x)/(0.001+torch.abs(y_persistence))),min=0.0,max=max(0.05,0.1*(1+np.log10(self.persistence_weight) ))))
424
434
  loss = torch.mean(torch.abs(x- batch['y'])*persistence_error)
425
435
 
426
- if self.loss_type == 'mda':
436
+ elif self.loss_type == 'mda':
427
437
  #import pdb
428
438
  #pdb.set_trace()
429
439
  mda = (1-torch.mean( torch.sign(torch.diff(x,axis=1))*torch.sign(torch.diff(batch['y'],axis=1))))
@@ -0,0 +1,284 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.fft
5
+
6
+
7
+ class Normalize(nn.Module):
8
+ def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
9
+ """
10
+ :param num_features: the number of features or channels
11
+ :param eps: a value added for numerical stability
12
+ :param affine: if True, RevIN has learnable affine parameters
13
+ """
14
+ super(Normalize, self).__init__()
15
+ self.num_features = num_features
16
+ self.eps = eps
17
+ self.affine = affine
18
+ self.subtract_last = subtract_last
19
+ self.non_norm = non_norm
20
+ if self.affine:
21
+ self._init_params()
22
+
23
+ def forward(self, x, mode: str):
24
+ if mode == 'norm':
25
+ self._get_statistics(x)
26
+ x = self._normalize(x)
27
+ elif mode == 'denorm':
28
+ x = self._denormalize(x)
29
+ else:
30
+ raise NotImplementedError
31
+ return x
32
+
33
+ def _init_params(self):
34
+ # initialize RevIN params: (C,)
35
+ self.affine_weight = nn.Parameter(torch.ones(self.num_features))
36
+ self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
37
+
38
+ def _get_statistics(self, x):
39
+ dim2reduce = tuple(range(1, x.ndim - 1))
40
+ if self.subtract_last:
41
+ self.last = x[:, -1, :].unsqueeze(1)
42
+ else:
43
+ self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
44
+ self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
45
+
46
+ def _normalize(self, x):
47
+ if self.non_norm:
48
+ return x
49
+ if self.subtract_last:
50
+ x = x - self.last
51
+ else:
52
+ x = x - self.mean
53
+ x = x / self.stdev
54
+ if self.affine:
55
+ x = x * self.affine_weight
56
+ x = x + self.affine_bias
57
+ return x
58
+
59
+ def _denormalize(self, x):
60
+ if self.non_norm:
61
+ return x
62
+ if self.affine:
63
+ x = x - self.affine_bias
64
+ x = x / (self.affine_weight + self.eps * self.eps)
65
+ x = x * self.stdev
66
+ if self.subtract_last:
67
+ x = x + self.last
68
+ else:
69
+ x = x + self.mean
70
+ return x
71
+
72
+ class ChebyKANLinear(nn.Module):
73
+ def __init__(self, input_dim, output_dim, degree):
74
+ super(ChebyKANLinear, self).__init__()
75
+ self.inputdim = input_dim
76
+ self.outdim = output_dim
77
+ self.degree = degree
78
+
79
+ self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
80
+ self.epsilon = 1e-7
81
+ self.pre_mul = False
82
+ self.post_mul = False
83
+ nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))
84
+ self.register_buffer("arange", torch.arange(0, degree + 1, 1))
85
+
86
+ def forward(self, x):
87
+ # Since Chebyshev polynomial is defined in [-1, 1]
88
+ # We need to normalize x to [-1, 1] using tanh
89
+ # View and repeat input degree + 1 times
90
+ b,c_in = x.shape
91
+ if self.pre_mul:
92
+ mul_1 = x[:,::2]
93
+ mul_2 = x[:,1::2]
94
+ mul_res = mul_1 * mul_2
95
+ x = torch.concat([x[:,:x.shape[1]//2], mul_res])
96
+ x = x.view((b, c_in, 1)).expand(
97
+ -1, -1, self.degree + 1
98
+ ) # shape = (batch_size, inputdim, self.degree + 1)
99
+ # Apply acos
100
+ x = torch.tanh(x)
101
+ x = torch.tanh(x)
102
+ x = torch.acos(x)
103
+ # x = torch.acos(torch.clamp(x, -1 + self.epsilon, 1 - self.epsilon))
104
+ # # Multiply by arange [0 .. degree]
105
+ x = x* self.arange
106
+ # Apply cos
107
+ x = x.cos()
108
+ # Compute the Chebyshev interpolation
109
+ y = torch.einsum(
110
+ "bid,iod->bo", x, self.cheby_coeffs
111
+ ) # shape = (batch_size, outdim)
112
+ y = y.view(-1, self.outdim)
113
+ if self.post_mul:
114
+ mul_1 = y[:,::2]
115
+ mul_2 = y[:,1::2]
116
+ mul_res = mul_1 * mul_2
117
+ y = torch.concat([y[:,:y.shape[1]//2], mul_res])
118
+ return y
119
+
120
+ class ChebyKANLayer(nn.Module):
121
+ def __init__(self, in_features, out_features,order):
122
+ super().__init__()
123
+ self.fc1 = ChebyKANLinear(
124
+ in_features,
125
+ out_features,
126
+ order)
127
+ def forward(self, x):
128
+ B, N, C = x.shape
129
+ x = self.fc1(x.reshape(B*N,C))
130
+ x = x.reshape(B,N,-1).contiguous()
131
+ return x
132
+
133
+
134
+ class FrequencyDecomp(nn.Module):
135
+
136
+ def __init__(self, seq_len,down_sampling_window,down_sampling_layers):
137
+ super(FrequencyDecomp, self).__init__()
138
+ self.seq_len = seq_len
139
+ self.down_sampling_window = down_sampling_window
140
+ self.down_sampling_layers = down_sampling_layers
141
+
142
+ def forward(self, level_list):
143
+
144
+ level_list_reverse = level_list.copy()
145
+ level_list_reverse.reverse()
146
+ out_low = level_list_reverse[0]
147
+ out_high = level_list_reverse[1]
148
+ out_level_list = [out_low]
149
+ for i in range(len(level_list_reverse) - 1):
150
+ out_high_res = self.frequency_interpolation(out_low.transpose(1,2),
151
+ self.seq_len // (self.down_sampling_window ** (self.down_sampling_layers-i)),
152
+ self.seq_len // (self.down_sampling_window ** (self.down_sampling_layers-i-1))
153
+ ).transpose(1,2)
154
+ out_high_left = out_high - out_high_res
155
+ out_low = out_high
156
+ if i + 2 <= len(level_list_reverse) - 1:
157
+ out_high = level_list_reverse[i + 2]
158
+ out_level_list.append(out_high_left)
159
+ out_level_list.reverse()
160
+ return out_level_list
161
+
162
+ def frequency_interpolation(self,x,seq_len,target_len):
163
+ len_ratio = seq_len/target_len
164
+ x_fft = torch.fft.rfft(x, dim=2)
165
+ out_fft = torch.zeros([x_fft.size(0),x_fft.size(1),target_len//2+1],dtype=x_fft.dtype).to(x_fft.device)
166
+ out_fft[:,:,:seq_len//2+1] = x_fft
167
+ out = torch.fft.irfft(out_fft, dim=2)
168
+ out = out * len_ratio
169
+ return out
170
+
171
+
172
+ class FrequencyMixing(nn.Module):
173
+
174
+ def __init__(self, d_model,seq_len,begin_order,down_sampling_window,down_sampling_layers):
175
+ super(FrequencyMixing, self).__init__()
176
+ self.front_block = M_KAN(d_model,seq_len // (down_sampling_window ** (down_sampling_layers)),order=begin_order)
177
+ self.d_model = d_model
178
+ self.seq_len = seq_len
179
+ self.begin_order = begin_order
180
+ self.down_sampling_window = down_sampling_window
181
+ self.down_sampling_layers = down_sampling_layers
182
+
183
+ self.front_blocks = torch.nn.ModuleList(
184
+ [
185
+ M_KAN(d_model,
186
+ seq_len // (down_sampling_window ** (down_sampling_layers-i-1)),
187
+ order=i+begin_order+1)
188
+ for i in range(down_sampling_layers)
189
+ ])
190
+
191
+ def forward(self, level_list):
192
+ level_list_reverse = level_list.copy()
193
+ level_list_reverse.reverse()
194
+ out_low = level_list_reverse[0]
195
+ out_high = level_list_reverse[1]
196
+ out_low = self.front_block(out_low)
197
+ out_level_list = [out_low]
198
+ for i in range(len(level_list_reverse) - 1):
199
+ out_high = self.front_blocks[i](out_high)
200
+ out_high_res = self.frequency_interpolation(out_low.transpose(1,2),
201
+ self.seq_len // (self.down_sampling_window ** (self.down_sampling_layers-i)),
202
+ self.seq_len // (self.down_sampling_window ** (self.down_sampling_layers-i-1))
203
+ ).transpose(1,2)
204
+ out_high = out_high + out_high_res
205
+ out_low = out_high
206
+ if i + 2 <= len(level_list_reverse) - 1:
207
+ out_high = level_list_reverse[i + 2]
208
+ out_level_list.append(out_low)
209
+ out_level_list.reverse()
210
+ return out_level_list
211
+
212
+ def frequency_interpolation(self,x,seq_len,target_len):
213
+ len_ratio = seq_len/target_len
214
+ x_fft = torch.fft.rfft(x, dim=2)
215
+ out_fft = torch.zeros([x_fft.size(0),x_fft.size(1),target_len//2+1],dtype=x_fft.dtype).to(x_fft.device)
216
+ out_fft[:,:,:seq_len//2+1] = x_fft
217
+ out = torch.fft.irfft(out_fft, dim=2)
218
+ out = out * len_ratio
219
+ return out
220
+
221
+ class M_KAN(nn.Module):
222
+ def __init__(self,d_model,seq_len,order):
223
+ super().__init__()
224
+ self.channel_mixer = nn.Sequential(
225
+ ChebyKANLayer(d_model, d_model,order)
226
+ )
227
+ self.conv = BasicConv(d_model,d_model,kernel_size=3,degree=order,groups=d_model)
228
+ def forward(self,x):
229
+ x1 = self.channel_mixer(x)
230
+ x2 = self.conv(x)
231
+ out = x1 + x2
232
+ return out
233
+
234
+ class BasicConv(nn.Module):
235
+ def __init__(self,c_in,c_out, kernel_size, degree,stride=1, padding=0, dilation=1, groups=1, act=False, bn=False, bias=False,dropout=0.):
236
+ super(BasicConv, self).__init__()
237
+ self.out_channels = c_out
238
+ self.conv = nn.Conv1d(c_in,c_out, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, dilation=dilation, groups=groups, bias=bias)
239
+ self.bn = nn.BatchNorm1d(c_out) if bn else None
240
+ self.act = nn.GELU() if act else None
241
+ self.dropout = nn.Dropout(dropout)
242
+ def forward(self, x):
243
+ if self.bn is not None:
244
+ x = self.bn(x)
245
+ x = self.conv(x.transpose(-1,-2)).transpose(-1,-2)
246
+ if self.act is not None:
247
+ x = self.act(x)
248
+ if self.dropout is not None:
249
+ x = self.dropout(x)
250
+ return x
251
+
252
+
253
+ class moving_avg(nn.Module):
254
+ """
255
+ Moving average block to highlight the trend of time series
256
+ """
257
+
258
+ def __init__(self, kernel_size, stride):
259
+ super(moving_avg, self).__init__()
260
+ self.kernel_size = kernel_size
261
+ self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
262
+
263
+ def forward(self, x):
264
+ # padding on the both ends of time series
265
+ front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
266
+ end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
267
+ x = torch.cat([front, x, end], dim=1)
268
+ x = self.avg(x.permute(0, 2, 1))
269
+ x = x.permute(0, 2, 1)
270
+ return x
271
+
272
+ class series_decomp(nn.Module):
273
+ """
274
+ Series decomposition block
275
+ """
276
+
277
+ def __init__(self, kernel_size):
278
+ super(series_decomp, self).__init__()
279
+ self.moving_avg = moving_avg(kernel_size, stride=1)
280
+
281
+ def forward(self, x):
282
+ moving_mean = self.moving_avg(x)
283
+ res = x - moving_mean
284
+ return res, moving_mean
File without changes
@@ -0,0 +1,95 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.fft
5
+
6
+
7
+ class Inception_Block_V1(nn.Module):
8
+ def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
9
+ super(Inception_Block_V1, self).__init__()
10
+ self.in_channels = in_channels
11
+ self.out_channels = out_channels
12
+ self.num_kernels = num_kernels
13
+ kernels = []
14
+ for i in range(self.num_kernels):
15
+ kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
16
+ self.kernels = nn.ModuleList(kernels)
17
+ if init_weight:
18
+ self._initialize_weights()
19
+
20
+ def _initialize_weights(self):
21
+ for m in self.modules():
22
+ if isinstance(m, nn.Conv2d):
23
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
24
+ if m.bias is not None:
25
+ nn.init.constant_(m.bias, 0)
26
+
27
+ def forward(self, x):
28
+ res_list = []
29
+ for i in range(self.num_kernels):
30
+ res_list.append(self.kernels[i](x))
31
+ res = torch.stack(res_list, dim=-1).mean(-1)
32
+ return res
33
+
34
+ def FFT_for_Period(x, k=2):
35
+ # [B, T, C]
36
+ xf = torch.fft.rfft(x, dim=1)
37
+ # find period by amplitudes
38
+ frequency_list = abs(xf).mean(0).mean(-1)
39
+ frequency_list[0] = 0
40
+ _, top_list = torch.topk(frequency_list, k)
41
+ top_list = top_list.detach().cpu().numpy()
42
+ period = x.shape[1] // top_list
43
+ return period, abs(xf).mean(-1)[:, top_list]
44
+
45
+
46
+
47
+ class TimesBlock(nn.Module):
48
+ def __init__(self, seq_len,pred_len,top_k,d_model,d_ff,num_kernels):
49
+ super(TimesBlock, self).__init__()
50
+ self.seq_len = seq_len
51
+ self.pred_len = pred_len
52
+ self.k = top_k
53
+ # parameter-efficient design
54
+ self.conv = nn.Sequential(
55
+ Inception_Block_V1(d_model, d_ff,
56
+ num_kernels=num_kernels),
57
+ nn.GELU(),
58
+ Inception_Block_V1(d_ff, d_model,
59
+ num_kernels=num_kernels)
60
+ )
61
+
62
+ def forward(self, x):
63
+ B, T, N = x.size()
64
+
65
+ period_list, period_weight = FFT_for_Period(x, self.k)
66
+
67
+ res = []
68
+ for i in range(self.k):
69
+ period = period_list[i]
70
+ # padding
71
+ if (self.seq_len + self.pred_len) % period != 0:
72
+ length = (
73
+ ((self.seq_len + self.pred_len) // period) + 1) * period
74
+ padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
75
+ out = torch.cat([x, padding], dim=1)
76
+ else:
77
+ length = (self.seq_len + self.pred_len)
78
+ out = x
79
+ # reshape
80
+ out = out.reshape(B, length // period, period,
81
+ N).permute(0, 3, 1, 2).contiguous()
82
+ # 2D conv: from 1d Variation to 2d Variation
83
+ out = self.conv(out)
84
+ # reshape back
85
+ out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
86
+ res.append(out[:, :(self.seq_len + self.pred_len), :])
87
+ res = torch.stack(res, dim=-1)
88
+ # adaptive aggregation
89
+ period_weight = F.softmax(period_weight, dim=1)
90
+ period_weight = period_weight.unsqueeze(
91
+ 1).unsqueeze(1).repeat(1, T, N, 1)
92
+ res = torch.sum(res * period_weight, -1)
93
+ # residual connection
94
+ res = res + x
95
+ return res
File without changes
dsipts/version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "1.1.15"