dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,213 @@
1
+
2
+ from abc import abstractmethod,ABC
3
+ from sklearn.cluster import BisectingKMeans
4
+ from scipy.stats import bootstrap
5
+ from torch.utils.data import Dataset
6
+ import torch
7
+ import numpy as np
8
+ import logging
9
+ from .utils import MyDataset
10
+
11
+
12
+
13
+ class VVADataset(Dataset):
14
+
15
+
16
+ def __init__(self,x,y,y_orig,t,length_in,length_out, num_digits):
17
+ self.length_in = length_in
18
+ self.length_out = length_out
19
+ self.num_digits = num_digits
20
+ self.x_emb = torch.tensor(x).long()
21
+ self.y_emb = torch.tensor(y).long()
22
+ self.y = torch.tensor(y_orig)
23
+ self.t = t
24
+
25
+ def __len__(self):
26
+ """
27
+
28
+ :meta private:
29
+ """
30
+ return len(self.x_emb) # ...
31
+
32
+ def get_vocab_size(self):
33
+ """
34
+
35
+ :meta private:
36
+ """
37
+ return self.num_digits
38
+
39
+ def get_block_size(self):
40
+ """
41
+
42
+ :meta private:
43
+ """
44
+ return self.length * 2 - 1
45
+
46
+ def __getitem__(self, idx):
47
+ """
48
+ :meta private:
49
+ """
50
+
51
+ inp = self.x_emb[idx]
52
+ sol = self.y_emb[idx]
53
+ cat = torch.cat((inp, sol), dim=0)
54
+
55
+ # the inputs to the transformer will be the offset sequence
56
+ x = cat[:-1].clone()
57
+ y = cat[1:].clone()
58
+ # we only want to predict at output locations, mask out the loss at the input locations
59
+ y[:self.length_out-1] = -1
60
+ return {'x_emb':x, 'y_emb':y, 'y':self.y[idx]}
61
+
62
+
63
+ class Modifier(ABC):
64
+ def __init__(self,**kwargs):
65
+ """In the constructor you can store some parameters of the modifier. It will be saved when the timeseries is saved.
66
+ """
67
+ super(Modifier, self).__init__()
68
+ self.__dict__.update(kwargs)
69
+
70
+ @abstractmethod
71
+ def fit_transform(self,train:MyDataset,val:MyDataset)->[Dataset,Dataset]:
72
+ """This funtion is called before the training procedure and it should tasnform the standard Dataset into the new Dataset
73
+
74
+ Args:
75
+ train (MyDataset): initial train `Dataset`
76
+ val (MyDataset): initial validation `Dataset`
77
+
78
+ Returns:
79
+ Dataset, Dataset: transformed train and validation `Datasets`
80
+ """
81
+ return train,val
82
+
83
+ @abstractmethod
84
+ def transform(self,test:MyDataset)->Dataset:
85
+ """Similar to `fit_transform` but only transformation task will be performed, it is used in the inference function before calling the inference method
86
+ Args:
87
+ test (MyDataset): initial test `Dataset`
88
+
89
+ Returns:
90
+ Dataset: transformed test `Dataset`
91
+ """
92
+ return test
93
+
94
+ @abstractmethod
95
+ def inverse_transform(self,res:np.array,real:np.array)->[np.array,np.array]:
96
+ """The results must be reverted respect to the prediction task
97
+
98
+ Args:
99
+ res (np.array): raw prediction
100
+ real (np.array): raw real data
101
+
102
+ Returns:
103
+ [np.array, np.array] : inverse transfrmation of the predictions and the real data
104
+ """
105
+ return res
106
+
107
+
108
+ class ModifierVVA(Modifier):
109
+ """This modifiers is used for the custom model VVA. The initial data are divided in smaller segments and then tokenized using a clustering procedure (fit_trasform).
110
+ The centroids of the clusters are stored. A GPT model is then trained on the tokens an the predictions are reverted using the centroid information.
111
+ """
112
+
113
+
114
+ def fit_transform(self,train:MyDataset,val:MyDataset)->[Dataset,Dataset]:
115
+ """BisectingKMeans is used on segments of length `token_split`
116
+
117
+ Args:
118
+ train (MyDataset): initial train `Dataset`
119
+ val (MyDataset): initial validation `Dataset`
120
+
121
+ Returns:
122
+ Dataset, Dataset: transformed train and validation `Datasets`
123
+ """
124
+ idx_target = train.idx_target
125
+ assert len(idx_target)==1, print('This works only with single channel prediction')
126
+
127
+ samples,length,_ = train.data['y'].shape
128
+ tmp = train.data['x_num_past'][:,:,idx_target[0]].reshape(samples,-1,self.token_split)
129
+ _,length_in, _ = tmp.shape
130
+ length_out = length//self.token_split
131
+ tmp = tmp.reshape(-1,self.token_split)
132
+ cl = BisectingKMeans(n_clusters=self.max_voc_size)
133
+ clusters = cl.fit_predict(tmp)
134
+ self.cl = cl
135
+ self.centroids = []
136
+ cls, counts = np.unique(clusters,return_counts=True)
137
+ logging.info(counts)
138
+
139
+ for i in cls:
140
+ res = []
141
+ data = tmp[np.where(clusters==i)[0]]
142
+ if len(data)>1:
143
+ for j in range(data.shape[1]):
144
+ bootstrap_ci = bootstrap((data[:,j],), np.median,n_resamples=50, confidence_level=0.9,random_state=1, method='percentile')
145
+ res.append([bootstrap_ci.confidence_interval.low,np.median(data[:,j]),bootstrap_ci.confidence_interval.high])
146
+ self.centroids.append(np.array(res))
147
+ else:
148
+ self.centroids.append(np.repeat(data.T,3,axis=1))
149
+
150
+ self.centroids = np.array(self.centroids) ##clusters x length x 3
151
+
152
+ x_train = clusters.reshape(-1,length_in)
153
+ samples = train.data['y'].shape[0]
154
+ y_train = cl.predict(train.data['y'].squeeze().reshape(samples,-1,self.token_split).reshape(-1,self.token_split)).reshape(-1,length_out)
155
+ samples = val.data['y'].shape[0]
156
+ y_validation = cl.predict(val.data['y'].squeeze().reshape(samples,-1,self.token_split).reshape(-1,self.token_split)).reshape(-1,length_out)
157
+ x_validation = cl.predict(val.data['x_num_past'][:,:,idx_target[0]].reshape(samples,-1,self.token_split).reshape(-1,self.token_split)).reshape(-1,length_in)
158
+ train_dataset = VVADataset(x_train,y_train,train.data['y'].squeeze(),train.t,length_in,length_out,self.max_voc_size)
159
+ validation_dataset = VVADataset(x_validation,y_validation,val.data['y'].squeeze(),val.t,length_in,length_out,self.max_voc_size)
160
+ return train_dataset,validation_dataset
161
+
162
+
163
+
164
+ def transform(self,test:MyDataset)->Dataset:
165
+ """Similar to `fit_transform` but only transformation task will be performed
166
+ Args:
167
+ test (MyDataset): test val `Dataset`
168
+
169
+ Returns:
170
+ Dataset: transformed test `Dataset`
171
+ """
172
+
173
+ idx_target = test.idx_target
174
+
175
+ samples,length,_ = test.data['y'].shape
176
+ tmp = test.data['x_num_past'][:,:,idx_target[0]].reshape(samples,-1,self.token_split)
177
+ _,length_in, _ = tmp.shape
178
+ length_out = length//self.token_split
179
+
180
+ tmp = tmp.reshape(-1,self.token_split)
181
+ clusters = self.cl.predict(tmp)
182
+ x = clusters.reshape(-1,length_in)
183
+ y = self.cl.predict(test.data['y'].squeeze().reshape(samples,-1,self.token_split).reshape(-1,self.token_split)).reshape(-1,length_out)
184
+
185
+ return VVADataset(x,y,test.data['y'].squeeze(),test.t,length_in,length_out,self.max_voc_size)
186
+
187
+ def inverse_transform(self,res:np.array,real:np.array)->[np.array,np.array]:
188
+ """The results must be reverted respect to the prediction task
189
+
190
+ Args:
191
+ res (np.array): raw prediction
192
+
193
+ Returns:
194
+ np.array: inverse transofrmation of the predictions
195
+ """
196
+ tot = []
197
+ for sample in res:
198
+ tmp_sample = []
199
+ for index in sample:
200
+ tmp = []
201
+ for i in index:
202
+ tmp.append(self.centroids[i])
203
+ tmp = np.array(tmp)
204
+ if tmp.shape[0]==1:
205
+ tmp2 = tmp[0,:,:]
206
+ else:
207
+ tmp2 = tmp.mean(axis=0)
208
+ tmp2[:,0] -= 1.96*tmp.std(axis=0)[:,0] #using confidence interval
209
+ tmp2[:,2] += 1.96*tmp.std(axis=0)[:,2]
210
+ tmp_sample.append(tmp2)
211
+ tot.append(np.vstack(tmp_sample))
212
+
213
+ return np.expand_dims(np.stack(tot),2),np.expand_dims(real,2)
@@ -0,0 +1,173 @@
1
+ from enum import Enum
2
+ from typing import Union
3
+ import pandas as pd
4
+ from torch.utils.data import Dataset
5
+ import numpy as np
6
+
7
+ try:
8
+ from lightning.pytorch.callbacks import Callback
9
+ except:
10
+ from pytorch_lightning import Callback
11
+ import torch
12
+ import os
13
+ import logging
14
+ from typing import Union
15
+ def beauty_string(message:str,type:str,verbose:bool):
16
+
17
+ size = 150
18
+ if verbose is True:
19
+ if type=='block':
20
+ characters = len(message)
21
+ border = max((100-characters)//2-5,0)
22
+ logging.info('\n')
23
+ logging.info(f"{'#'*size}")
24
+ logging.info(f"{'#'*border}{' '*(size-border*2)}{'#'*border}")
25
+ logging.info(f"{ message:^{size}}")
26
+ logging.info(f"{'#'*border}{' '*(size-border*2)}{'#'*border}")
27
+ logging.info(f"{'#'*size}")
28
+ elif type=='section':
29
+ logging.info('\n')
30
+ logging.info(f"{'#'*size}")
31
+ logging.info(f"{ message:^{size}}")
32
+ logging.info(f"{'#'*size}")
33
+ elif type=='info':
34
+ logging.info(f"{ message:^{size}}")
35
+ else:
36
+ logging.info(message)
37
+
38
+
39
+
40
+
41
+ def extend_time_df(x:pd.DataFrame,freq:Union[str,int],group:Union[str,None]=None,global_minmax:bool=False)-> pd.DataFrame:
42
+ """Utility for generating a full dataset and then merge the real data
43
+
44
+ Args:
45
+ x (pd.DataFrame): dataframe containing the column time
46
+ freq (str): frequency (in pandas notation) of the resulting dataframe
47
+ group (string or None): if not None the min max are computed by the group column, default None
48
+ global_minmax (bool): if True the min_max is computed globally for each group. Usually used for stacked model
49
+ Returns:
50
+ pd.DataFrame: a dataframe with the column time ranging from thr minumum of x to the maximum with frequency `freq`
51
+ """
52
+
53
+ if group is None:
54
+
55
+ if isinstance(freq,int):
56
+ empty = pd.DataFrame({'time':list(range(x.time.min(),x.time.max(),freq))})
57
+ else:
58
+ empty = pd.DataFrame({'time':pd.date_range(x.time.min(),x.time.max(),freq=freq)})
59
+
60
+ else:
61
+
62
+ if global_minmax:
63
+ _min = pd.DataFrame({group:x[group].unique(),'time':x.time.min()})
64
+ _max = pd.DataFrame({group:x[group].unique(),'time':x.time.max()})
65
+
66
+ else:
67
+ _min = x.groupby(group).time.min().reset_index()
68
+ _max = x.groupby(group).time.max().reset_index()
69
+ empty = []
70
+ for c in x[group].unique():
71
+ if isinstance(freq,int):
72
+ empty.append(pd.DataFrame({group:c,'time':np.arange(_min.time[_min[group]==c].values[0],_max.time[_max[group]==c].values[0],freq)}))
73
+
74
+ else:
75
+ empty.append(pd.DataFrame({group:c,'time':pd.date_range(_min.time[_min[group]==c].values[0],_max.time[_max[group]==c].values[0],freq=freq)}))
76
+
77
+ empty = pd.concat(empty,ignore_index=True)
78
+ return empty
79
+
80
+
81
+ class MetricsCallback(Callback):
82
+ """PyTorch Lightning metric callback.
83
+
84
+ :meta private:
85
+ """
86
+
87
+ def __init__(self,dirpath):
88
+ super().__init__()
89
+ self.dirpath = dirpath
90
+ self.metrics = {'val_loss':[],'train_loss':[]}
91
+
92
+
93
+
94
+ def on_validation_end(self, trainer, pl_module):
95
+ for c in trainer.callback_metrics:
96
+ self.metrics[c].append(trainer.callback_metrics[c].item())
97
+ ##Write csv in a convenient way
98
+ tmp = self.metrics.copy()
99
+ if len(tmp['train_loss']) >0:
100
+ tmp['val_loss'] = tmp['val_loss'][-len(tmp['train_loss']):]
101
+ else:
102
+ tmp['val_loss'] = tmp['val_loss'][2:]
103
+
104
+ losses = pd.DataFrame(tmp)
105
+ losses.to_csv(os.path.join(self.dirpath,'loss.csv'),index=False)
106
+
107
+
108
+ def on_train_end(self, trainer, pl_module):
109
+ losses = self.metrics
110
+ ##non so perche' le prime due le chiama prima del train
111
+ if len(losses['train_loss']) >0:
112
+ losses['val_loss'] =losses['val_loss'][-len(losses['train_loss']):]
113
+ else:
114
+ losses['val_loss'] = losses['val_loss'][2:]
115
+
116
+ #losses['val_loss'] = losses['val_loss'][2:]
117
+ losses = pd.DataFrame(losses)
118
+ ##accrocchio per quando ci sono piu' gpu!
119
+ losses.to_csv(os.path.join(self.dirpath,f'{np.random.randint(10000)}__losses__.csv'),index=False)
120
+ print("Saving losses on file because multigpu not working")
121
+
122
+
123
+
124
+ class MyDataset(Dataset):
125
+
126
+ def __init__(self, data:dict,t:np.array,groups:np.array,idx_target:Union[np.array,None],idx_target_future:Union[np.array,None])->torch.utils.data.Dataset:
127
+ """
128
+ Extension of Dataset class. While training the returned item is a batch containing the standard keys
129
+
130
+ Args:
131
+ data (dict): a dictionary. Each key is a np.array containing the data. The keys are:
132
+ y : the target variable(s)
133
+ x_num_past: the numerical past variables
134
+ x_num_future: the numerical future variables
135
+ x_cat_past: the categorical past variables
136
+ x_cat_future: the categorical future variables
137
+ idx_target: index of target features in the past array
138
+ t (np.array): the time array related to the target variables
139
+ idx_target (Union[np.array,None]): you can specify the index in the past data that represent the input features (for differntial analysis or detrending strategies)
140
+ idx_target_future (Union[np.array,None]): you can specify the index in the future data that represent the input features (for differntial analysis or detrending strategies)
141
+
142
+ Returns:
143
+ torch.utils.data.Dataset: a torch Dataset to be used in a Dataloader
144
+ """
145
+ self.data = data
146
+ self.t = t
147
+ self.groups = groups
148
+ self.idx_target = np.array(idx_target) if idx_target is not None else None
149
+ self.idx_target_future = np.array(idx_target_future) if idx_target_future is not None else None
150
+
151
+
152
+
153
+ def __len__(self):
154
+
155
+ return len(self.data['x_num_past'])
156
+
157
+ def __getitem__(self, idxs):
158
+ sample = {}
159
+ for k in self.data:
160
+ sample[k] = self.data[k][idxs]
161
+ if self.idx_target is not None:
162
+ sample['idx_target'] = self.idx_target
163
+ if self.idx_target_future is not None:
164
+ sample['idx_target_future'] = self.idx_target_future
165
+ return sample
166
+
167
+ class ActionEnum(Enum):
168
+ """action of categorical variable
169
+
170
+ :meta private:
171
+ """
172
+ multiplicative: str = 'multiplicative'
173
+ additive: str = 'additive'
@@ -0,0 +1,199 @@
1
+ ## Copyright 2022 DLinear Authors (https://github.com/cure-lab/LTSF-Linear/tree/main?tab=Apache-2.0-1-ov-file#readme)
2
+ ## Code modified for align the notation and the batch generation
3
+ ## extended to all present in informer, autoformer folder
4
+
5
+
6
+ from torch import nn
7
+ import torch
8
+
9
+ try:
10
+ import lightning.pytorch as pl
11
+ from .base_v2 import Base
12
+ OLD_PL = False
13
+ except:
14
+ import pytorch_lightning as pl
15
+ OLD_PL = True
16
+ from .base import Base
17
+ from typing import List,Union
18
+ from ..data_structure.utils import beauty_string
19
+ from .utils import get_activation,get_scope,QuantileLossMO
20
+ from .autoformer.layers import AutoCorrelation, AutoCorrelationLayer, Encoder, Decoder,\
21
+ EncoderLayer, DecoderLayer, my_Layernorm, series_decomp,PositionalEmbedding
22
+ from .utils import Embedding_cat_variables
23
+
24
+
25
+
26
+ class Autoformer(Base):
27
+ handle_multivariate = True
28
+ handle_future_covariates = True
29
+ handle_categorical_variables = True
30
+ handle_quantile_loss= True
31
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
32
+
33
+ def __init__(self,
34
+ label_len: int,
35
+ d_model:int,
36
+ dropout_rate:float,
37
+ kernel_size:int,
38
+ activation:str='torch.nn.ReLU',
39
+ factor: float=0.5,
40
+ n_head:int=1,
41
+ n_layer_encoder:int=2,
42
+ n_layer_decoder:int=2,
43
+ hidden_size:int=1048,
44
+ **kwargs
45
+ )->None:
46
+ """Autoformer from https://github.com/cure-lab/LTSF-Linear
47
+
48
+ Args:
49
+ label_len (int): see the original implementation, seems like a warmup dimension (the decoder part will produce also some past predictions that are filter out at the end)
50
+ d_model (int): embedding dimension of the attention layer
51
+ dropout_rate (float): dropout raye
52
+ kernel_size (int): kernel size
53
+ activation (str, optional): _description_. Defaults to 'torch.nn.ReLU'.
54
+ factor (int, optional): parameter of `.autoformer.layers.AutoCorrelation` for find the top k. Defaults to 0.5.
55
+ n_head (int, optional): number of heads. Defaults to 1.
56
+ n_layer_encoder (int, optional): number of encoder layers. Defaults to 2.
57
+ n_layer_decoder (int, optional): number of decoder layers. Defaults to 2.
58
+ hidden_size (int, optional): output dimension of the transformer layer. Defaults to 1048.
59
+ """
60
+ super().__init__(**kwargs)
61
+ beauty_string(self.description,'info',True)
62
+
63
+ if activation == 'torch.nn.SELU':
64
+ beauty_string('SELU do not require BN','info',self.verbose)
65
+ if isinstance(activation,str):
66
+ activation = get_activation(activation)
67
+ else:
68
+ beauty_string('There is a bug in pytorch lightening, the constructior is called twice ','info',self.verbose)
69
+
70
+
71
+
72
+ self.save_hyperparameters(logger=False)
73
+
74
+
75
+
76
+
77
+ self.seq_len = self.past_steps
78
+ self.label_len = label_len
79
+ self.pred_len = self.future_steps
80
+
81
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
82
+ self.emb_fut = Embedding_cat_variables(self.future_steps+label_len,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
83
+ emb_past_out_channel = self.emb_past.output_channels
84
+ emb_fut_out_channel = self.emb_fut.output_channels
85
+
86
+
87
+ # Decomp
88
+ self.decomp = series_decomp(kernel_size)
89
+
90
+
91
+
92
+ self.linear_encoder = nn.Sequential(nn.Linear(self.past_channels+emb_past_out_channel,self.past_channels*2),
93
+ activation(),
94
+ nn.Dropout(dropout_rate),
95
+ nn.Linear(self.past_channels*2,d_model*2),
96
+ activation(),
97
+ nn.Dropout(dropout_rate),
98
+ nn.Linear(d_model*2,d_model))
99
+
100
+ self.linear_decoder = nn.Sequential(nn.Linear(self.future_channels+emb_fut_out_channel,self.future_channels*2),
101
+ activation(),
102
+ nn.Dropout(dropout_rate),
103
+ nn.Linear(self.future_channels*2,d_model*2),
104
+ activation() ,nn.Dropout(dropout_rate),
105
+ nn.Linear(d_model*2,d_model))
106
+
107
+ #self.final_layer = nn.Linear(self.past_channels,self.out_channels)
108
+
109
+ # Encoder
110
+ self.encoder = Encoder(
111
+ [
112
+ EncoderLayer(
113
+ AutoCorrelationLayer(
114
+ AutoCorrelation(False, factor, attention_dropout=dropout_rate,
115
+ output_attention=False),
116
+ d_model, n_head),
117
+ d_model,
118
+ hidden_size,
119
+ moving_avg=kernel_size,
120
+ dropout=dropout_rate,
121
+ activation=activation
122
+ ) for _ in range(n_layer_encoder)
123
+ ],
124
+ norm_layer=my_Layernorm(d_model)
125
+ )
126
+ # Decoder
127
+ self.decoder = Decoder(
128
+ [
129
+ DecoderLayer(
130
+ AutoCorrelationLayer(
131
+ AutoCorrelation(True, factor, attention_dropout=dropout_rate,
132
+ output_attention=False),
133
+ d_model, n_head),
134
+ AutoCorrelationLayer(
135
+ AutoCorrelation(False, factor, attention_dropout=dropout_rate,
136
+ output_attention=False),
137
+ d_model, n_head),
138
+ d_model,
139
+ self.out_channels,
140
+ hidden_size,
141
+ moving_avg=kernel_size,
142
+ dropout=dropout_rate,
143
+ activation=activation,
144
+ )
145
+ for _ in range(n_layer_decoder)
146
+ ],
147
+ norm_layer=my_Layernorm(d_model),
148
+ projection=nn.Linear(d_model, self.out_channels*self.mul, bias=True)
149
+ )
150
+ self.projection = nn.Linear(self.past_channels,self.out_channels*self.mul )
151
+
152
+ def forward(self, batch):
153
+
154
+
155
+
156
+ idx_target_future = batch['idx_target_future'][0]
157
+ x = batch['x_num_past'].to(self.device)
158
+ BS = x.shape[0]
159
+ if 'x_cat_future' in batch.keys():
160
+ emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
161
+ else:
162
+ emb_fut = self.emb_fut(BS,None)
163
+ if 'x_cat_past' in batch.keys():
164
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
165
+ else:
166
+ emb_past = self.emb_past(BS,None)
167
+
168
+
169
+ if 'x_num_future' in batch.keys():
170
+ x_future = batch['x_num_future'].to(self.device)
171
+ x_future[:,-self.pred_len:,idx_target_future] = 0
172
+
173
+
174
+
175
+
176
+ mean = torch.mean(x, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)
177
+
178
+ zeros = torch.zeros([x_future.shape[0], self.pred_len, x.shape[2]], device=x.device)
179
+ seasonal_init, trend_init = self.decomp(x)
180
+ # decoder input
181
+ trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
182
+ seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1)
183
+ # enc
184
+ enc_out = self.linear_encoder(torch.cat([x,emb_past],2))
185
+ enc_out, attns = self.encoder(enc_out, attn_mask=None)
186
+ # dec
187
+ dec_out = self.linear_decoder(torch.cat([x_future,emb_fut],2))
188
+ seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None, trend=trend_init)
189
+ # final
190
+
191
+ trend_part = self.projection(trend_part)
192
+ dec_out = trend_part + seasonal_part
193
+
194
+
195
+ BS = dec_out.shape[0]
196
+
197
+ return dec_out[:, -self.pred_len:, :].reshape(BS,self.pred_len,-1,self.mul) # [B, L, D,MUL]
198
+
199
+