dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
dsipts/models/base.py ADDED
@@ -0,0 +1,439 @@
1
+
2
+ from torch import optim
3
+ import torch
4
+ import pytorch_lightning as pl
5
+ from torch.optim.lr_scheduler import StepLR
6
+ from abc import abstractmethod
7
+ from .utils import SinkhornDistance, SoftDTWBatch,PathDTWBatch,pairwise_distances
8
+ from ..data_structure.utils import beauty_string
9
+ from .samformer.utils import SAM
10
+ from .utils import get_scope
11
+ import numpy as np
12
+ from aim import Image
13
+ import matplotlib.pyplot as plt
14
+ from typing import List, Union
15
+ from .utils import QuantileLossMO
16
+ import torch.nn as nn
17
+
18
+ def standardize_momentum(x,order):
19
+ mean = torch.mean(x,1).unsqueeze(1).repeat(1,x.shape[1],1)
20
+ num = torch.pow(x-mean,order).mean(axis=1)
21
+ #den = torch.sqrt(torch.pow(x-mean,2).mean(axis=1)+1e-8)
22
+ #den = torch.pow(den,order)
23
+
24
+ return num#/den
25
+
26
+
27
+ def dilate_loss(outputs, targets, alpha, gamma, device):
28
+ # outputs, targets: shape (batch_size, N_output, 1)
29
+ batch_size, N_output = outputs.shape[0:2]
30
+ loss_shape = 0
31
+ softdtw_batch = SoftDTWBatch.apply
32
+ D = torch.zeros((batch_size, N_output,N_output )).to(device)
33
+ for k in range(batch_size):
34
+ Dk = pairwise_distances(targets[k,:,:].view(-1,1),outputs[k,:,:].view(-1,1))
35
+ D[k:k+1,:,:] = Dk
36
+ loss_shape = softdtw_batch(D,gamma)
37
+
38
+ path_dtw = PathDTWBatch.apply
39
+ path = path_dtw(D,gamma)
40
+ Omega = pairwise_distances(torch.range(1,N_output).view(N_output,1)).to(device)
41
+ loss_temporal = torch.sum( path*Omega ) / (N_output*N_output)
42
+ loss = alpha*loss_shape+ (1-alpha)*loss_temporal
43
+ return loss#, loss_shape, loss_temporal
44
+
45
+
46
+ class Base(pl.LightningModule):
47
+
48
+ ############### SET THE PROPERTIES OF THE ARCHITECTURE##############
49
+ handle_multivariate = False
50
+ handle_future_covariates = False
51
+ handle_categorical_variables = False
52
+ handle_quantile_loss = False
53
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
54
+ #####################################################################
55
+ @abstractmethod
56
+ def __init__(self,verbose:bool,
57
+
58
+ past_steps:int,
59
+ future_steps:int,
60
+ past_channels:int,
61
+ future_channels:int,
62
+ out_channels:int,
63
+ embs_past:List[int],
64
+ embs_fut:List[int],
65
+ n_classes:int=0,
66
+
67
+ persistence_weight:float=0.0,
68
+ loss_type: str='l1',
69
+ quantiles:List[int]=[],
70
+ reduction_mode:str = 'mean',
71
+ use_classical_positional_encoder:bool=False,
72
+ emb_dim: int=16,
73
+
74
+ optim:Union[str,None]=None,
75
+ optim_config:dict=None,
76
+ scheduler_config:dict=None,):
77
+ """
78
+ This is the basic model, each model implemented must overwrite the init method and the forward method.
79
+ The inference step is optional, by default it uses the forward method but for recurrent
80
+ network you should implement your own method
81
+
82
+ Args:
83
+ verbose (bool): Flag to enable verbose logging.
84
+ past_steps (int): Number of past time steps to consider.
85
+ future_steps (int): Number of future time steps to predict.
86
+ past_channels (int): Number of channels in the past input data.
87
+ future_channels (int): Number of channels in the future input data.
88
+ out_channels (int): Number of output channels.
89
+ embs_past (List[int]): List of embedding dimensions for past data.
90
+ embs_fut (List[int]): List of embedding dimensions for future data.
91
+ n_classes (int, optional): Number of classes for classification. Defaults to 0.
92
+ persistence_weight (float, optional): Weight for persistence in loss calculation. Defaults to 0.0.
93
+ loss_type (str, optional): Type of loss function to use ('l1' or 'mse'). Defaults to 'l1'.
94
+ quantiles (List[int], optional): List of quantiles for quantile loss. Defaults to an empty list.
95
+ reduction_mode (str, optional): Mode for reduction for categorical embedding layer ('mean', 'sum', 'none'). Defaults to 'mean'.
96
+ use_classical_positional_encoder (bool, optional): Flag to use classical positional encoding or using embedding layer also for the positions. Defaults to False.
97
+ emb_dim (int, optional): Dimension of categorical embeddings. Defaults to 16.
98
+ optim (Union[str, None], optional): Optimizer type. Defaults to None.
99
+ optim_config (dict, optional): Configuration for the optimizer. Defaults to None.
100
+ scheduler_config (dict, optional): Configuration for the learning rate scheduler. Defaults to None.
101
+
102
+ Raises:
103
+ AssertionError: If the number of quantiles is not equal to 3 when quantiles are provided.
104
+ AssertionError: If the number of output channels is not 1 for classification tasks.
105
+ """
106
+ beauty_string('V2','block',True)
107
+ super(Base, self).__init__()
108
+ self.save_hyperparameters(logger=False)
109
+ self.count_epoch = 0
110
+ self.initialize = False
111
+ self.train_loss_epoch = -100.0
112
+ self.verbose = verbose
113
+ self.name = self.__class__.__name__
114
+ self.train_epoch_metrics = []
115
+ self.validation_epoch_metrics = []
116
+
117
+ self.use_quantiles = True if len(quantiles)>0 else False
118
+ self.quantiles = quantiles
119
+ self.optim = optim
120
+ self.optim_config = optim_config
121
+ self.scheduler_config = scheduler_config
122
+ self.loss_type = loss_type
123
+ self.persistence_weight = persistence_weight
124
+ self.use_classical_positional_encoder = use_classical_positional_encoder
125
+ self.reduction_mode = reduction_mode
126
+ self.past_steps = past_steps
127
+ self.future_steps = future_steps
128
+ self.embs_past = embs_past
129
+ self.embs_fut = embs_fut
130
+ self.past_channels = past_channels
131
+ self.future_channels = future_channels
132
+ self.emb_dim = emb_dim
133
+ self.out_channels = out_channels
134
+ self.n_classes = n_classes
135
+ if n_classes==0:
136
+ self.is_classification = False
137
+ if len(self.quantiles)>0:
138
+ assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
139
+ self.use_quantiles = True
140
+ self.mul = len(self.quantiles)
141
+ self.loss = QuantileLossMO(quantiles)
142
+ else:
143
+ self.use_quantiles = False
144
+ self.mul = 1
145
+ if self.loss_type == 'mse':
146
+ self.loss = nn.MSELoss()
147
+ else:
148
+ self.loss = nn.L1Loss()
149
+ else:
150
+ self.is_classification = True
151
+ self.use_quantiles = False
152
+ self.mul = n_classes
153
+ self.loss = torch.nn.CrossEntropyLoss()
154
+ assert self.out_channels==1, "Classification require only one channel"
155
+
156
+ self.future_steps = future_steps
157
+
158
+ beauty_string(self.description,'info',True)
159
+ @abstractmethod
160
+ def forward(self, batch:dict)-> torch.tensor:
161
+ """Forlward method used during the training loop
162
+
163
+ Args:
164
+ batch (dict): the batch structure. The keys are:
165
+ y : the target variable(s). This is always present
166
+ x_num_past: the numerical past variables. This is always present
167
+ x_num_future: the numerical future variables
168
+ x_cat_past: the categorical past variables
169
+ x_cat_future: the categorical future variables
170
+ idx_target: index of target features in the past array
171
+
172
+
173
+ Returns:
174
+ torch.tensor: output of the mode;
175
+ """
176
+ return None
177
+
178
+
179
+
180
+ def inference(self, batch:dict)->torch.tensor:
181
+ """Usually it is ok to return the output of the forward method but sometimes not (e.g. RNN)
182
+
183
+ Args:
184
+ batch (dict): batch
185
+
186
+ Returns:
187
+ torch.tensor: result
188
+ """
189
+ return self(batch)
190
+
191
+ def configure_optimizers(self):
192
+ """
193
+ Each model has optim_config and scheduler_config
194
+
195
+ :meta private:
196
+ """
197
+
198
+ self.has_sam_optim = False
199
+ if self.optim_config is None:
200
+ self.optim_config = {'lr': 5e-05}
201
+
202
+
203
+ if self.optim is None:
204
+ optimizer = optim.Adam(self.parameters(), **self.optim_config)
205
+ self.initialize = True
206
+
207
+ else:
208
+ if self.initialize is False:
209
+ if self.optim=='SAM':
210
+ self.has_sam_optim = True
211
+ self.automatic_optimization = False
212
+ self.my_step = 0
213
+
214
+ else:
215
+ self.optim = eval(self.optim)
216
+ self.has_sam_optim = False
217
+ self.automatic_optimization = True
218
+
219
+ beauty_string(self.optim,'',self.verbose)
220
+ if self.has_sam_optim:
221
+ optimizer = SAM(self.parameters(), base_optimizer=torch.optim.Adam, **self.optim_config)
222
+ else:
223
+ optimizer = self.optim(self.parameters(), **self.optim_config)
224
+ beauty_string(optimizer,'',self.verbose)
225
+ self.initialize = True
226
+ self.lr = self.optim_config['lr']
227
+ if self.scheduler_config is not None:
228
+ scheduler = StepLR(optimizer,**self.scheduler_config)
229
+ return [optimizer], [scheduler]
230
+ else:
231
+ return optimizer
232
+
233
+
234
+ def training_step(self, batch, batch_idx):
235
+ """
236
+ pythotrch lightening stuff
237
+
238
+ :meta private:
239
+ """
240
+
241
+ #loss = self.compute_loss(batch,y_hat)
242
+ #import pdb
243
+ #pdb.set_trace()
244
+
245
+ if self.has_sam_optim:
246
+
247
+ opt = self.optimizers()
248
+ def closure():
249
+ opt.zero_grad()
250
+ y_hat = self(batch)
251
+ loss = self.compute_loss(batch,y_hat)
252
+ self.manual_backward(loss)
253
+ return loss
254
+
255
+ opt.step(closure)
256
+ y_hat = self(batch)
257
+ loss = self.compute_loss(batch,y_hat)
258
+
259
+ #opt.first_step(zero_grad=True)
260
+
261
+ #y_hat = self(batch)
262
+ #loss = self.compute_loss(batch, y_hat)
263
+ #self.my_step+=1
264
+ #self.manual_backward(loss,retain_graph=True)
265
+ #opt.second_step(zero_grad=True)
266
+ #self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
267
+ #self.log("global_step", self.my_step, on_step=True) # Correct way to log
268
+
269
+
270
+ #self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.increment("optimizer")
271
+ else:
272
+ y_hat = self(batch)
273
+ loss = self.compute_loss(batch,y_hat)
274
+ return loss
275
+
276
+
277
+ def validation_step(self, batch, batch_idx):
278
+ """
279
+ pythotrch lightening stuff
280
+
281
+ :meta private:
282
+ """
283
+ y_hat = self(batch)
284
+ if batch_idx==0:
285
+ if self.use_quantiles:
286
+ idx = 1
287
+ else:
288
+ idx = 0
289
+ #track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
290
+
291
+ if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
292
+
293
+ for i in range(batch['y'].shape[2]):
294
+ real = batch['y'][0,:,i].cpu().detach().numpy()
295
+ pred = y_hat[0,:,i,idx].cpu().detach().numpy()
296
+ fig, ax = plt.subplots(figsize=(7,5))
297
+ ax.plot(real,'o-',label='real')
298
+ ax.plot(pred,'o-',label='pred')
299
+ ax.legend()
300
+ ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
301
+ self.logger.experiment.track(Image(fig), name='cm_training_end')
302
+ #self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
303
+
304
+ return self.compute_loss(batch,y_hat)
305
+
306
+
307
+ def validation_epoch_end(self, outs):
308
+ """
309
+ pythotrch lightening stuff
310
+
311
+ :meta private:
312
+ """
313
+
314
+ loss = torch.stack(outs).mean()
315
+ self.log("val_loss", loss.item(),sync_dist=True)
316
+ beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {loss.item():.4f}','info',self.verbose)
317
+
318
+ def training_epoch_end(self, outs):
319
+ """
320
+ pythotrch lightening stuff
321
+
322
+ :meta private:
323
+ """
324
+
325
+ loss = sum(outs['loss'] for outs in outs) / len(outs)
326
+ self.log("train_loss", loss.item(),sync_dist=True)
327
+ self.count_epoch+=1
328
+
329
+ self.train_loss_epoch = loss.item()
330
+
331
+ def compute_loss(self,batch,y_hat):
332
+ """
333
+ custom loss calculation
334
+
335
+ :meta private:
336
+ """
337
+
338
+ if self.use_quantiles is False:
339
+ initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
340
+ else:
341
+ initial_loss = self.loss(y_hat, batch['y'])
342
+ x = batch['x_num_past'].to(self.device)
343
+ idx_target = batch['idx_target'][0]
344
+ x_start = x[:,-1,idx_target].unsqueeze(1)
345
+ y_persistence = x_start.repeat(1,self.future_steps,1)
346
+
347
+ ##generally you want to work without quantile loss
348
+ if self.use_quantiles is False:
349
+ x = y_hat[:,:,:,0]
350
+ else:
351
+ x = y_hat[:,:,:,1]
352
+
353
+
354
+ if self.loss_type == 'linear_penalization':
355
+ persistence_error = (2.0-10.0*torch.clamp( torch.abs((y_persistence-x)/(0.001+torch.abs(y_persistence))),min=0.0,max=max(0.05,0.1*(1+np.log10(self.persistence_weight) ))))
356
+ loss = torch.mean(torch.abs(x- batch['y'])*persistence_error)
357
+
358
+ if self.loss_type == 'mda':
359
+ #import pdb
360
+ #pdb.set_trace()
361
+ mda = (1-torch.mean( torch.sign(torch.diff(x,axis=1))*torch.sign(torch.diff(batch['y'],axis=1))))
362
+ loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1).flatten()) + self.persistence_weight*mda/10
363
+
364
+
365
+
366
+ elif self.loss_type == 'exponential_penalization':
367
+ weights = (1+self.persistence_weight*torch.exp(-torch.abs(y_persistence-x)))
368
+ loss = torch.mean(torch.abs(x- batch['y'])*weights)
369
+
370
+ elif self.loss_type=='sinkhorn':
371
+ sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction='mean')
372
+ loss = sinkhorn.compute(x,batch['y'])
373
+
374
+ elif self.loss_type == 'additive_iv':
375
+ std = torch.sqrt(torch.var(batch['y'], dim=(1))+ 1e-8) ##--> BSxChannel
376
+ x_std = torch.sqrt(torch.var(x, dim=(1))+ 1e-8)
377
+ loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1).flatten() + self.persistence_weight*torch.abs(x_std-std).flatten())
378
+
379
+ elif self.loss_type == 'multiplicative_iv':
380
+ std = torch.sqrt(torch.var(batch['y'], dim=(1))+ 1e-8) ##--> BSxChannel
381
+ x_std = torch.sqrt(torch.var(x, dim=(1))+ 1e-8)
382
+ if self.persistence_weight>0:
383
+ loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1)*torch.abs(x_std-std))
384
+ else:
385
+ loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1))
386
+ elif self.loss_type=='global_iv':
387
+ std_real = torch.sqrt(torch.var(batch['y'], dim=(0,1)))
388
+ std_predict = torch.sqrt(torch.var(x, dim=(0,1)))
389
+ loss = initial_loss + self.persistence_weight*torch.abs(std_real-std_predict).mean()
390
+
391
+ elif self.loss_type=='smape':
392
+ loss = torch.mean(2*torch.abs(x-batch['y']) / (torch.abs(x)+torch.abs(batch['y'])))
393
+
394
+ elif self.loss_type=='triplet':
395
+ loss_fn = torch.nn.TripletMarginLoss(margin=0.01, p=1.0,swap=False)
396
+ loss = initial_loss + self.persistence_weight*loss_fn(x, batch['y'], y_persistence)
397
+
398
+ elif self.loss_type=='high_order':
399
+ loss = initial_loss
400
+ for i in range(2,5):
401
+ mom_real = standardize_momentum( batch['y'],i)
402
+ mom_pred = standardize_momentum(x,i)
403
+
404
+ mom_loss = torch.abs(mom_real-mom_pred).mean()
405
+ loss+=self.persistence_weight*mom_loss
406
+
407
+ elif self.loss_type=='dilated':
408
+ #BxLxCxMUL
409
+ if self.persistence_weight==0.1:
410
+ alpha = 0.25
411
+ if self.persistence_weight==1:
412
+ alpha = 0.5
413
+ else:
414
+ alpha =0.75
415
+ alpha = self.persistence_weight
416
+ gamma = 0.01
417
+ loss = 0
418
+ ##no multichannel here
419
+ for i in range(y_hat.shape[2]):
420
+ ##error here
421
+
422
+ loss+= dilate_loss( batch['y'][:,:,i:i+1],x[:,:,i:i+1], alpha, gamma, y_hat.device)
423
+
424
+ elif self.loss_type=='huber':
425
+ loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight/10)
426
+ #loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
427
+ if self.use_quantiles is False:
428
+ x = y_hat[:,:,:,0]
429
+ else:
430
+ x = y_hat[:,:,:,1]
431
+ BS = x.shape[0]
432
+ loss = loss(y_hat.reshape(BS,-1), batch['y'].reshape(BS,-1))
433
+
434
+ else:
435
+ loss = initial_loss
436
+
437
+
438
+
439
+ return loss