dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
dsipts/models/base_v2.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
|
|
2
|
+
from torch import optim
|
|
3
|
+
import torch
|
|
4
|
+
import lightning.pytorch as pl
|
|
5
|
+
from torch.optim.lr_scheduler import StepLR
|
|
6
|
+
from abc import abstractmethod
|
|
7
|
+
from .utils import SinkhornDistance, SoftDTWBatch,PathDTWBatch,pairwise_distances
|
|
8
|
+
from ..data_structure.utils import beauty_string
|
|
9
|
+
from .samformer.utils import SAM
|
|
10
|
+
from .utils import get_scope
|
|
11
|
+
import numpy as np
|
|
12
|
+
from aim import Image
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
from typing import List, Union
|
|
15
|
+
from .utils import QuantileLossMO
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def standardize_momentum(x,order):
|
|
20
|
+
mean = torch.mean(x,1).unsqueeze(1).repeat(1,x.shape[1],1)
|
|
21
|
+
num = torch.pow(x-mean,order).mean(axis=1)
|
|
22
|
+
#den = torch.sqrt(torch.pow(x-mean,2).mean(axis=1)+1e-8)
|
|
23
|
+
#den = torch.pow(den,order)
|
|
24
|
+
|
|
25
|
+
return num#/den
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def dilate_loss(outputs, targets, alpha, gamma, device):
|
|
29
|
+
# outputs, targets: shape (batch_size, N_output, 1)
|
|
30
|
+
batch_size, N_output = outputs.shape[0:2]
|
|
31
|
+
loss_shape = 0
|
|
32
|
+
softdtw_batch = SoftDTWBatch.apply
|
|
33
|
+
D = torch.zeros((batch_size, N_output,N_output )).to(device)
|
|
34
|
+
for k in range(batch_size):
|
|
35
|
+
Dk = pairwise_distances(targets[k,:,:].view(-1,1),outputs[k,:,:].view(-1,1))
|
|
36
|
+
D[k:k+1,:,:] = Dk
|
|
37
|
+
loss_shape = softdtw_batch(D,gamma)
|
|
38
|
+
|
|
39
|
+
path_dtw = PathDTWBatch.apply
|
|
40
|
+
path = path_dtw(D,gamma)
|
|
41
|
+
Omega = pairwise_distances(torch.range(1,N_output).view(N_output,1)).to(device)
|
|
42
|
+
loss_temporal = torch.sum( path*Omega ) / (N_output*N_output)
|
|
43
|
+
loss = alpha*loss_shape+ (1-alpha)*loss_temporal
|
|
44
|
+
return loss#, loss_shape, loss_temporal
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Base(pl.LightningModule):
|
|
48
|
+
|
|
49
|
+
############### SET THE PROPERTIES OF THE ARCHITECTURE##############
|
|
50
|
+
|
|
51
|
+
handle_multivariate = False
|
|
52
|
+
handle_future_covariates = False
|
|
53
|
+
handle_categorical_variables = False
|
|
54
|
+
handle_quantile_loss = False
|
|
55
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
56
|
+
#####################################################################
|
|
57
|
+
@abstractmethod
|
|
58
|
+
def __init__(self,verbose:bool,
|
|
59
|
+
past_steps:int,
|
|
60
|
+
future_steps:int,
|
|
61
|
+
past_channels:int,
|
|
62
|
+
future_channels:int,
|
|
63
|
+
out_channels:int,
|
|
64
|
+
embs_past:List[int],
|
|
65
|
+
embs_fut:List[int],
|
|
66
|
+
n_classes:int=0,
|
|
67
|
+
persistence_weight:float=0.0,
|
|
68
|
+
loss_type: str='l1',
|
|
69
|
+
quantiles:List[int]=[],
|
|
70
|
+
reduction_mode:str = 'mean',
|
|
71
|
+
use_classical_positional_encoder:bool=False,
|
|
72
|
+
emb_dim: int=16,
|
|
73
|
+
|
|
74
|
+
optim:Union[str,None]=None,
|
|
75
|
+
optim_config:dict=None,
|
|
76
|
+
scheduler_config:dict=None,):
|
|
77
|
+
"""
|
|
78
|
+
This is the basic model, each model implemented must overwrite the init method and the forward method.
|
|
79
|
+
The inference step is optional, by default it uses the forward method but for recurrent
|
|
80
|
+
network you should implement your own method
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
verbose (bool): Flag to enable verbose logging.
|
|
84
|
+
past_steps (int): Number of past time steps to consider.
|
|
85
|
+
future_steps (int): Number of future time steps to predict.
|
|
86
|
+
past_channels (int): Number of channels in the past input data.
|
|
87
|
+
future_channels (int): Number of channels in the future input data.
|
|
88
|
+
out_channels (int): Number of output channels.
|
|
89
|
+
embs_past (List[int]): List of embedding dimensions for past data.
|
|
90
|
+
embs_fut (List[int]): List of embedding dimensions for future data.
|
|
91
|
+
n_classes (int, optional): Number of classes for classification. Defaults to 0.
|
|
92
|
+
persistence_weight (float, optional): Weight for persistence in loss calculation. Defaults to 0.0.
|
|
93
|
+
loss_type (str, optional): Type of loss function to use ('l1' or 'mse'). Defaults to 'l1'.
|
|
94
|
+
quantiles (List[int], optional): List of quantiles for quantile loss. Defaults to an empty list.
|
|
95
|
+
reduction_mode (str, optional): Mode for reduction for categorical embedding layer ('mean', 'sum', 'none'). Defaults to 'mean'.
|
|
96
|
+
use_classical_positional_encoder (bool, optional): Flag to use classical positional encoding or using embedding layer also for the positions. Defaults to False.
|
|
97
|
+
emb_dim (int, optional): Dimension of categorical embeddings. Defaults to 16.
|
|
98
|
+
optim (Union[str, None], optional): Optimizer type. Defaults to None.
|
|
99
|
+
optim_config (dict, optional): Configuration for the optimizer. Defaults to None.
|
|
100
|
+
scheduler_config (dict, optional): Configuration for the learning rate scheduler. Defaults to None.
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
AssertionError: If the number of quantiles is not equal to 3 when quantiles are provided.
|
|
104
|
+
AssertionError: If the number of output channels is not 1 for classification tasks.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
beauty_string('V2','block',True)
|
|
109
|
+
super(Base, self).__init__()
|
|
110
|
+
self.save_hyperparameters(logger=False)
|
|
111
|
+
self.count_epoch = 0
|
|
112
|
+
self.initialize = False
|
|
113
|
+
self.train_loss_epoch = -100.0
|
|
114
|
+
self.verbose = verbose
|
|
115
|
+
self.name = self.__class__.__name__
|
|
116
|
+
self.train_epoch_metrics = []
|
|
117
|
+
self.validation_epoch_metrics = []
|
|
118
|
+
|
|
119
|
+
self.use_quantiles = True if len(quantiles)>0 else False
|
|
120
|
+
self.quantiles = quantiles
|
|
121
|
+
self.optim = optim
|
|
122
|
+
self.optim_config = optim_config
|
|
123
|
+
self.scheduler_config = scheduler_config
|
|
124
|
+
self.loss_type = loss_type
|
|
125
|
+
self.persistence_weight = persistence_weight
|
|
126
|
+
self.use_classical_positional_encoder = use_classical_positional_encoder
|
|
127
|
+
self.reduction_mode = reduction_mode
|
|
128
|
+
self.past_steps = past_steps
|
|
129
|
+
self.future_steps = future_steps
|
|
130
|
+
self.embs_past = embs_past
|
|
131
|
+
self.embs_fut = embs_fut
|
|
132
|
+
self.past_channels = past_channels
|
|
133
|
+
self.future_channels = future_channels
|
|
134
|
+
self.emb_dim = emb_dim
|
|
135
|
+
self.out_channels = out_channels
|
|
136
|
+
self.n_classes = n_classes
|
|
137
|
+
if n_classes==0:
|
|
138
|
+
self.is_classification = False
|
|
139
|
+
if len(self.quantiles)>0:
|
|
140
|
+
assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
141
|
+
self.use_quantiles = True
|
|
142
|
+
self.mul = len(self.quantiles)
|
|
143
|
+
self.loss = QuantileLossMO(quantiles)
|
|
144
|
+
else:
|
|
145
|
+
self.use_quantiles = False
|
|
146
|
+
self.mul = 1
|
|
147
|
+
if self.loss_type == 'mse':
|
|
148
|
+
self.loss = nn.MSELoss()
|
|
149
|
+
else:
|
|
150
|
+
self.loss = nn.L1Loss()
|
|
151
|
+
else:
|
|
152
|
+
self.is_classification = True
|
|
153
|
+
self.use_quantiles = False
|
|
154
|
+
self.mul = n_classes
|
|
155
|
+
self.loss = torch.nn.CrossEntropyLoss()
|
|
156
|
+
assert self.out_channels==1, "Classification require only one channel"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
self.future_steps = future_steps
|
|
160
|
+
|
|
161
|
+
beauty_string(self.description,'info',True)
|
|
162
|
+
@abstractmethod
|
|
163
|
+
def forward(self, batch:dict)-> torch.tensor:
|
|
164
|
+
"""Forlward method used during the training loop
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
batch (dict): the batch structure. The keys are:
|
|
168
|
+
y : the target variable(s). This is always present
|
|
169
|
+
x_num_past: the numerical past variables. This is always present
|
|
170
|
+
x_num_future: the numerical future variables
|
|
171
|
+
x_cat_past: the categorical past variables
|
|
172
|
+
x_cat_future: the categorical future variables
|
|
173
|
+
idx_target: index of target features in the past array
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
torch.tensor: output of the mode;
|
|
178
|
+
"""
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def inference(self, batch:dict)->torch.tensor:
|
|
184
|
+
"""Usually it is ok to return the output of the forward method but sometimes not (e.g. RNN)
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
batch (dict): batch
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
torch.tensor: result
|
|
191
|
+
"""
|
|
192
|
+
return self(batch)
|
|
193
|
+
|
|
194
|
+
def configure_optimizers(self):
|
|
195
|
+
"""
|
|
196
|
+
Each model has optim_config and scheduler_config
|
|
197
|
+
|
|
198
|
+
:meta private:
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
self.has_sam_optim = False
|
|
202
|
+
if self.optim_config is None:
|
|
203
|
+
self.optim_config = {'lr': 5e-05}
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
if self.optim is None:
|
|
207
|
+
optimizer = optim.Adam(self.parameters(), **self.optim_config)
|
|
208
|
+
self.initialize = True
|
|
209
|
+
|
|
210
|
+
else:
|
|
211
|
+
if self.initialize is False:
|
|
212
|
+
if self.optim=='SAM':
|
|
213
|
+
self.has_sam_optim = True
|
|
214
|
+
self.automatic_optimization = False
|
|
215
|
+
self.my_step = 0
|
|
216
|
+
|
|
217
|
+
else:
|
|
218
|
+
self.optim = eval(self.optim)
|
|
219
|
+
self.has_sam_optim = False
|
|
220
|
+
self.automatic_optimization = True
|
|
221
|
+
|
|
222
|
+
beauty_string(self.optim,'',self.verbose)
|
|
223
|
+
if self.has_sam_optim:
|
|
224
|
+
optimizer = SAM(self.parameters(), base_optimizer=torch.optim.Adam, **self.optim_config)
|
|
225
|
+
else:
|
|
226
|
+
optimizer = self.optim(self.parameters(), **self.optim_config)
|
|
227
|
+
beauty_string(optimizer,'',self.verbose)
|
|
228
|
+
self.initialize = True
|
|
229
|
+
self.lr = self.optim_config['lr'] ##CHECK THISs
|
|
230
|
+
if self.scheduler_config is not None:
|
|
231
|
+
scheduler = StepLR(optimizer,**self.scheduler_config)
|
|
232
|
+
return [optimizer], [scheduler]
|
|
233
|
+
else:
|
|
234
|
+
return optimizer
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def training_step(self, batch, batch_idx):
|
|
238
|
+
"""
|
|
239
|
+
pythotrch lightening stuff
|
|
240
|
+
|
|
241
|
+
:meta private:
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
#loss = self.compute_loss(batch,y_hat)
|
|
245
|
+
#import pdb
|
|
246
|
+
#pdb.set_trace()
|
|
247
|
+
|
|
248
|
+
if self.has_sam_optim:
|
|
249
|
+
|
|
250
|
+
opt = self.optimizers()
|
|
251
|
+
def closure():
|
|
252
|
+
opt.zero_grad()
|
|
253
|
+
y_hat = self(batch)
|
|
254
|
+
loss = self.compute_loss(batch,y_hat)
|
|
255
|
+
self.manual_backward(loss)
|
|
256
|
+
return loss
|
|
257
|
+
|
|
258
|
+
opt.step(closure)
|
|
259
|
+
y_hat = self(batch)
|
|
260
|
+
loss = self.compute_loss(batch,y_hat)
|
|
261
|
+
|
|
262
|
+
#opt.first_step(zero_grad=True)
|
|
263
|
+
|
|
264
|
+
#y_hat = self(batch)
|
|
265
|
+
#loss = self.compute_loss(batch, y_hat)
|
|
266
|
+
#self.my_step+=1
|
|
267
|
+
#self.manual_backward(loss,retain_graph=True)
|
|
268
|
+
#opt.second_step(zero_grad=True)
|
|
269
|
+
#self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
|
|
270
|
+
#self.log("global_step", self.my_step, on_step=True) # Correct way to log
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
#self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.increment("optimizer")
|
|
274
|
+
else:
|
|
275
|
+
y_hat = self(batch)
|
|
276
|
+
loss = self.compute_loss(batch,y_hat)
|
|
277
|
+
|
|
278
|
+
self.train_epoch_metrics.append(loss.item())
|
|
279
|
+
return loss
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def validation_step(self, batch, batch_idx):
|
|
283
|
+
"""
|
|
284
|
+
pythotrch lightening stuff
|
|
285
|
+
|
|
286
|
+
:meta private:
|
|
287
|
+
"""
|
|
288
|
+
y_hat = self(batch)
|
|
289
|
+
if batch_idx==0:
|
|
290
|
+
if self.use_quantiles:
|
|
291
|
+
idx = 1
|
|
292
|
+
else:
|
|
293
|
+
idx = 0
|
|
294
|
+
#track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
|
|
295
|
+
|
|
296
|
+
if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
|
|
297
|
+
|
|
298
|
+
for i in range(batch['y'].shape[2]):
|
|
299
|
+
real = batch['y'][0,:,i].cpu().detach().numpy()
|
|
300
|
+
pred = y_hat[0,:,i,idx].cpu().detach().numpy()
|
|
301
|
+
fig, ax = plt.subplots(figsize=(7,5))
|
|
302
|
+
ax.plot(real,'o-',label='real')
|
|
303
|
+
ax.plot(pred,'o-',label='pred')
|
|
304
|
+
ax.legend()
|
|
305
|
+
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
306
|
+
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
307
|
+
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
308
|
+
self.validation_epoch_metrics.append(self.compute_loss(batch,y_hat))
|
|
309
|
+
return
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def on_validation_epoch_end(self):
|
|
313
|
+
"""
|
|
314
|
+
pythotrch lightening stuff
|
|
315
|
+
|
|
316
|
+
:meta private:
|
|
317
|
+
"""
|
|
318
|
+
avg = torch.stack(self.validation_epoch_metrics).mean()
|
|
319
|
+
self.validation_epoch_metrics = []
|
|
320
|
+
self.log("val_loss", avg,sync_dist=True)
|
|
321
|
+
beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
|
|
322
|
+
|
|
323
|
+
def on_train_epoch_end(self):
|
|
324
|
+
|
|
325
|
+
"""
|
|
326
|
+
pythotrch lightening stuff
|
|
327
|
+
|
|
328
|
+
:meta private:
|
|
329
|
+
"""
|
|
330
|
+
avg = np.stack(self.train_epoch_metrics).mean()
|
|
331
|
+
self.log("train_loss", avg,sync_dist=True)
|
|
332
|
+
self.count_epoch+=1
|
|
333
|
+
self.train_epoch_metrics = []
|
|
334
|
+
self.train_loss_epoch = avg
|
|
335
|
+
|
|
336
|
+
def compute_loss(self,batch,y_hat):
|
|
337
|
+
"""
|
|
338
|
+
custom loss calculation
|
|
339
|
+
|
|
340
|
+
:meta private:
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
if self.use_quantiles is False:
|
|
344
|
+
initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
|
|
345
|
+
else:
|
|
346
|
+
initial_loss = self.loss(y_hat, batch['y'])
|
|
347
|
+
x = batch['x_num_past'].to(self.device)
|
|
348
|
+
idx_target = batch['idx_target'][0]
|
|
349
|
+
x_start = x[:,-1,idx_target].unsqueeze(1)
|
|
350
|
+
y_persistence = x_start.repeat(1,self.future_steps,1)
|
|
351
|
+
|
|
352
|
+
##generally you want to work without quantile loss
|
|
353
|
+
if self.use_quantiles is False:
|
|
354
|
+
x = y_hat[:,:,:,0]
|
|
355
|
+
else:
|
|
356
|
+
x = y_hat[:,:,:,1]
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
if self.loss_type == 'linear_penalization':
|
|
360
|
+
persistence_error = (2.0-10.0*torch.clamp( torch.abs((y_persistence-x)/(0.001+torch.abs(y_persistence))),min=0.0,max=max(0.05,0.1*(1+np.log10(self.persistence_weight) ))))
|
|
361
|
+
loss = torch.mean(torch.abs(x- batch['y'])*persistence_error)
|
|
362
|
+
|
|
363
|
+
if self.loss_type == 'mda':
|
|
364
|
+
#import pdb
|
|
365
|
+
#pdb.set_trace()
|
|
366
|
+
mda = (1-torch.mean( torch.sign(torch.diff(x,axis=1))*torch.sign(torch.diff(batch['y'],axis=1))))
|
|
367
|
+
loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1).flatten()) + self.persistence_weight*mda/10
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
elif self.loss_type == 'exponential_penalization':
|
|
372
|
+
weights = (1+self.persistence_weight*torch.exp(-torch.abs(y_persistence-x)))
|
|
373
|
+
loss = torch.mean(torch.abs(x- batch['y'])*weights)
|
|
374
|
+
|
|
375
|
+
elif self.loss_type=='sinkhorn':
|
|
376
|
+
sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction='mean')
|
|
377
|
+
loss = sinkhorn.compute(x,batch['y'])
|
|
378
|
+
|
|
379
|
+
elif self.loss_type == 'additive_iv':
|
|
380
|
+
std = torch.sqrt(torch.var(batch['y'], dim=(1))+ 1e-8) ##--> BSxChannel
|
|
381
|
+
x_std = torch.sqrt(torch.var(x, dim=(1))+ 1e-8)
|
|
382
|
+
loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1).flatten() + self.persistence_weight*torch.abs(x_std-std).flatten())
|
|
383
|
+
|
|
384
|
+
elif self.loss_type == 'multiplicative_iv':
|
|
385
|
+
std = torch.sqrt(torch.var(batch['y'], dim=(1))+ 1e-8) ##--> BSxChannel
|
|
386
|
+
x_std = torch.sqrt(torch.var(x, dim=(1))+ 1e-8)
|
|
387
|
+
if self.persistence_weight>0:
|
|
388
|
+
loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1)*torch.abs(x_std-std))
|
|
389
|
+
else:
|
|
390
|
+
loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1))
|
|
391
|
+
elif self.loss_type=='global_iv':
|
|
392
|
+
std_real = torch.sqrt(torch.var(batch['y'], dim=(0,1)))
|
|
393
|
+
std_predict = torch.sqrt(torch.var(x, dim=(0,1)))
|
|
394
|
+
loss = initial_loss + self.persistence_weight*torch.abs(std_real-std_predict).mean()
|
|
395
|
+
|
|
396
|
+
elif self.loss_type=='smape':
|
|
397
|
+
loss = torch.mean(2*torch.abs(x-batch['y']) / (torch.abs(x)+torch.abs(batch['y'])))
|
|
398
|
+
|
|
399
|
+
elif self.loss_type=='triplet':
|
|
400
|
+
loss_fn = torch.nn.TripletMarginLoss(margin=0.01, p=1.0,swap=False)
|
|
401
|
+
loss = initial_loss + self.persistence_weight*loss_fn(x, batch['y'], y_persistence)
|
|
402
|
+
|
|
403
|
+
elif self.loss_type=='high_order':
|
|
404
|
+
loss = initial_loss
|
|
405
|
+
for i in range(2,5):
|
|
406
|
+
mom_real = standardize_momentum( batch['y'],i)
|
|
407
|
+
mom_pred = standardize_momentum(x,i)
|
|
408
|
+
|
|
409
|
+
mom_loss = torch.abs(mom_real-mom_pred).mean()
|
|
410
|
+
loss+=self.persistence_weight*mom_loss
|
|
411
|
+
|
|
412
|
+
elif self.loss_type=='dilated':
|
|
413
|
+
#BxLxCxMUL
|
|
414
|
+
if self.persistence_weight==0.1:
|
|
415
|
+
alpha = 0.25
|
|
416
|
+
if self.persistence_weight==1:
|
|
417
|
+
alpha = 0.5
|
|
418
|
+
else:
|
|
419
|
+
alpha =0.75
|
|
420
|
+
alpha = self.persistence_weight
|
|
421
|
+
gamma = 0.01
|
|
422
|
+
loss = 0
|
|
423
|
+
##no multichannel here
|
|
424
|
+
for i in range(y_hat.shape[2]):
|
|
425
|
+
##error here
|
|
426
|
+
|
|
427
|
+
loss+= dilate_loss( batch['y'][:,:,i:i+1],x[:,:,i:i+1], alpha, gamma, y_hat.device)
|
|
428
|
+
|
|
429
|
+
elif self.loss_type=='huber':
|
|
430
|
+
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight/10)
|
|
431
|
+
#loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
|
|
432
|
+
if self.use_quantiles is False:
|
|
433
|
+
x = y_hat[:,:,:,0]
|
|
434
|
+
else:
|
|
435
|
+
x = y_hat[:,:,:,1]
|
|
436
|
+
BS = x.shape[0]
|
|
437
|
+
loss = loss(y_hat.reshape(BS,-1), batch['y'].reshape(BS,-1))
|
|
438
|
+
|
|
439
|
+
else:
|
|
440
|
+
loss = initial_loss
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
return loss
|
|
File without changes
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from einops import rearrange, repeat
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from math import sqrt
|
|
7
|
+
|
|
8
|
+
class FullAttention(nn.Module):
|
|
9
|
+
'''
|
|
10
|
+
The Attention operation
|
|
11
|
+
'''
|
|
12
|
+
def __init__(self, scale=None, attention_dropout=0.1):
|
|
13
|
+
super(FullAttention, self).__init__()
|
|
14
|
+
self.scale = scale
|
|
15
|
+
self.dropout = nn.Dropout(attention_dropout)
|
|
16
|
+
|
|
17
|
+
def forward(self, queries, keys, values):
|
|
18
|
+
B, L, H, E = queries.shape
|
|
19
|
+
_, S, _, D = values.shape
|
|
20
|
+
scale = self.scale or 1./sqrt(E)
|
|
21
|
+
|
|
22
|
+
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
|
23
|
+
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
|
24
|
+
V = torch.einsum("bhls,bshd->blhd", A, values)
|
|
25
|
+
|
|
26
|
+
return V.contiguous()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AttentionLayer(nn.Module):
|
|
30
|
+
'''
|
|
31
|
+
The Multi-head Self-Attention (MSA) Layer
|
|
32
|
+
'''
|
|
33
|
+
def __init__(self, d_model, n_heads, d_keys=None, d_values=None, mix=True, dropout = 0.1):
|
|
34
|
+
super(AttentionLayer, self).__init__()
|
|
35
|
+
|
|
36
|
+
d_keys = d_keys or (d_model//n_heads)
|
|
37
|
+
d_values = d_values or (d_model//n_heads)
|
|
38
|
+
|
|
39
|
+
self.inner_attention = FullAttention(scale=None, attention_dropout = dropout)
|
|
40
|
+
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
|
41
|
+
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
|
42
|
+
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
|
43
|
+
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
|
44
|
+
self.n_heads = n_heads
|
|
45
|
+
self.mix = mix
|
|
46
|
+
|
|
47
|
+
def forward(self, queries, keys, values):
|
|
48
|
+
B, L, _ = queries.shape
|
|
49
|
+
_, S, _ = keys.shape
|
|
50
|
+
H = self.n_heads
|
|
51
|
+
|
|
52
|
+
queries = self.query_projection(queries).view(B, L, H, -1)
|
|
53
|
+
keys = self.key_projection(keys).view(B, S, H, -1)
|
|
54
|
+
values = self.value_projection(values).view(B, S, H, -1)
|
|
55
|
+
|
|
56
|
+
out = self.inner_attention(
|
|
57
|
+
queries,
|
|
58
|
+
keys,
|
|
59
|
+
values,
|
|
60
|
+
)
|
|
61
|
+
if self.mix:
|
|
62
|
+
out = out.transpose(2,1).contiguous()
|
|
63
|
+
out = out.view(B, L, -1)
|
|
64
|
+
|
|
65
|
+
return self.out_projection(out)
|
|
66
|
+
|
|
67
|
+
class TwoStageAttentionLayer(nn.Module):
|
|
68
|
+
'''
|
|
69
|
+
The Two Stage Attention (TSA) Layer
|
|
70
|
+
input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]
|
|
71
|
+
'''
|
|
72
|
+
def __init__(self, seg_num, factor, d_model, n_heads, d_ff = None, dropout=0.1):
|
|
73
|
+
super(TwoStageAttentionLayer, self).__init__()
|
|
74
|
+
d_ff = d_ff or 4*d_model
|
|
75
|
+
self.time_attention = AttentionLayer(d_model, n_heads, dropout = dropout)
|
|
76
|
+
self.dim_sender = AttentionLayer(d_model, n_heads, dropout = dropout)
|
|
77
|
+
self.dim_receiver = AttentionLayer(d_model, n_heads, dropout = dropout)
|
|
78
|
+
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))
|
|
79
|
+
|
|
80
|
+
self.dropout = nn.Dropout(dropout)
|
|
81
|
+
|
|
82
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
83
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
84
|
+
self.norm3 = nn.LayerNorm(d_model)
|
|
85
|
+
self.norm4 = nn.LayerNorm(d_model)
|
|
86
|
+
|
|
87
|
+
self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff),
|
|
88
|
+
nn.GELU(),
|
|
89
|
+
nn.Linear(d_ff, d_model))
|
|
90
|
+
self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff),
|
|
91
|
+
nn.GELU(),
|
|
92
|
+
nn.Linear(d_ff, d_model))
|
|
93
|
+
|
|
94
|
+
def forward(self, x):
|
|
95
|
+
#Cross Time Stage: Directly apply MSA to each dimension
|
|
96
|
+
batch = x.shape[0]
|
|
97
|
+
time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model')
|
|
98
|
+
time_enc = self.time_attention(
|
|
99
|
+
time_in, time_in, time_in
|
|
100
|
+
)
|
|
101
|
+
dim_in = time_in + self.dropout(time_enc)
|
|
102
|
+
dim_in = self.norm1(dim_in)
|
|
103
|
+
dim_in = dim_in + self.dropout(self.MLP1(dim_in))
|
|
104
|
+
dim_in = self.norm2(dim_in)
|
|
105
|
+
|
|
106
|
+
#Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection
|
|
107
|
+
dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b = batch)
|
|
108
|
+
batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat = batch)
|
|
109
|
+
dim_buffer = self.dim_sender(batch_router, dim_send, dim_send)
|
|
110
|
+
dim_receive = self.dim_receiver(dim_send, dim_buffer, dim_buffer)
|
|
111
|
+
dim_enc = dim_send + self.dropout(dim_receive)
|
|
112
|
+
dim_enc = self.norm3(dim_enc)
|
|
113
|
+
dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))
|
|
114
|
+
dim_enc = self.norm4(dim_enc)
|
|
115
|
+
|
|
116
|
+
final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b = batch)
|
|
117
|
+
|
|
118
|
+
return final_out
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
from einops import rearrange
|
|
3
|
+
from .attn import AttentionLayer, TwoStageAttentionLayer
|
|
4
|
+
|
|
5
|
+
class DecoderLayer(nn.Module):
|
|
6
|
+
'''
|
|
7
|
+
The decoder layer of Crossformer, each layer will make a prediction at its scale
|
|
8
|
+
'''
|
|
9
|
+
def __init__(self, seg_len, d_model, n_heads, d_ff=None, dropout=0.1, out_seg_num = 10, factor = 10):
|
|
10
|
+
super(DecoderLayer, self).__init__()
|
|
11
|
+
self.self_attention = TwoStageAttentionLayer(out_seg_num, factor, d_model, n_heads,d_ff, dropout)
|
|
12
|
+
self.cross_attention = AttentionLayer(d_model, n_heads, dropout = dropout)
|
|
13
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
14
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
15
|
+
self.dropout = nn.Dropout(dropout)
|
|
16
|
+
self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model),
|
|
17
|
+
nn.GELU(),
|
|
18
|
+
nn.Linear(d_model, d_model))
|
|
19
|
+
self.linear_pred = nn.Linear(d_model, seg_len)
|
|
20
|
+
|
|
21
|
+
def forward(self, x, cross):
|
|
22
|
+
'''
|
|
23
|
+
x: the output of last decoder layer
|
|
24
|
+
cross: the output of the corresponding encoder layer
|
|
25
|
+
'''
|
|
26
|
+
|
|
27
|
+
batch = x.shape[0]
|
|
28
|
+
x = self.self_attention(x)
|
|
29
|
+
x = rearrange(x, 'b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model')
|
|
30
|
+
|
|
31
|
+
cross = rearrange(cross, 'b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model')
|
|
32
|
+
tmp = self.cross_attention(
|
|
33
|
+
x, cross, cross,
|
|
34
|
+
)
|
|
35
|
+
x = x + self.dropout(tmp)
|
|
36
|
+
y = x = self.norm1(x)
|
|
37
|
+
y = self.MLP1(y)
|
|
38
|
+
dec_output = self.norm2(x+y)
|
|
39
|
+
|
|
40
|
+
dec_output = rearrange(dec_output, '(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model', b = batch)
|
|
41
|
+
layer_predict = self.linear_pred(dec_output)
|
|
42
|
+
layer_predict = rearrange(layer_predict, 'b out_d seg_num seg_len -> b (out_d seg_num) seg_len')
|
|
43
|
+
|
|
44
|
+
return dec_output, layer_predict
|
|
45
|
+
|
|
46
|
+
class Decoder(nn.Module):
|
|
47
|
+
'''
|
|
48
|
+
The decoder of Crossformer, making the final prediction by adding up predictions at each scale
|
|
49
|
+
'''
|
|
50
|
+
def __init__(self, seg_len, d_layers, d_model, n_heads, d_ff, dropout,\
|
|
51
|
+
router=False, out_seg_num = 10, factor=10):
|
|
52
|
+
super(Decoder, self).__init__()
|
|
53
|
+
|
|
54
|
+
self.router = router
|
|
55
|
+
self.decode_layers = nn.ModuleList()
|
|
56
|
+
for i in range(d_layers):
|
|
57
|
+
self.decode_layers.append(DecoderLayer(seg_len, d_model, n_heads, d_ff, dropout, \
|
|
58
|
+
out_seg_num, factor))
|
|
59
|
+
|
|
60
|
+
def forward(self, x, cross):
|
|
61
|
+
final_predict = None
|
|
62
|
+
i = 0
|
|
63
|
+
|
|
64
|
+
ts_d = x.shape[1]
|
|
65
|
+
for layer in self.decode_layers:
|
|
66
|
+
cross_enc = cross[i]
|
|
67
|
+
x, layer_predict = layer(x, cross_enc)
|
|
68
|
+
if final_predict is None:
|
|
69
|
+
final_predict = layer_predict
|
|
70
|
+
else:
|
|
71
|
+
final_predict = final_predict + layer_predict
|
|
72
|
+
i += 1
|
|
73
|
+
|
|
74
|
+
final_predict = rearrange(final_predict, 'b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d', out_d = ts_d)
|
|
75
|
+
|
|
76
|
+
return final_predict
|
|
77
|
+
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
from einops import rearrange
|
|
3
|
+
|
|
4
|
+
class DSW_embedding(nn.Module):
|
|
5
|
+
def __init__(self, seg_len, d_model):
|
|
6
|
+
super(DSW_embedding, self).__init__()
|
|
7
|
+
self.seg_len = seg_len
|
|
8
|
+
|
|
9
|
+
self.linear = nn.Linear(seg_len, d_model)
|
|
10
|
+
|
|
11
|
+
def forward(self, x):
|
|
12
|
+
batch, ts_len, ts_dim = x.shape
|
|
13
|
+
|
|
14
|
+
x_segment = rearrange(x, 'b (seg_num seg_len) d -> (b d seg_num) seg_len', seg_len = self.seg_len)
|
|
15
|
+
x_embed = self.linear(x_segment)
|
|
16
|
+
x_embed = rearrange(x_embed, '(b d seg_num) d_model -> b d seg_num d_model', b = batch, d = ts_dim)
|
|
17
|
+
|
|
18
|
+
return x_embed
|