dsipts 1.1.12__py3-none-any.whl → 1.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +4 -1
- dsipts/data_structure/data_structure.py +110 -32
- dsipts/data_structure/utils.py +4 -2
- dsipts/models/Persistent.py +2 -0
- dsipts/models/TTM.py +31 -9
- dsipts/models/TimeKAN.py +123 -0
- dsipts/models/TimesNet.py +96 -0
- dsipts/models/base.py +9 -1
- dsipts/models/base_v2.py +17 -7
- dsipts/models/timekan/Layers.py +284 -0
- dsipts/models/timekan/__init__.py +0 -0
- dsipts/models/timesnet/Layers.py +95 -0
- dsipts/models/timesnet/__init__.py +0 -0
- dsipts/version.py +1 -0
- {dsipts-1.1.12.dist-info → dsipts-1.1.15.dist-info}/METADATA +56 -8
- {dsipts-1.1.12.dist-info → dsipts-1.1.15.dist-info}/RECORD +18 -11
- {dsipts-1.1.12.dist-info → dsipts-1.1.15.dist-info}/WHEEL +0 -0
- {dsipts-1.1.12.dist-info → dsipts-1.1.15.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
## Copyright https://github.com/thuml/Time-Series-Library/blob/main/models/TimesNet.py
|
|
2
|
+
## Modified for notation alignmenet and batch structure
|
|
3
|
+
## extended to what inside itransformer folder
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import numpy as np
|
|
8
|
+
from .timesnet.Layers import TimesBlock
|
|
9
|
+
from ..data_structure.utils import beauty_string
|
|
10
|
+
from .utils import get_scope,get_activation,Embedding_cat_variables
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import lightning.pytorch as pl
|
|
14
|
+
from .base_v2 import Base
|
|
15
|
+
OLD_PL = False
|
|
16
|
+
except:
|
|
17
|
+
import pytorch_lightning as pl
|
|
18
|
+
OLD_PL = True
|
|
19
|
+
from .base import Base
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TimesNet(Base):
|
|
24
|
+
handle_multivariate = True
|
|
25
|
+
handle_future_covariates = False
|
|
26
|
+
handle_categorical_variables = True
|
|
27
|
+
handle_quantile_loss = True
|
|
28
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
29
|
+
|
|
30
|
+
def __init__(self,
|
|
31
|
+
# specific params
|
|
32
|
+
e_layers:int,
|
|
33
|
+
d_model: int,
|
|
34
|
+
top_k: int,
|
|
35
|
+
d_ff: int,
|
|
36
|
+
num_kernels: int,
|
|
37
|
+
**kwargs)->None:
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
super().__init__(**kwargs)
|
|
43
|
+
|
|
44
|
+
self.save_hyperparameters(logger=False)
|
|
45
|
+
self.e_layers = e_layers
|
|
46
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
47
|
+
#self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
48
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
49
|
+
#emb_fut_out_channel = self.emb_fut.output_channels
|
|
50
|
+
|
|
51
|
+
self.prepare = nn.Linear(emb_past_out_channel+self.past_channels, d_model)
|
|
52
|
+
|
|
53
|
+
self.model = nn.ModuleList([TimesBlock(self.past_steps,self.future_steps,top_k,d_model,d_ff,num_kernels) for _ in range(e_layers)])
|
|
54
|
+
self.layer_norm = nn.LayerNorm(d_model)
|
|
55
|
+
|
|
56
|
+
self.predict_linear = nn.Linear(self.past_steps, self.future_steps + self.past_steps)
|
|
57
|
+
|
|
58
|
+
self.projection = nn.Linear(d_model, self.out_channels*self.mul, bias=True)
|
|
59
|
+
|
|
60
|
+
def can_be_compiled(self):
|
|
61
|
+
return False#True
|
|
62
|
+
|
|
63
|
+
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
|
64
|
+
# Normalization from Non-stationary Transformer
|
|
65
|
+
#means = x_enc.mean(1, keepdim=True).detach()
|
|
66
|
+
#x_enc = x_enc.sub(means)
|
|
67
|
+
#stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
|
68
|
+
#x_enc = x_enc.div(stdev)
|
|
69
|
+
|
|
70
|
+
# embedding
|
|
71
|
+
enc_out = torch.cat([x_enc, x_mark_enc],axis=2) # [B,T,C]
|
|
72
|
+
enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(0, 2, 1) # align temporal dimension
|
|
73
|
+
# TimesNet
|
|
74
|
+
enc_out = self.prepare(enc_out)
|
|
75
|
+
for i in range(self.e_layers):
|
|
76
|
+
enc_out = self.layer_norm(self.model[i](enc_out))
|
|
77
|
+
# project back
|
|
78
|
+
dec_out = self.projection(enc_out)
|
|
79
|
+
|
|
80
|
+
# De-Normalization from Non-stationary Transformer
|
|
81
|
+
#dec_out = dec_out.mul((stdev[:, 0, :].unsqueeze(1).repeat(1, self.future_steps + self.past_steps, 1)))
|
|
82
|
+
#dec_out = dec_out.add((means[:, 0, :].unsqueeze(1).repeat(1, self.future_steps + self.past_steps, 1)))
|
|
83
|
+
return dec_out
|
|
84
|
+
|
|
85
|
+
def forward(self, batch:dict)-> float:
|
|
86
|
+
|
|
87
|
+
x_enc = batch['x_num_past'].to(self.device)
|
|
88
|
+
BS = x_enc.shape[0]
|
|
89
|
+
if 'x_cat_past' in batch.keys():
|
|
90
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
91
|
+
else:
|
|
92
|
+
emb_past = self.emb_past(BS,None)
|
|
93
|
+
|
|
94
|
+
dec_out = self.forecast(x_enc, emb_past, None, None)
|
|
95
|
+
|
|
96
|
+
return dec_out[:, -self.future_steps:,:].reshape(BS,self.future_steps,self.out_channels,self.mul)
|
dsipts/models/base.py
CHANGED
|
@@ -14,6 +14,7 @@ import matplotlib.pyplot as plt
|
|
|
14
14
|
from typing import List, Union
|
|
15
15
|
from .utils import QuantileLossMO, CPRS
|
|
16
16
|
import torch.nn as nn
|
|
17
|
+
from torch.optim import Adam, AdamW, SGD, RMSprop
|
|
17
18
|
|
|
18
19
|
def standardize_momentum(x,order):
|
|
19
20
|
mean = torch.mean(x,1).unsqueeze(1).repeat(1,x.shape[1],1)
|
|
@@ -336,7 +337,7 @@ class Base(pl.LightningModule):
|
|
|
336
337
|
|
|
337
338
|
:meta private:
|
|
338
339
|
"""
|
|
339
|
-
if len(self._val_outputs)>0:
|
|
340
|
+
if (len(self._val_outputs)>0) & (self.trainer.max_epochs>0):
|
|
340
341
|
ys = torch.cat([o["y"] for o in self._val_outputs])
|
|
341
342
|
y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
|
|
342
343
|
if self.use_quantiles:
|
|
@@ -392,8 +393,15 @@ class Base(pl.LightningModule):
|
|
|
392
393
|
initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
|
|
393
394
|
else:
|
|
394
395
|
initial_loss = self.loss(y_hat, batch['y'])
|
|
396
|
+
|
|
397
|
+
if self.loss_type in ['mse','l1']:
|
|
398
|
+
return initial_loss
|
|
399
|
+
|
|
395
400
|
x = batch['x_num_past'].to(self.device)
|
|
396
401
|
idx_target = batch['idx_target'][0]
|
|
402
|
+
if idx_target is None:
|
|
403
|
+
beauty_string(f'Can not compute non-standard loss for non autoregressive models, if you want to use custom losses please add check=True wile initialize the time series object','info',self.verbose)
|
|
404
|
+
return initial_loss
|
|
397
405
|
x_start = x[:,-1,idx_target].unsqueeze(1)
|
|
398
406
|
y_persistence = x_start.repeat(1,self.future_steps,1)
|
|
399
407
|
|
dsipts/models/base_v2.py
CHANGED
|
@@ -14,6 +14,7 @@ import matplotlib.pyplot as plt
|
|
|
14
14
|
from typing import List, Union
|
|
15
15
|
from .utils import QuantileLossMO, CPRS
|
|
16
16
|
import torch.nn as nn
|
|
17
|
+
from torch.optim import Adam, AdamW, SGD, RMSprop
|
|
17
18
|
|
|
18
19
|
def standardize_momentum(x,order):
|
|
19
20
|
mean = torch.mean(x,1).unsqueeze(1).repeat(1,x.shape[1],1)
|
|
@@ -220,7 +221,7 @@ class Base(pl.LightningModule):
|
|
|
220
221
|
|
|
221
222
|
|
|
222
223
|
if self.optim is None:
|
|
223
|
-
optimizer =
|
|
224
|
+
optimizer = Adam(self.parameters(), **self.optim_config)
|
|
224
225
|
self.initialize = True
|
|
225
226
|
|
|
226
227
|
else:
|
|
@@ -237,7 +238,7 @@ class Base(pl.LightningModule):
|
|
|
237
238
|
|
|
238
239
|
beauty_string(self.optim,'',self.verbose)
|
|
239
240
|
if self.has_sam_optim:
|
|
240
|
-
optimizer = SAM(self.parameters(), base_optimizer=
|
|
241
|
+
optimizer = SAM(self.parameters(), base_optimizer=Adam, **self.optim_config)
|
|
241
242
|
else:
|
|
242
243
|
optimizer = self.optim(self.parameters(), **self.optim_config)
|
|
243
244
|
beauty_string(optimizer,'',self.verbose)
|
|
@@ -314,7 +315,7 @@ class Base(pl.LightningModule):
|
|
|
314
315
|
|
|
315
316
|
:meta private:
|
|
316
317
|
"""
|
|
317
|
-
|
|
318
|
+
|
|
318
319
|
if self.return_additional_loss:
|
|
319
320
|
y_hat,score = self(batch)
|
|
320
321
|
else:
|
|
@@ -335,7 +336,6 @@ class Base(pl.LightningModule):
|
|
|
335
336
|
def on_validation_start(self):
|
|
336
337
|
# reset buffer each epoch
|
|
337
338
|
self._val_outputs = []
|
|
338
|
-
|
|
339
339
|
|
|
340
340
|
def on_validation_epoch_end(self):
|
|
341
341
|
"""
|
|
@@ -344,7 +344,7 @@ class Base(pl.LightningModule):
|
|
|
344
344
|
:meta private:
|
|
345
345
|
"""
|
|
346
346
|
|
|
347
|
-
if len(self._val_outputs)>0:
|
|
347
|
+
if (len(self._val_outputs)>0) & (self.trainer.max_epochs>0):
|
|
348
348
|
ys = torch.cat([o["y"] for o in self._val_outputs])
|
|
349
349
|
y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
|
|
350
350
|
if self.use_quantiles:
|
|
@@ -353,6 +353,7 @@ class Base(pl.LightningModule):
|
|
|
353
353
|
idx = 0
|
|
354
354
|
for i in range(ys.shape[2]):
|
|
355
355
|
real = ys[0,:,i].cpu().detach().numpy()
|
|
356
|
+
|
|
356
357
|
pred = y_hats[0,:,i,idx].cpu().detach().numpy()
|
|
357
358
|
fig, ax = plt.subplots(figsize=(7,5))
|
|
358
359
|
ax.plot(real,'o-',label='real')
|
|
@@ -363,7 +364,7 @@ class Base(pl.LightningModule):
|
|
|
363
364
|
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
364
365
|
plt.close(fig)
|
|
365
366
|
|
|
366
|
-
|
|
367
|
+
|
|
367
368
|
avg = self.validation_epoch_metrics/self.validation_epoch_count
|
|
368
369
|
|
|
369
370
|
self.validation_epoch_metrics.zero_()
|
|
@@ -407,8 +408,17 @@ class Base(pl.LightningModule):
|
|
|
407
408
|
initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
|
|
408
409
|
else:
|
|
409
410
|
initial_loss = self.loss(y_hat, batch['y'])
|
|
411
|
+
|
|
412
|
+
if self.loss_type in ['mse','l1']:
|
|
413
|
+
return initial_loss
|
|
414
|
+
|
|
410
415
|
x = batch['x_num_past'].to(self.device)
|
|
416
|
+
|
|
411
417
|
idx_target = batch['idx_target'][0]
|
|
418
|
+
|
|
419
|
+
if idx_target is None:
|
|
420
|
+
beauty_string(f'Can not compute non-standard loss for non autoregressive models, if you want to use custom losses please add check=True wile initialize the time series object','info',self.verbose)
|
|
421
|
+
return initial_loss
|
|
412
422
|
x_start = x[:,-1,idx_target].unsqueeze(1)
|
|
413
423
|
y_persistence = x_start.repeat(1,self.future_steps,1)
|
|
414
424
|
|
|
@@ -423,7 +433,7 @@ class Base(pl.LightningModule):
|
|
|
423
433
|
persistence_error = (2.0-10.0*torch.clamp( torch.abs((y_persistence-x)/(0.001+torch.abs(y_persistence))),min=0.0,max=max(0.05,0.1*(1+np.log10(self.persistence_weight) ))))
|
|
424
434
|
loss = torch.mean(torch.abs(x- batch['y'])*persistence_error)
|
|
425
435
|
|
|
426
|
-
|
|
436
|
+
elif self.loss_type == 'mda':
|
|
427
437
|
#import pdb
|
|
428
438
|
#pdb.set_trace()
|
|
429
439
|
mda = (1-torch.mean( torch.sign(torch.diff(x,axis=1))*torch.sign(torch.diff(batch['y'],axis=1))))
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch.fft
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Normalize(nn.Module):
|
|
8
|
+
def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
|
|
9
|
+
"""
|
|
10
|
+
:param num_features: the number of features or channels
|
|
11
|
+
:param eps: a value added for numerical stability
|
|
12
|
+
:param affine: if True, RevIN has learnable affine parameters
|
|
13
|
+
"""
|
|
14
|
+
super(Normalize, self).__init__()
|
|
15
|
+
self.num_features = num_features
|
|
16
|
+
self.eps = eps
|
|
17
|
+
self.affine = affine
|
|
18
|
+
self.subtract_last = subtract_last
|
|
19
|
+
self.non_norm = non_norm
|
|
20
|
+
if self.affine:
|
|
21
|
+
self._init_params()
|
|
22
|
+
|
|
23
|
+
def forward(self, x, mode: str):
|
|
24
|
+
if mode == 'norm':
|
|
25
|
+
self._get_statistics(x)
|
|
26
|
+
x = self._normalize(x)
|
|
27
|
+
elif mode == 'denorm':
|
|
28
|
+
x = self._denormalize(x)
|
|
29
|
+
else:
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
return x
|
|
32
|
+
|
|
33
|
+
def _init_params(self):
|
|
34
|
+
# initialize RevIN params: (C,)
|
|
35
|
+
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
|
36
|
+
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
|
37
|
+
|
|
38
|
+
def _get_statistics(self, x):
|
|
39
|
+
dim2reduce = tuple(range(1, x.ndim - 1))
|
|
40
|
+
if self.subtract_last:
|
|
41
|
+
self.last = x[:, -1, :].unsqueeze(1)
|
|
42
|
+
else:
|
|
43
|
+
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
|
44
|
+
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
|
45
|
+
|
|
46
|
+
def _normalize(self, x):
|
|
47
|
+
if self.non_norm:
|
|
48
|
+
return x
|
|
49
|
+
if self.subtract_last:
|
|
50
|
+
x = x - self.last
|
|
51
|
+
else:
|
|
52
|
+
x = x - self.mean
|
|
53
|
+
x = x / self.stdev
|
|
54
|
+
if self.affine:
|
|
55
|
+
x = x * self.affine_weight
|
|
56
|
+
x = x + self.affine_bias
|
|
57
|
+
return x
|
|
58
|
+
|
|
59
|
+
def _denormalize(self, x):
|
|
60
|
+
if self.non_norm:
|
|
61
|
+
return x
|
|
62
|
+
if self.affine:
|
|
63
|
+
x = x - self.affine_bias
|
|
64
|
+
x = x / (self.affine_weight + self.eps * self.eps)
|
|
65
|
+
x = x * self.stdev
|
|
66
|
+
if self.subtract_last:
|
|
67
|
+
x = x + self.last
|
|
68
|
+
else:
|
|
69
|
+
x = x + self.mean
|
|
70
|
+
return x
|
|
71
|
+
|
|
72
|
+
class ChebyKANLinear(nn.Module):
|
|
73
|
+
def __init__(self, input_dim, output_dim, degree):
|
|
74
|
+
super(ChebyKANLinear, self).__init__()
|
|
75
|
+
self.inputdim = input_dim
|
|
76
|
+
self.outdim = output_dim
|
|
77
|
+
self.degree = degree
|
|
78
|
+
|
|
79
|
+
self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
|
|
80
|
+
self.epsilon = 1e-7
|
|
81
|
+
self.pre_mul = False
|
|
82
|
+
self.post_mul = False
|
|
83
|
+
nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))
|
|
84
|
+
self.register_buffer("arange", torch.arange(0, degree + 1, 1))
|
|
85
|
+
|
|
86
|
+
def forward(self, x):
|
|
87
|
+
# Since Chebyshev polynomial is defined in [-1, 1]
|
|
88
|
+
# We need to normalize x to [-1, 1] using tanh
|
|
89
|
+
# View and repeat input degree + 1 times
|
|
90
|
+
b,c_in = x.shape
|
|
91
|
+
if self.pre_mul:
|
|
92
|
+
mul_1 = x[:,::2]
|
|
93
|
+
mul_2 = x[:,1::2]
|
|
94
|
+
mul_res = mul_1 * mul_2
|
|
95
|
+
x = torch.concat([x[:,:x.shape[1]//2], mul_res])
|
|
96
|
+
x = x.view((b, c_in, 1)).expand(
|
|
97
|
+
-1, -1, self.degree + 1
|
|
98
|
+
) # shape = (batch_size, inputdim, self.degree + 1)
|
|
99
|
+
# Apply acos
|
|
100
|
+
x = torch.tanh(x)
|
|
101
|
+
x = torch.tanh(x)
|
|
102
|
+
x = torch.acos(x)
|
|
103
|
+
# x = torch.acos(torch.clamp(x, -1 + self.epsilon, 1 - self.epsilon))
|
|
104
|
+
# # Multiply by arange [0 .. degree]
|
|
105
|
+
x = x* self.arange
|
|
106
|
+
# Apply cos
|
|
107
|
+
x = x.cos()
|
|
108
|
+
# Compute the Chebyshev interpolation
|
|
109
|
+
y = torch.einsum(
|
|
110
|
+
"bid,iod->bo", x, self.cheby_coeffs
|
|
111
|
+
) # shape = (batch_size, outdim)
|
|
112
|
+
y = y.view(-1, self.outdim)
|
|
113
|
+
if self.post_mul:
|
|
114
|
+
mul_1 = y[:,::2]
|
|
115
|
+
mul_2 = y[:,1::2]
|
|
116
|
+
mul_res = mul_1 * mul_2
|
|
117
|
+
y = torch.concat([y[:,:y.shape[1]//2], mul_res])
|
|
118
|
+
return y
|
|
119
|
+
|
|
120
|
+
class ChebyKANLayer(nn.Module):
|
|
121
|
+
def __init__(self, in_features, out_features,order):
|
|
122
|
+
super().__init__()
|
|
123
|
+
self.fc1 = ChebyKANLinear(
|
|
124
|
+
in_features,
|
|
125
|
+
out_features,
|
|
126
|
+
order)
|
|
127
|
+
def forward(self, x):
|
|
128
|
+
B, N, C = x.shape
|
|
129
|
+
x = self.fc1(x.reshape(B*N,C))
|
|
130
|
+
x = x.reshape(B,N,-1).contiguous()
|
|
131
|
+
return x
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class FrequencyDecomp(nn.Module):
|
|
135
|
+
|
|
136
|
+
def __init__(self, seq_len,down_sampling_window,down_sampling_layers):
|
|
137
|
+
super(FrequencyDecomp, self).__init__()
|
|
138
|
+
self.seq_len = seq_len
|
|
139
|
+
self.down_sampling_window = down_sampling_window
|
|
140
|
+
self.down_sampling_layers = down_sampling_layers
|
|
141
|
+
|
|
142
|
+
def forward(self, level_list):
|
|
143
|
+
|
|
144
|
+
level_list_reverse = level_list.copy()
|
|
145
|
+
level_list_reverse.reverse()
|
|
146
|
+
out_low = level_list_reverse[0]
|
|
147
|
+
out_high = level_list_reverse[1]
|
|
148
|
+
out_level_list = [out_low]
|
|
149
|
+
for i in range(len(level_list_reverse) - 1):
|
|
150
|
+
out_high_res = self.frequency_interpolation(out_low.transpose(1,2),
|
|
151
|
+
self.seq_len // (self.down_sampling_window ** (self.down_sampling_layers-i)),
|
|
152
|
+
self.seq_len // (self.down_sampling_window ** (self.down_sampling_layers-i-1))
|
|
153
|
+
).transpose(1,2)
|
|
154
|
+
out_high_left = out_high - out_high_res
|
|
155
|
+
out_low = out_high
|
|
156
|
+
if i + 2 <= len(level_list_reverse) - 1:
|
|
157
|
+
out_high = level_list_reverse[i + 2]
|
|
158
|
+
out_level_list.append(out_high_left)
|
|
159
|
+
out_level_list.reverse()
|
|
160
|
+
return out_level_list
|
|
161
|
+
|
|
162
|
+
def frequency_interpolation(self,x,seq_len,target_len):
|
|
163
|
+
len_ratio = seq_len/target_len
|
|
164
|
+
x_fft = torch.fft.rfft(x, dim=2)
|
|
165
|
+
out_fft = torch.zeros([x_fft.size(0),x_fft.size(1),target_len//2+1],dtype=x_fft.dtype).to(x_fft.device)
|
|
166
|
+
out_fft[:,:,:seq_len//2+1] = x_fft
|
|
167
|
+
out = torch.fft.irfft(out_fft, dim=2)
|
|
168
|
+
out = out * len_ratio
|
|
169
|
+
return out
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class FrequencyMixing(nn.Module):
|
|
173
|
+
|
|
174
|
+
def __init__(self, d_model,seq_len,begin_order,down_sampling_window,down_sampling_layers):
|
|
175
|
+
super(FrequencyMixing, self).__init__()
|
|
176
|
+
self.front_block = M_KAN(d_model,seq_len // (down_sampling_window ** (down_sampling_layers)),order=begin_order)
|
|
177
|
+
self.d_model = d_model
|
|
178
|
+
self.seq_len = seq_len
|
|
179
|
+
self.begin_order = begin_order
|
|
180
|
+
self.down_sampling_window = down_sampling_window
|
|
181
|
+
self.down_sampling_layers = down_sampling_layers
|
|
182
|
+
|
|
183
|
+
self.front_blocks = torch.nn.ModuleList(
|
|
184
|
+
[
|
|
185
|
+
M_KAN(d_model,
|
|
186
|
+
seq_len // (down_sampling_window ** (down_sampling_layers-i-1)),
|
|
187
|
+
order=i+begin_order+1)
|
|
188
|
+
for i in range(down_sampling_layers)
|
|
189
|
+
])
|
|
190
|
+
|
|
191
|
+
def forward(self, level_list):
|
|
192
|
+
level_list_reverse = level_list.copy()
|
|
193
|
+
level_list_reverse.reverse()
|
|
194
|
+
out_low = level_list_reverse[0]
|
|
195
|
+
out_high = level_list_reverse[1]
|
|
196
|
+
out_low = self.front_block(out_low)
|
|
197
|
+
out_level_list = [out_low]
|
|
198
|
+
for i in range(len(level_list_reverse) - 1):
|
|
199
|
+
out_high = self.front_blocks[i](out_high)
|
|
200
|
+
out_high_res = self.frequency_interpolation(out_low.transpose(1,2),
|
|
201
|
+
self.seq_len // (self.down_sampling_window ** (self.down_sampling_layers-i)),
|
|
202
|
+
self.seq_len // (self.down_sampling_window ** (self.down_sampling_layers-i-1))
|
|
203
|
+
).transpose(1,2)
|
|
204
|
+
out_high = out_high + out_high_res
|
|
205
|
+
out_low = out_high
|
|
206
|
+
if i + 2 <= len(level_list_reverse) - 1:
|
|
207
|
+
out_high = level_list_reverse[i + 2]
|
|
208
|
+
out_level_list.append(out_low)
|
|
209
|
+
out_level_list.reverse()
|
|
210
|
+
return out_level_list
|
|
211
|
+
|
|
212
|
+
def frequency_interpolation(self,x,seq_len,target_len):
|
|
213
|
+
len_ratio = seq_len/target_len
|
|
214
|
+
x_fft = torch.fft.rfft(x, dim=2)
|
|
215
|
+
out_fft = torch.zeros([x_fft.size(0),x_fft.size(1),target_len//2+1],dtype=x_fft.dtype).to(x_fft.device)
|
|
216
|
+
out_fft[:,:,:seq_len//2+1] = x_fft
|
|
217
|
+
out = torch.fft.irfft(out_fft, dim=2)
|
|
218
|
+
out = out * len_ratio
|
|
219
|
+
return out
|
|
220
|
+
|
|
221
|
+
class M_KAN(nn.Module):
|
|
222
|
+
def __init__(self,d_model,seq_len,order):
|
|
223
|
+
super().__init__()
|
|
224
|
+
self.channel_mixer = nn.Sequential(
|
|
225
|
+
ChebyKANLayer(d_model, d_model,order)
|
|
226
|
+
)
|
|
227
|
+
self.conv = BasicConv(d_model,d_model,kernel_size=3,degree=order,groups=d_model)
|
|
228
|
+
def forward(self,x):
|
|
229
|
+
x1 = self.channel_mixer(x)
|
|
230
|
+
x2 = self.conv(x)
|
|
231
|
+
out = x1 + x2
|
|
232
|
+
return out
|
|
233
|
+
|
|
234
|
+
class BasicConv(nn.Module):
|
|
235
|
+
def __init__(self,c_in,c_out, kernel_size, degree,stride=1, padding=0, dilation=1, groups=1, act=False, bn=False, bias=False,dropout=0.):
|
|
236
|
+
super(BasicConv, self).__init__()
|
|
237
|
+
self.out_channels = c_out
|
|
238
|
+
self.conv = nn.Conv1d(c_in,c_out, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, dilation=dilation, groups=groups, bias=bias)
|
|
239
|
+
self.bn = nn.BatchNorm1d(c_out) if bn else None
|
|
240
|
+
self.act = nn.GELU() if act else None
|
|
241
|
+
self.dropout = nn.Dropout(dropout)
|
|
242
|
+
def forward(self, x):
|
|
243
|
+
if self.bn is not None:
|
|
244
|
+
x = self.bn(x)
|
|
245
|
+
x = self.conv(x.transpose(-1,-2)).transpose(-1,-2)
|
|
246
|
+
if self.act is not None:
|
|
247
|
+
x = self.act(x)
|
|
248
|
+
if self.dropout is not None:
|
|
249
|
+
x = self.dropout(x)
|
|
250
|
+
return x
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class moving_avg(nn.Module):
|
|
254
|
+
"""
|
|
255
|
+
Moving average block to highlight the trend of time series
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
def __init__(self, kernel_size, stride):
|
|
259
|
+
super(moving_avg, self).__init__()
|
|
260
|
+
self.kernel_size = kernel_size
|
|
261
|
+
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
|
262
|
+
|
|
263
|
+
def forward(self, x):
|
|
264
|
+
# padding on the both ends of time series
|
|
265
|
+
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
|
266
|
+
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
|
267
|
+
x = torch.cat([front, x, end], dim=1)
|
|
268
|
+
x = self.avg(x.permute(0, 2, 1))
|
|
269
|
+
x = x.permute(0, 2, 1)
|
|
270
|
+
return x
|
|
271
|
+
|
|
272
|
+
class series_decomp(nn.Module):
|
|
273
|
+
"""
|
|
274
|
+
Series decomposition block
|
|
275
|
+
"""
|
|
276
|
+
|
|
277
|
+
def __init__(self, kernel_size):
|
|
278
|
+
super(series_decomp, self).__init__()
|
|
279
|
+
self.moving_avg = moving_avg(kernel_size, stride=1)
|
|
280
|
+
|
|
281
|
+
def forward(self, x):
|
|
282
|
+
moving_mean = self.moving_avg(x)
|
|
283
|
+
res = x - moving_mean
|
|
284
|
+
return res, moving_mean
|
|
File without changes
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch.fft
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Inception_Block_V1(nn.Module):
|
|
8
|
+
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
|
|
9
|
+
super(Inception_Block_V1, self).__init__()
|
|
10
|
+
self.in_channels = in_channels
|
|
11
|
+
self.out_channels = out_channels
|
|
12
|
+
self.num_kernels = num_kernels
|
|
13
|
+
kernels = []
|
|
14
|
+
for i in range(self.num_kernels):
|
|
15
|
+
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
|
|
16
|
+
self.kernels = nn.ModuleList(kernels)
|
|
17
|
+
if init_weight:
|
|
18
|
+
self._initialize_weights()
|
|
19
|
+
|
|
20
|
+
def _initialize_weights(self):
|
|
21
|
+
for m in self.modules():
|
|
22
|
+
if isinstance(m, nn.Conv2d):
|
|
23
|
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
24
|
+
if m.bias is not None:
|
|
25
|
+
nn.init.constant_(m.bias, 0)
|
|
26
|
+
|
|
27
|
+
def forward(self, x):
|
|
28
|
+
res_list = []
|
|
29
|
+
for i in range(self.num_kernels):
|
|
30
|
+
res_list.append(self.kernels[i](x))
|
|
31
|
+
res = torch.stack(res_list, dim=-1).mean(-1)
|
|
32
|
+
return res
|
|
33
|
+
|
|
34
|
+
def FFT_for_Period(x, k=2):
|
|
35
|
+
# [B, T, C]
|
|
36
|
+
xf = torch.fft.rfft(x, dim=1)
|
|
37
|
+
# find period by amplitudes
|
|
38
|
+
frequency_list = abs(xf).mean(0).mean(-1)
|
|
39
|
+
frequency_list[0] = 0
|
|
40
|
+
_, top_list = torch.topk(frequency_list, k)
|
|
41
|
+
top_list = top_list.detach().cpu().numpy()
|
|
42
|
+
period = x.shape[1] // top_list
|
|
43
|
+
return period, abs(xf).mean(-1)[:, top_list]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TimesBlock(nn.Module):
|
|
48
|
+
def __init__(self, seq_len,pred_len,top_k,d_model,d_ff,num_kernels):
|
|
49
|
+
super(TimesBlock, self).__init__()
|
|
50
|
+
self.seq_len = seq_len
|
|
51
|
+
self.pred_len = pred_len
|
|
52
|
+
self.k = top_k
|
|
53
|
+
# parameter-efficient design
|
|
54
|
+
self.conv = nn.Sequential(
|
|
55
|
+
Inception_Block_V1(d_model, d_ff,
|
|
56
|
+
num_kernels=num_kernels),
|
|
57
|
+
nn.GELU(),
|
|
58
|
+
Inception_Block_V1(d_ff, d_model,
|
|
59
|
+
num_kernels=num_kernels)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def forward(self, x):
|
|
63
|
+
B, T, N = x.size()
|
|
64
|
+
|
|
65
|
+
period_list, period_weight = FFT_for_Period(x, self.k)
|
|
66
|
+
|
|
67
|
+
res = []
|
|
68
|
+
for i in range(self.k):
|
|
69
|
+
period = period_list[i]
|
|
70
|
+
# padding
|
|
71
|
+
if (self.seq_len + self.pred_len) % period != 0:
|
|
72
|
+
length = (
|
|
73
|
+
((self.seq_len + self.pred_len) // period) + 1) * period
|
|
74
|
+
padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
|
|
75
|
+
out = torch.cat([x, padding], dim=1)
|
|
76
|
+
else:
|
|
77
|
+
length = (self.seq_len + self.pred_len)
|
|
78
|
+
out = x
|
|
79
|
+
# reshape
|
|
80
|
+
out = out.reshape(B, length // period, period,
|
|
81
|
+
N).permute(0, 3, 1, 2).contiguous()
|
|
82
|
+
# 2D conv: from 1d Variation to 2d Variation
|
|
83
|
+
out = self.conv(out)
|
|
84
|
+
# reshape back
|
|
85
|
+
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
|
|
86
|
+
res.append(out[:, :(self.seq_len + self.pred_len), :])
|
|
87
|
+
res = torch.stack(res, dim=-1)
|
|
88
|
+
# adaptive aggregation
|
|
89
|
+
period_weight = F.softmax(period_weight, dim=1)
|
|
90
|
+
period_weight = period_weight.unsqueeze(
|
|
91
|
+
1).unsqueeze(1).repeat(1, T, N, 1)
|
|
92
|
+
res = torch.sum(res * period_weight, -1)
|
|
93
|
+
# residual connection
|
|
94
|
+
res = res + x
|
|
95
|
+
return res
|
|
File without changes
|
dsipts/version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.1.15"
|