dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,444 @@
1
+
2
+ from torch import optim
3
+ import torch
4
+ import lightning.pytorch as pl
5
+ from torch.optim.lr_scheduler import StepLR
6
+ from abc import abstractmethod
7
+ from .utils import SinkhornDistance, SoftDTWBatch,PathDTWBatch,pairwise_distances
8
+ from ..data_structure.utils import beauty_string
9
+ from .samformer.utils import SAM
10
+ from .utils import get_scope
11
+ import numpy as np
12
+ from aim import Image
13
+ import matplotlib.pyplot as plt
14
+ from typing import List, Union
15
+ from .utils import QuantileLossMO
16
+ import torch.nn as nn
17
+
18
+
19
+ def standardize_momentum(x,order):
20
+ mean = torch.mean(x,1).unsqueeze(1).repeat(1,x.shape[1],1)
21
+ num = torch.pow(x-mean,order).mean(axis=1)
22
+ #den = torch.sqrt(torch.pow(x-mean,2).mean(axis=1)+1e-8)
23
+ #den = torch.pow(den,order)
24
+
25
+ return num#/den
26
+
27
+
28
+ def dilate_loss(outputs, targets, alpha, gamma, device):
29
+ # outputs, targets: shape (batch_size, N_output, 1)
30
+ batch_size, N_output = outputs.shape[0:2]
31
+ loss_shape = 0
32
+ softdtw_batch = SoftDTWBatch.apply
33
+ D = torch.zeros((batch_size, N_output,N_output )).to(device)
34
+ for k in range(batch_size):
35
+ Dk = pairwise_distances(targets[k,:,:].view(-1,1),outputs[k,:,:].view(-1,1))
36
+ D[k:k+1,:,:] = Dk
37
+ loss_shape = softdtw_batch(D,gamma)
38
+
39
+ path_dtw = PathDTWBatch.apply
40
+ path = path_dtw(D,gamma)
41
+ Omega = pairwise_distances(torch.range(1,N_output).view(N_output,1)).to(device)
42
+ loss_temporal = torch.sum( path*Omega ) / (N_output*N_output)
43
+ loss = alpha*loss_shape+ (1-alpha)*loss_temporal
44
+ return loss#, loss_shape, loss_temporal
45
+
46
+
47
+ class Base(pl.LightningModule):
48
+
49
+ ############### SET THE PROPERTIES OF THE ARCHITECTURE##############
50
+
51
+ handle_multivariate = False
52
+ handle_future_covariates = False
53
+ handle_categorical_variables = False
54
+ handle_quantile_loss = False
55
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
56
+ #####################################################################
57
+ @abstractmethod
58
+ def __init__(self,verbose:bool,
59
+ past_steps:int,
60
+ future_steps:int,
61
+ past_channels:int,
62
+ future_channels:int,
63
+ out_channels:int,
64
+ embs_past:List[int],
65
+ embs_fut:List[int],
66
+ n_classes:int=0,
67
+ persistence_weight:float=0.0,
68
+ loss_type: str='l1',
69
+ quantiles:List[int]=[],
70
+ reduction_mode:str = 'mean',
71
+ use_classical_positional_encoder:bool=False,
72
+ emb_dim: int=16,
73
+
74
+ optim:Union[str,None]=None,
75
+ optim_config:dict=None,
76
+ scheduler_config:dict=None,):
77
+ """
78
+ This is the basic model, each model implemented must overwrite the init method and the forward method.
79
+ The inference step is optional, by default it uses the forward method but for recurrent
80
+ network you should implement your own method
81
+
82
+ Args:
83
+ verbose (bool): Flag to enable verbose logging.
84
+ past_steps (int): Number of past time steps to consider.
85
+ future_steps (int): Number of future time steps to predict.
86
+ past_channels (int): Number of channels in the past input data.
87
+ future_channels (int): Number of channels in the future input data.
88
+ out_channels (int): Number of output channels.
89
+ embs_past (List[int]): List of embedding dimensions for past data.
90
+ embs_fut (List[int]): List of embedding dimensions for future data.
91
+ n_classes (int, optional): Number of classes for classification. Defaults to 0.
92
+ persistence_weight (float, optional): Weight for persistence in loss calculation. Defaults to 0.0.
93
+ loss_type (str, optional): Type of loss function to use ('l1' or 'mse'). Defaults to 'l1'.
94
+ quantiles (List[int], optional): List of quantiles for quantile loss. Defaults to an empty list.
95
+ reduction_mode (str, optional): Mode for reduction for categorical embedding layer ('mean', 'sum', 'none'). Defaults to 'mean'.
96
+ use_classical_positional_encoder (bool, optional): Flag to use classical positional encoding or using embedding layer also for the positions. Defaults to False.
97
+ emb_dim (int, optional): Dimension of categorical embeddings. Defaults to 16.
98
+ optim (Union[str, None], optional): Optimizer type. Defaults to None.
99
+ optim_config (dict, optional): Configuration for the optimizer. Defaults to None.
100
+ scheduler_config (dict, optional): Configuration for the learning rate scheduler. Defaults to None.
101
+
102
+ Raises:
103
+ AssertionError: If the number of quantiles is not equal to 3 when quantiles are provided.
104
+ AssertionError: If the number of output channels is not 1 for classification tasks.
105
+ """
106
+
107
+
108
+ beauty_string('V2','block',True)
109
+ super(Base, self).__init__()
110
+ self.save_hyperparameters(logger=False)
111
+ self.count_epoch = 0
112
+ self.initialize = False
113
+ self.train_loss_epoch = -100.0
114
+ self.verbose = verbose
115
+ self.name = self.__class__.__name__
116
+ self.train_epoch_metrics = []
117
+ self.validation_epoch_metrics = []
118
+
119
+ self.use_quantiles = True if len(quantiles)>0 else False
120
+ self.quantiles = quantiles
121
+ self.optim = optim
122
+ self.optim_config = optim_config
123
+ self.scheduler_config = scheduler_config
124
+ self.loss_type = loss_type
125
+ self.persistence_weight = persistence_weight
126
+ self.use_classical_positional_encoder = use_classical_positional_encoder
127
+ self.reduction_mode = reduction_mode
128
+ self.past_steps = past_steps
129
+ self.future_steps = future_steps
130
+ self.embs_past = embs_past
131
+ self.embs_fut = embs_fut
132
+ self.past_channels = past_channels
133
+ self.future_channels = future_channels
134
+ self.emb_dim = emb_dim
135
+ self.out_channels = out_channels
136
+ self.n_classes = n_classes
137
+ if n_classes==0:
138
+ self.is_classification = False
139
+ if len(self.quantiles)>0:
140
+ assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
141
+ self.use_quantiles = True
142
+ self.mul = len(self.quantiles)
143
+ self.loss = QuantileLossMO(quantiles)
144
+ else:
145
+ self.use_quantiles = False
146
+ self.mul = 1
147
+ if self.loss_type == 'mse':
148
+ self.loss = nn.MSELoss()
149
+ else:
150
+ self.loss = nn.L1Loss()
151
+ else:
152
+ self.is_classification = True
153
+ self.use_quantiles = False
154
+ self.mul = n_classes
155
+ self.loss = torch.nn.CrossEntropyLoss()
156
+ assert self.out_channels==1, "Classification require only one channel"
157
+
158
+
159
+ self.future_steps = future_steps
160
+
161
+ beauty_string(self.description,'info',True)
162
+ @abstractmethod
163
+ def forward(self, batch:dict)-> torch.tensor:
164
+ """Forlward method used during the training loop
165
+
166
+ Args:
167
+ batch (dict): the batch structure. The keys are:
168
+ y : the target variable(s). This is always present
169
+ x_num_past: the numerical past variables. This is always present
170
+ x_num_future: the numerical future variables
171
+ x_cat_past: the categorical past variables
172
+ x_cat_future: the categorical future variables
173
+ idx_target: index of target features in the past array
174
+
175
+
176
+ Returns:
177
+ torch.tensor: output of the mode;
178
+ """
179
+ return None
180
+
181
+
182
+
183
+ def inference(self, batch:dict)->torch.tensor:
184
+ """Usually it is ok to return the output of the forward method but sometimes not (e.g. RNN)
185
+
186
+ Args:
187
+ batch (dict): batch
188
+
189
+ Returns:
190
+ torch.tensor: result
191
+ """
192
+ return self(batch)
193
+
194
+ def configure_optimizers(self):
195
+ """
196
+ Each model has optim_config and scheduler_config
197
+
198
+ :meta private:
199
+ """
200
+
201
+ self.has_sam_optim = False
202
+ if self.optim_config is None:
203
+ self.optim_config = {'lr': 5e-05}
204
+
205
+
206
+ if self.optim is None:
207
+ optimizer = optim.Adam(self.parameters(), **self.optim_config)
208
+ self.initialize = True
209
+
210
+ else:
211
+ if self.initialize is False:
212
+ if self.optim=='SAM':
213
+ self.has_sam_optim = True
214
+ self.automatic_optimization = False
215
+ self.my_step = 0
216
+
217
+ else:
218
+ self.optim = eval(self.optim)
219
+ self.has_sam_optim = False
220
+ self.automatic_optimization = True
221
+
222
+ beauty_string(self.optim,'',self.verbose)
223
+ if self.has_sam_optim:
224
+ optimizer = SAM(self.parameters(), base_optimizer=torch.optim.Adam, **self.optim_config)
225
+ else:
226
+ optimizer = self.optim(self.parameters(), **self.optim_config)
227
+ beauty_string(optimizer,'',self.verbose)
228
+ self.initialize = True
229
+ self.lr = self.optim_config['lr'] ##CHECK THISs
230
+ if self.scheduler_config is not None:
231
+ scheduler = StepLR(optimizer,**self.scheduler_config)
232
+ return [optimizer], [scheduler]
233
+ else:
234
+ return optimizer
235
+
236
+
237
+ def training_step(self, batch, batch_idx):
238
+ """
239
+ pythotrch lightening stuff
240
+
241
+ :meta private:
242
+ """
243
+
244
+ #loss = self.compute_loss(batch,y_hat)
245
+ #import pdb
246
+ #pdb.set_trace()
247
+
248
+ if self.has_sam_optim:
249
+
250
+ opt = self.optimizers()
251
+ def closure():
252
+ opt.zero_grad()
253
+ y_hat = self(batch)
254
+ loss = self.compute_loss(batch,y_hat)
255
+ self.manual_backward(loss)
256
+ return loss
257
+
258
+ opt.step(closure)
259
+ y_hat = self(batch)
260
+ loss = self.compute_loss(batch,y_hat)
261
+
262
+ #opt.first_step(zero_grad=True)
263
+
264
+ #y_hat = self(batch)
265
+ #loss = self.compute_loss(batch, y_hat)
266
+ #self.my_step+=1
267
+ #self.manual_backward(loss,retain_graph=True)
268
+ #opt.second_step(zero_grad=True)
269
+ #self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
270
+ #self.log("global_step", self.my_step, on_step=True) # Correct way to log
271
+
272
+
273
+ #self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.increment("optimizer")
274
+ else:
275
+ y_hat = self(batch)
276
+ loss = self.compute_loss(batch,y_hat)
277
+
278
+ self.train_epoch_metrics.append(loss.item())
279
+ return loss
280
+
281
+
282
+ def validation_step(self, batch, batch_idx):
283
+ """
284
+ pythotrch lightening stuff
285
+
286
+ :meta private:
287
+ """
288
+ y_hat = self(batch)
289
+ if batch_idx==0:
290
+ if self.use_quantiles:
291
+ idx = 1
292
+ else:
293
+ idx = 0
294
+ #track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
295
+
296
+ if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
297
+
298
+ for i in range(batch['y'].shape[2]):
299
+ real = batch['y'][0,:,i].cpu().detach().numpy()
300
+ pred = y_hat[0,:,i,idx].cpu().detach().numpy()
301
+ fig, ax = plt.subplots(figsize=(7,5))
302
+ ax.plot(real,'o-',label='real')
303
+ ax.plot(pred,'o-',label='pred')
304
+ ax.legend()
305
+ ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
306
+ self.logger.experiment.track(Image(fig), name='cm_training_end')
307
+ #self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
308
+ self.validation_epoch_metrics.append(self.compute_loss(batch,y_hat))
309
+ return
310
+
311
+
312
+ def on_validation_epoch_end(self):
313
+ """
314
+ pythotrch lightening stuff
315
+
316
+ :meta private:
317
+ """
318
+ avg = torch.stack(self.validation_epoch_metrics).mean()
319
+ self.validation_epoch_metrics = []
320
+ self.log("val_loss", avg,sync_dist=True)
321
+ beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
322
+
323
+ def on_train_epoch_end(self):
324
+
325
+ """
326
+ pythotrch lightening stuff
327
+
328
+ :meta private:
329
+ """
330
+ avg = np.stack(self.train_epoch_metrics).mean()
331
+ self.log("train_loss", avg,sync_dist=True)
332
+ self.count_epoch+=1
333
+ self.train_epoch_metrics = []
334
+ self.train_loss_epoch = avg
335
+
336
+ def compute_loss(self,batch,y_hat):
337
+ """
338
+ custom loss calculation
339
+
340
+ :meta private:
341
+ """
342
+
343
+ if self.use_quantiles is False:
344
+ initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
345
+ else:
346
+ initial_loss = self.loss(y_hat, batch['y'])
347
+ x = batch['x_num_past'].to(self.device)
348
+ idx_target = batch['idx_target'][0]
349
+ x_start = x[:,-1,idx_target].unsqueeze(1)
350
+ y_persistence = x_start.repeat(1,self.future_steps,1)
351
+
352
+ ##generally you want to work without quantile loss
353
+ if self.use_quantiles is False:
354
+ x = y_hat[:,:,:,0]
355
+ else:
356
+ x = y_hat[:,:,:,1]
357
+
358
+
359
+ if self.loss_type == 'linear_penalization':
360
+ persistence_error = (2.0-10.0*torch.clamp( torch.abs((y_persistence-x)/(0.001+torch.abs(y_persistence))),min=0.0,max=max(0.05,0.1*(1+np.log10(self.persistence_weight) ))))
361
+ loss = torch.mean(torch.abs(x- batch['y'])*persistence_error)
362
+
363
+ if self.loss_type == 'mda':
364
+ #import pdb
365
+ #pdb.set_trace()
366
+ mda = (1-torch.mean( torch.sign(torch.diff(x,axis=1))*torch.sign(torch.diff(batch['y'],axis=1))))
367
+ loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1).flatten()) + self.persistence_weight*mda/10
368
+
369
+
370
+
371
+ elif self.loss_type == 'exponential_penalization':
372
+ weights = (1+self.persistence_weight*torch.exp(-torch.abs(y_persistence-x)))
373
+ loss = torch.mean(torch.abs(x- batch['y'])*weights)
374
+
375
+ elif self.loss_type=='sinkhorn':
376
+ sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction='mean')
377
+ loss = sinkhorn.compute(x,batch['y'])
378
+
379
+ elif self.loss_type == 'additive_iv':
380
+ std = torch.sqrt(torch.var(batch['y'], dim=(1))+ 1e-8) ##--> BSxChannel
381
+ x_std = torch.sqrt(torch.var(x, dim=(1))+ 1e-8)
382
+ loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1).flatten() + self.persistence_weight*torch.abs(x_std-std).flatten())
383
+
384
+ elif self.loss_type == 'multiplicative_iv':
385
+ std = torch.sqrt(torch.var(batch['y'], dim=(1))+ 1e-8) ##--> BSxChannel
386
+ x_std = torch.sqrt(torch.var(x, dim=(1))+ 1e-8)
387
+ if self.persistence_weight>0:
388
+ loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1)*torch.abs(x_std-std))
389
+ else:
390
+ loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1))
391
+ elif self.loss_type=='global_iv':
392
+ std_real = torch.sqrt(torch.var(batch['y'], dim=(0,1)))
393
+ std_predict = torch.sqrt(torch.var(x, dim=(0,1)))
394
+ loss = initial_loss + self.persistence_weight*torch.abs(std_real-std_predict).mean()
395
+
396
+ elif self.loss_type=='smape':
397
+ loss = torch.mean(2*torch.abs(x-batch['y']) / (torch.abs(x)+torch.abs(batch['y'])))
398
+
399
+ elif self.loss_type=='triplet':
400
+ loss_fn = torch.nn.TripletMarginLoss(margin=0.01, p=1.0,swap=False)
401
+ loss = initial_loss + self.persistence_weight*loss_fn(x, batch['y'], y_persistence)
402
+
403
+ elif self.loss_type=='high_order':
404
+ loss = initial_loss
405
+ for i in range(2,5):
406
+ mom_real = standardize_momentum( batch['y'],i)
407
+ mom_pred = standardize_momentum(x,i)
408
+
409
+ mom_loss = torch.abs(mom_real-mom_pred).mean()
410
+ loss+=self.persistence_weight*mom_loss
411
+
412
+ elif self.loss_type=='dilated':
413
+ #BxLxCxMUL
414
+ if self.persistence_weight==0.1:
415
+ alpha = 0.25
416
+ if self.persistence_weight==1:
417
+ alpha = 0.5
418
+ else:
419
+ alpha =0.75
420
+ alpha = self.persistence_weight
421
+ gamma = 0.01
422
+ loss = 0
423
+ ##no multichannel here
424
+ for i in range(y_hat.shape[2]):
425
+ ##error here
426
+
427
+ loss+= dilate_loss( batch['y'][:,:,i:i+1],x[:,:,i:i+1], alpha, gamma, y_hat.device)
428
+
429
+ elif self.loss_type=='huber':
430
+ loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight/10)
431
+ #loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
432
+ if self.use_quantiles is False:
433
+ x = y_hat[:,:,:,0]
434
+ else:
435
+ x = y_hat[:,:,:,1]
436
+ BS = x.shape[0]
437
+ loss = loss(y_hat.reshape(BS,-1), batch['y'].reshape(BS,-1))
438
+
439
+ else:
440
+ loss = initial_loss
441
+
442
+
443
+
444
+ return loss
File without changes
@@ -0,0 +1,118 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange, repeat
4
+
5
+
6
+ from math import sqrt
7
+
8
+ class FullAttention(nn.Module):
9
+ '''
10
+ The Attention operation
11
+ '''
12
+ def __init__(self, scale=None, attention_dropout=0.1):
13
+ super(FullAttention, self).__init__()
14
+ self.scale = scale
15
+ self.dropout = nn.Dropout(attention_dropout)
16
+
17
+ def forward(self, queries, keys, values):
18
+ B, L, H, E = queries.shape
19
+ _, S, _, D = values.shape
20
+ scale = self.scale or 1./sqrt(E)
21
+
22
+ scores = torch.einsum("blhe,bshe->bhls", queries, keys)
23
+ A = self.dropout(torch.softmax(scale * scores, dim=-1))
24
+ V = torch.einsum("bhls,bshd->blhd", A, values)
25
+
26
+ return V.contiguous()
27
+
28
+
29
+ class AttentionLayer(nn.Module):
30
+ '''
31
+ The Multi-head Self-Attention (MSA) Layer
32
+ '''
33
+ def __init__(self, d_model, n_heads, d_keys=None, d_values=None, mix=True, dropout = 0.1):
34
+ super(AttentionLayer, self).__init__()
35
+
36
+ d_keys = d_keys or (d_model//n_heads)
37
+ d_values = d_values or (d_model//n_heads)
38
+
39
+ self.inner_attention = FullAttention(scale=None, attention_dropout = dropout)
40
+ self.query_projection = nn.Linear(d_model, d_keys * n_heads)
41
+ self.key_projection = nn.Linear(d_model, d_keys * n_heads)
42
+ self.value_projection = nn.Linear(d_model, d_values * n_heads)
43
+ self.out_projection = nn.Linear(d_values * n_heads, d_model)
44
+ self.n_heads = n_heads
45
+ self.mix = mix
46
+
47
+ def forward(self, queries, keys, values):
48
+ B, L, _ = queries.shape
49
+ _, S, _ = keys.shape
50
+ H = self.n_heads
51
+
52
+ queries = self.query_projection(queries).view(B, L, H, -1)
53
+ keys = self.key_projection(keys).view(B, S, H, -1)
54
+ values = self.value_projection(values).view(B, S, H, -1)
55
+
56
+ out = self.inner_attention(
57
+ queries,
58
+ keys,
59
+ values,
60
+ )
61
+ if self.mix:
62
+ out = out.transpose(2,1).contiguous()
63
+ out = out.view(B, L, -1)
64
+
65
+ return self.out_projection(out)
66
+
67
+ class TwoStageAttentionLayer(nn.Module):
68
+ '''
69
+ The Two Stage Attention (TSA) Layer
70
+ input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]
71
+ '''
72
+ def __init__(self, seg_num, factor, d_model, n_heads, d_ff = None, dropout=0.1):
73
+ super(TwoStageAttentionLayer, self).__init__()
74
+ d_ff = d_ff or 4*d_model
75
+ self.time_attention = AttentionLayer(d_model, n_heads, dropout = dropout)
76
+ self.dim_sender = AttentionLayer(d_model, n_heads, dropout = dropout)
77
+ self.dim_receiver = AttentionLayer(d_model, n_heads, dropout = dropout)
78
+ self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))
79
+
80
+ self.dropout = nn.Dropout(dropout)
81
+
82
+ self.norm1 = nn.LayerNorm(d_model)
83
+ self.norm2 = nn.LayerNorm(d_model)
84
+ self.norm3 = nn.LayerNorm(d_model)
85
+ self.norm4 = nn.LayerNorm(d_model)
86
+
87
+ self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff),
88
+ nn.GELU(),
89
+ nn.Linear(d_ff, d_model))
90
+ self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff),
91
+ nn.GELU(),
92
+ nn.Linear(d_ff, d_model))
93
+
94
+ def forward(self, x):
95
+ #Cross Time Stage: Directly apply MSA to each dimension
96
+ batch = x.shape[0]
97
+ time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model')
98
+ time_enc = self.time_attention(
99
+ time_in, time_in, time_in
100
+ )
101
+ dim_in = time_in + self.dropout(time_enc)
102
+ dim_in = self.norm1(dim_in)
103
+ dim_in = dim_in + self.dropout(self.MLP1(dim_in))
104
+ dim_in = self.norm2(dim_in)
105
+
106
+ #Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection
107
+ dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b = batch)
108
+ batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat = batch)
109
+ dim_buffer = self.dim_sender(batch_router, dim_send, dim_send)
110
+ dim_receive = self.dim_receiver(dim_send, dim_buffer, dim_buffer)
111
+ dim_enc = dim_send + self.dropout(dim_receive)
112
+ dim_enc = self.norm3(dim_enc)
113
+ dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))
114
+ dim_enc = self.norm4(dim_enc)
115
+
116
+ final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b = batch)
117
+
118
+ return final_out
@@ -0,0 +1,77 @@
1
+ import torch.nn as nn
2
+ from einops import rearrange
3
+ from .attn import AttentionLayer, TwoStageAttentionLayer
4
+
5
+ class DecoderLayer(nn.Module):
6
+ '''
7
+ The decoder layer of Crossformer, each layer will make a prediction at its scale
8
+ '''
9
+ def __init__(self, seg_len, d_model, n_heads, d_ff=None, dropout=0.1, out_seg_num = 10, factor = 10):
10
+ super(DecoderLayer, self).__init__()
11
+ self.self_attention = TwoStageAttentionLayer(out_seg_num, factor, d_model, n_heads,d_ff, dropout)
12
+ self.cross_attention = AttentionLayer(d_model, n_heads, dropout = dropout)
13
+ self.norm1 = nn.LayerNorm(d_model)
14
+ self.norm2 = nn.LayerNorm(d_model)
15
+ self.dropout = nn.Dropout(dropout)
16
+ self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model),
17
+ nn.GELU(),
18
+ nn.Linear(d_model, d_model))
19
+ self.linear_pred = nn.Linear(d_model, seg_len)
20
+
21
+ def forward(self, x, cross):
22
+ '''
23
+ x: the output of last decoder layer
24
+ cross: the output of the corresponding encoder layer
25
+ '''
26
+
27
+ batch = x.shape[0]
28
+ x = self.self_attention(x)
29
+ x = rearrange(x, 'b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model')
30
+
31
+ cross = rearrange(cross, 'b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model')
32
+ tmp = self.cross_attention(
33
+ x, cross, cross,
34
+ )
35
+ x = x + self.dropout(tmp)
36
+ y = x = self.norm1(x)
37
+ y = self.MLP1(y)
38
+ dec_output = self.norm2(x+y)
39
+
40
+ dec_output = rearrange(dec_output, '(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model', b = batch)
41
+ layer_predict = self.linear_pred(dec_output)
42
+ layer_predict = rearrange(layer_predict, 'b out_d seg_num seg_len -> b (out_d seg_num) seg_len')
43
+
44
+ return dec_output, layer_predict
45
+
46
+ class Decoder(nn.Module):
47
+ '''
48
+ The decoder of Crossformer, making the final prediction by adding up predictions at each scale
49
+ '''
50
+ def __init__(self, seg_len, d_layers, d_model, n_heads, d_ff, dropout,\
51
+ router=False, out_seg_num = 10, factor=10):
52
+ super(Decoder, self).__init__()
53
+
54
+ self.router = router
55
+ self.decode_layers = nn.ModuleList()
56
+ for i in range(d_layers):
57
+ self.decode_layers.append(DecoderLayer(seg_len, d_model, n_heads, d_ff, dropout, \
58
+ out_seg_num, factor))
59
+
60
+ def forward(self, x, cross):
61
+ final_predict = None
62
+ i = 0
63
+
64
+ ts_d = x.shape[1]
65
+ for layer in self.decode_layers:
66
+ cross_enc = cross[i]
67
+ x, layer_predict = layer(x, cross_enc)
68
+ if final_predict is None:
69
+ final_predict = layer_predict
70
+ else:
71
+ final_predict = final_predict + layer_predict
72
+ i += 1
73
+
74
+ final_predict = rearrange(final_predict, 'b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d', out_d = ts_d)
75
+
76
+ return final_predict
77
+
@@ -0,0 +1,18 @@
1
+ import torch.nn as nn
2
+ from einops import rearrange
3
+
4
+ class DSW_embedding(nn.Module):
5
+ def __init__(self, seg_len, d_model):
6
+ super(DSW_embedding, self).__init__()
7
+ self.seg_len = seg_len
8
+
9
+ self.linear = nn.Linear(seg_len, d_model)
10
+
11
+ def forward(self, x):
12
+ batch, ts_len, ts_dim = x.shape
13
+
14
+ x_segment = rearrange(x, 'b (seg_num seg_len) d -> (b d seg_num) seg_len', seg_len = self.seg_len)
15
+ x_embed = self.linear(x_segment)
16
+ x_embed = rearrange(x_embed, '(b d seg_num) d_model -> b d seg_num d_model', b = batch, d = ts_dim)
17
+
18
+ return x_embed