dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
dsipts/models/base.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
|
|
2
|
+
from torch import optim
|
|
3
|
+
import torch
|
|
4
|
+
import pytorch_lightning as pl
|
|
5
|
+
from torch.optim.lr_scheduler import StepLR
|
|
6
|
+
from abc import abstractmethod
|
|
7
|
+
from .utils import SinkhornDistance, SoftDTWBatch,PathDTWBatch,pairwise_distances
|
|
8
|
+
from ..data_structure.utils import beauty_string
|
|
9
|
+
from .samformer.utils import SAM
|
|
10
|
+
from .utils import get_scope
|
|
11
|
+
import numpy as np
|
|
12
|
+
from aim import Image
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
from typing import List, Union
|
|
15
|
+
from .utils import QuantileLossMO
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
|
|
18
|
+
def standardize_momentum(x,order):
|
|
19
|
+
mean = torch.mean(x,1).unsqueeze(1).repeat(1,x.shape[1],1)
|
|
20
|
+
num = torch.pow(x-mean,order).mean(axis=1)
|
|
21
|
+
#den = torch.sqrt(torch.pow(x-mean,2).mean(axis=1)+1e-8)
|
|
22
|
+
#den = torch.pow(den,order)
|
|
23
|
+
|
|
24
|
+
return num#/den
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def dilate_loss(outputs, targets, alpha, gamma, device):
|
|
28
|
+
# outputs, targets: shape (batch_size, N_output, 1)
|
|
29
|
+
batch_size, N_output = outputs.shape[0:2]
|
|
30
|
+
loss_shape = 0
|
|
31
|
+
softdtw_batch = SoftDTWBatch.apply
|
|
32
|
+
D = torch.zeros((batch_size, N_output,N_output )).to(device)
|
|
33
|
+
for k in range(batch_size):
|
|
34
|
+
Dk = pairwise_distances(targets[k,:,:].view(-1,1),outputs[k,:,:].view(-1,1))
|
|
35
|
+
D[k:k+1,:,:] = Dk
|
|
36
|
+
loss_shape = softdtw_batch(D,gamma)
|
|
37
|
+
|
|
38
|
+
path_dtw = PathDTWBatch.apply
|
|
39
|
+
path = path_dtw(D,gamma)
|
|
40
|
+
Omega = pairwise_distances(torch.range(1,N_output).view(N_output,1)).to(device)
|
|
41
|
+
loss_temporal = torch.sum( path*Omega ) / (N_output*N_output)
|
|
42
|
+
loss = alpha*loss_shape+ (1-alpha)*loss_temporal
|
|
43
|
+
return loss#, loss_shape, loss_temporal
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Base(pl.LightningModule):
|
|
47
|
+
|
|
48
|
+
############### SET THE PROPERTIES OF THE ARCHITECTURE##############
|
|
49
|
+
handle_multivariate = False
|
|
50
|
+
handle_future_covariates = False
|
|
51
|
+
handle_categorical_variables = False
|
|
52
|
+
handle_quantile_loss = False
|
|
53
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
54
|
+
#####################################################################
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def __init__(self,verbose:bool,
|
|
57
|
+
|
|
58
|
+
past_steps:int,
|
|
59
|
+
future_steps:int,
|
|
60
|
+
past_channels:int,
|
|
61
|
+
future_channels:int,
|
|
62
|
+
out_channels:int,
|
|
63
|
+
embs_past:List[int],
|
|
64
|
+
embs_fut:List[int],
|
|
65
|
+
n_classes:int=0,
|
|
66
|
+
|
|
67
|
+
persistence_weight:float=0.0,
|
|
68
|
+
loss_type: str='l1',
|
|
69
|
+
quantiles:List[int]=[],
|
|
70
|
+
reduction_mode:str = 'mean',
|
|
71
|
+
use_classical_positional_encoder:bool=False,
|
|
72
|
+
emb_dim: int=16,
|
|
73
|
+
|
|
74
|
+
optim:Union[str,None]=None,
|
|
75
|
+
optim_config:dict=None,
|
|
76
|
+
scheduler_config:dict=None,):
|
|
77
|
+
"""
|
|
78
|
+
This is the basic model, each model implemented must overwrite the init method and the forward method.
|
|
79
|
+
The inference step is optional, by default it uses the forward method but for recurrent
|
|
80
|
+
network you should implement your own method
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
verbose (bool): Flag to enable verbose logging.
|
|
84
|
+
past_steps (int): Number of past time steps to consider.
|
|
85
|
+
future_steps (int): Number of future time steps to predict.
|
|
86
|
+
past_channels (int): Number of channels in the past input data.
|
|
87
|
+
future_channels (int): Number of channels in the future input data.
|
|
88
|
+
out_channels (int): Number of output channels.
|
|
89
|
+
embs_past (List[int]): List of embedding dimensions for past data.
|
|
90
|
+
embs_fut (List[int]): List of embedding dimensions for future data.
|
|
91
|
+
n_classes (int, optional): Number of classes for classification. Defaults to 0.
|
|
92
|
+
persistence_weight (float, optional): Weight for persistence in loss calculation. Defaults to 0.0.
|
|
93
|
+
loss_type (str, optional): Type of loss function to use ('l1' or 'mse'). Defaults to 'l1'.
|
|
94
|
+
quantiles (List[int], optional): List of quantiles for quantile loss. Defaults to an empty list.
|
|
95
|
+
reduction_mode (str, optional): Mode for reduction for categorical embedding layer ('mean', 'sum', 'none'). Defaults to 'mean'.
|
|
96
|
+
use_classical_positional_encoder (bool, optional): Flag to use classical positional encoding or using embedding layer also for the positions. Defaults to False.
|
|
97
|
+
emb_dim (int, optional): Dimension of categorical embeddings. Defaults to 16.
|
|
98
|
+
optim (Union[str, None], optional): Optimizer type. Defaults to None.
|
|
99
|
+
optim_config (dict, optional): Configuration for the optimizer. Defaults to None.
|
|
100
|
+
scheduler_config (dict, optional): Configuration for the learning rate scheduler. Defaults to None.
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
AssertionError: If the number of quantiles is not equal to 3 when quantiles are provided.
|
|
104
|
+
AssertionError: If the number of output channels is not 1 for classification tasks.
|
|
105
|
+
"""
|
|
106
|
+
beauty_string('V2','block',True)
|
|
107
|
+
super(Base, self).__init__()
|
|
108
|
+
self.save_hyperparameters(logger=False)
|
|
109
|
+
self.count_epoch = 0
|
|
110
|
+
self.initialize = False
|
|
111
|
+
self.train_loss_epoch = -100.0
|
|
112
|
+
self.verbose = verbose
|
|
113
|
+
self.name = self.__class__.__name__
|
|
114
|
+
self.train_epoch_metrics = []
|
|
115
|
+
self.validation_epoch_metrics = []
|
|
116
|
+
|
|
117
|
+
self.use_quantiles = True if len(quantiles)>0 else False
|
|
118
|
+
self.quantiles = quantiles
|
|
119
|
+
self.optim = optim
|
|
120
|
+
self.optim_config = optim_config
|
|
121
|
+
self.scheduler_config = scheduler_config
|
|
122
|
+
self.loss_type = loss_type
|
|
123
|
+
self.persistence_weight = persistence_weight
|
|
124
|
+
self.use_classical_positional_encoder = use_classical_positional_encoder
|
|
125
|
+
self.reduction_mode = reduction_mode
|
|
126
|
+
self.past_steps = past_steps
|
|
127
|
+
self.future_steps = future_steps
|
|
128
|
+
self.embs_past = embs_past
|
|
129
|
+
self.embs_fut = embs_fut
|
|
130
|
+
self.past_channels = past_channels
|
|
131
|
+
self.future_channels = future_channels
|
|
132
|
+
self.emb_dim = emb_dim
|
|
133
|
+
self.out_channels = out_channels
|
|
134
|
+
self.n_classes = n_classes
|
|
135
|
+
if n_classes==0:
|
|
136
|
+
self.is_classification = False
|
|
137
|
+
if len(self.quantiles)>0:
|
|
138
|
+
assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
139
|
+
self.use_quantiles = True
|
|
140
|
+
self.mul = len(self.quantiles)
|
|
141
|
+
self.loss = QuantileLossMO(quantiles)
|
|
142
|
+
else:
|
|
143
|
+
self.use_quantiles = False
|
|
144
|
+
self.mul = 1
|
|
145
|
+
if self.loss_type == 'mse':
|
|
146
|
+
self.loss = nn.MSELoss()
|
|
147
|
+
else:
|
|
148
|
+
self.loss = nn.L1Loss()
|
|
149
|
+
else:
|
|
150
|
+
self.is_classification = True
|
|
151
|
+
self.use_quantiles = False
|
|
152
|
+
self.mul = n_classes
|
|
153
|
+
self.loss = torch.nn.CrossEntropyLoss()
|
|
154
|
+
assert self.out_channels==1, "Classification require only one channel"
|
|
155
|
+
|
|
156
|
+
self.future_steps = future_steps
|
|
157
|
+
|
|
158
|
+
beauty_string(self.description,'info',True)
|
|
159
|
+
@abstractmethod
|
|
160
|
+
def forward(self, batch:dict)-> torch.tensor:
|
|
161
|
+
"""Forlward method used during the training loop
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
batch (dict): the batch structure. The keys are:
|
|
165
|
+
y : the target variable(s). This is always present
|
|
166
|
+
x_num_past: the numerical past variables. This is always present
|
|
167
|
+
x_num_future: the numerical future variables
|
|
168
|
+
x_cat_past: the categorical past variables
|
|
169
|
+
x_cat_future: the categorical future variables
|
|
170
|
+
idx_target: index of target features in the past array
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
torch.tensor: output of the mode;
|
|
175
|
+
"""
|
|
176
|
+
return None
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def inference(self, batch:dict)->torch.tensor:
|
|
181
|
+
"""Usually it is ok to return the output of the forward method but sometimes not (e.g. RNN)
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
batch (dict): batch
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
torch.tensor: result
|
|
188
|
+
"""
|
|
189
|
+
return self(batch)
|
|
190
|
+
|
|
191
|
+
def configure_optimizers(self):
|
|
192
|
+
"""
|
|
193
|
+
Each model has optim_config and scheduler_config
|
|
194
|
+
|
|
195
|
+
:meta private:
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
self.has_sam_optim = False
|
|
199
|
+
if self.optim_config is None:
|
|
200
|
+
self.optim_config = {'lr': 5e-05}
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
if self.optim is None:
|
|
204
|
+
optimizer = optim.Adam(self.parameters(), **self.optim_config)
|
|
205
|
+
self.initialize = True
|
|
206
|
+
|
|
207
|
+
else:
|
|
208
|
+
if self.initialize is False:
|
|
209
|
+
if self.optim=='SAM':
|
|
210
|
+
self.has_sam_optim = True
|
|
211
|
+
self.automatic_optimization = False
|
|
212
|
+
self.my_step = 0
|
|
213
|
+
|
|
214
|
+
else:
|
|
215
|
+
self.optim = eval(self.optim)
|
|
216
|
+
self.has_sam_optim = False
|
|
217
|
+
self.automatic_optimization = True
|
|
218
|
+
|
|
219
|
+
beauty_string(self.optim,'',self.verbose)
|
|
220
|
+
if self.has_sam_optim:
|
|
221
|
+
optimizer = SAM(self.parameters(), base_optimizer=torch.optim.Adam, **self.optim_config)
|
|
222
|
+
else:
|
|
223
|
+
optimizer = self.optim(self.parameters(), **self.optim_config)
|
|
224
|
+
beauty_string(optimizer,'',self.verbose)
|
|
225
|
+
self.initialize = True
|
|
226
|
+
self.lr = self.optim_config['lr']
|
|
227
|
+
if self.scheduler_config is not None:
|
|
228
|
+
scheduler = StepLR(optimizer,**self.scheduler_config)
|
|
229
|
+
return [optimizer], [scheduler]
|
|
230
|
+
else:
|
|
231
|
+
return optimizer
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def training_step(self, batch, batch_idx):
|
|
235
|
+
"""
|
|
236
|
+
pythotrch lightening stuff
|
|
237
|
+
|
|
238
|
+
:meta private:
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
#loss = self.compute_loss(batch,y_hat)
|
|
242
|
+
#import pdb
|
|
243
|
+
#pdb.set_trace()
|
|
244
|
+
|
|
245
|
+
if self.has_sam_optim:
|
|
246
|
+
|
|
247
|
+
opt = self.optimizers()
|
|
248
|
+
def closure():
|
|
249
|
+
opt.zero_grad()
|
|
250
|
+
y_hat = self(batch)
|
|
251
|
+
loss = self.compute_loss(batch,y_hat)
|
|
252
|
+
self.manual_backward(loss)
|
|
253
|
+
return loss
|
|
254
|
+
|
|
255
|
+
opt.step(closure)
|
|
256
|
+
y_hat = self(batch)
|
|
257
|
+
loss = self.compute_loss(batch,y_hat)
|
|
258
|
+
|
|
259
|
+
#opt.first_step(zero_grad=True)
|
|
260
|
+
|
|
261
|
+
#y_hat = self(batch)
|
|
262
|
+
#loss = self.compute_loss(batch, y_hat)
|
|
263
|
+
#self.my_step+=1
|
|
264
|
+
#self.manual_backward(loss,retain_graph=True)
|
|
265
|
+
#opt.second_step(zero_grad=True)
|
|
266
|
+
#self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
|
|
267
|
+
#self.log("global_step", self.my_step, on_step=True) # Correct way to log
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
#self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.increment("optimizer")
|
|
271
|
+
else:
|
|
272
|
+
y_hat = self(batch)
|
|
273
|
+
loss = self.compute_loss(batch,y_hat)
|
|
274
|
+
return loss
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def validation_step(self, batch, batch_idx):
|
|
278
|
+
"""
|
|
279
|
+
pythotrch lightening stuff
|
|
280
|
+
|
|
281
|
+
:meta private:
|
|
282
|
+
"""
|
|
283
|
+
y_hat = self(batch)
|
|
284
|
+
if batch_idx==0:
|
|
285
|
+
if self.use_quantiles:
|
|
286
|
+
idx = 1
|
|
287
|
+
else:
|
|
288
|
+
idx = 0
|
|
289
|
+
#track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
|
|
290
|
+
|
|
291
|
+
if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
|
|
292
|
+
|
|
293
|
+
for i in range(batch['y'].shape[2]):
|
|
294
|
+
real = batch['y'][0,:,i].cpu().detach().numpy()
|
|
295
|
+
pred = y_hat[0,:,i,idx].cpu().detach().numpy()
|
|
296
|
+
fig, ax = plt.subplots(figsize=(7,5))
|
|
297
|
+
ax.plot(real,'o-',label='real')
|
|
298
|
+
ax.plot(pred,'o-',label='pred')
|
|
299
|
+
ax.legend()
|
|
300
|
+
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
301
|
+
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
302
|
+
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
303
|
+
|
|
304
|
+
return self.compute_loss(batch,y_hat)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def validation_epoch_end(self, outs):
|
|
308
|
+
"""
|
|
309
|
+
pythotrch lightening stuff
|
|
310
|
+
|
|
311
|
+
:meta private:
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
loss = torch.stack(outs).mean()
|
|
315
|
+
self.log("val_loss", loss.item(),sync_dist=True)
|
|
316
|
+
beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {loss.item():.4f}','info',self.verbose)
|
|
317
|
+
|
|
318
|
+
def training_epoch_end(self, outs):
|
|
319
|
+
"""
|
|
320
|
+
pythotrch lightening stuff
|
|
321
|
+
|
|
322
|
+
:meta private:
|
|
323
|
+
"""
|
|
324
|
+
|
|
325
|
+
loss = sum(outs['loss'] for outs in outs) / len(outs)
|
|
326
|
+
self.log("train_loss", loss.item(),sync_dist=True)
|
|
327
|
+
self.count_epoch+=1
|
|
328
|
+
|
|
329
|
+
self.train_loss_epoch = loss.item()
|
|
330
|
+
|
|
331
|
+
def compute_loss(self,batch,y_hat):
|
|
332
|
+
"""
|
|
333
|
+
custom loss calculation
|
|
334
|
+
|
|
335
|
+
:meta private:
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
if self.use_quantiles is False:
|
|
339
|
+
initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
|
|
340
|
+
else:
|
|
341
|
+
initial_loss = self.loss(y_hat, batch['y'])
|
|
342
|
+
x = batch['x_num_past'].to(self.device)
|
|
343
|
+
idx_target = batch['idx_target'][0]
|
|
344
|
+
x_start = x[:,-1,idx_target].unsqueeze(1)
|
|
345
|
+
y_persistence = x_start.repeat(1,self.future_steps,1)
|
|
346
|
+
|
|
347
|
+
##generally you want to work without quantile loss
|
|
348
|
+
if self.use_quantiles is False:
|
|
349
|
+
x = y_hat[:,:,:,0]
|
|
350
|
+
else:
|
|
351
|
+
x = y_hat[:,:,:,1]
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
if self.loss_type == 'linear_penalization':
|
|
355
|
+
persistence_error = (2.0-10.0*torch.clamp( torch.abs((y_persistence-x)/(0.001+torch.abs(y_persistence))),min=0.0,max=max(0.05,0.1*(1+np.log10(self.persistence_weight) ))))
|
|
356
|
+
loss = torch.mean(torch.abs(x- batch['y'])*persistence_error)
|
|
357
|
+
|
|
358
|
+
if self.loss_type == 'mda':
|
|
359
|
+
#import pdb
|
|
360
|
+
#pdb.set_trace()
|
|
361
|
+
mda = (1-torch.mean( torch.sign(torch.diff(x,axis=1))*torch.sign(torch.diff(batch['y'],axis=1))))
|
|
362
|
+
loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1).flatten()) + self.persistence_weight*mda/10
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
elif self.loss_type == 'exponential_penalization':
|
|
367
|
+
weights = (1+self.persistence_weight*torch.exp(-torch.abs(y_persistence-x)))
|
|
368
|
+
loss = torch.mean(torch.abs(x- batch['y'])*weights)
|
|
369
|
+
|
|
370
|
+
elif self.loss_type=='sinkhorn':
|
|
371
|
+
sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction='mean')
|
|
372
|
+
loss = sinkhorn.compute(x,batch['y'])
|
|
373
|
+
|
|
374
|
+
elif self.loss_type == 'additive_iv':
|
|
375
|
+
std = torch.sqrt(torch.var(batch['y'], dim=(1))+ 1e-8) ##--> BSxChannel
|
|
376
|
+
x_std = torch.sqrt(torch.var(x, dim=(1))+ 1e-8)
|
|
377
|
+
loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1).flatten() + self.persistence_weight*torch.abs(x_std-std).flatten())
|
|
378
|
+
|
|
379
|
+
elif self.loss_type == 'multiplicative_iv':
|
|
380
|
+
std = torch.sqrt(torch.var(batch['y'], dim=(1))+ 1e-8) ##--> BSxChannel
|
|
381
|
+
x_std = torch.sqrt(torch.var(x, dim=(1))+ 1e-8)
|
|
382
|
+
if self.persistence_weight>0:
|
|
383
|
+
loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1)*torch.abs(x_std-std))
|
|
384
|
+
else:
|
|
385
|
+
loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1))
|
|
386
|
+
elif self.loss_type=='global_iv':
|
|
387
|
+
std_real = torch.sqrt(torch.var(batch['y'], dim=(0,1)))
|
|
388
|
+
std_predict = torch.sqrt(torch.var(x, dim=(0,1)))
|
|
389
|
+
loss = initial_loss + self.persistence_weight*torch.abs(std_real-std_predict).mean()
|
|
390
|
+
|
|
391
|
+
elif self.loss_type=='smape':
|
|
392
|
+
loss = torch.mean(2*torch.abs(x-batch['y']) / (torch.abs(x)+torch.abs(batch['y'])))
|
|
393
|
+
|
|
394
|
+
elif self.loss_type=='triplet':
|
|
395
|
+
loss_fn = torch.nn.TripletMarginLoss(margin=0.01, p=1.0,swap=False)
|
|
396
|
+
loss = initial_loss + self.persistence_weight*loss_fn(x, batch['y'], y_persistence)
|
|
397
|
+
|
|
398
|
+
elif self.loss_type=='high_order':
|
|
399
|
+
loss = initial_loss
|
|
400
|
+
for i in range(2,5):
|
|
401
|
+
mom_real = standardize_momentum( batch['y'],i)
|
|
402
|
+
mom_pred = standardize_momentum(x,i)
|
|
403
|
+
|
|
404
|
+
mom_loss = torch.abs(mom_real-mom_pred).mean()
|
|
405
|
+
loss+=self.persistence_weight*mom_loss
|
|
406
|
+
|
|
407
|
+
elif self.loss_type=='dilated':
|
|
408
|
+
#BxLxCxMUL
|
|
409
|
+
if self.persistence_weight==0.1:
|
|
410
|
+
alpha = 0.25
|
|
411
|
+
if self.persistence_weight==1:
|
|
412
|
+
alpha = 0.5
|
|
413
|
+
else:
|
|
414
|
+
alpha =0.75
|
|
415
|
+
alpha = self.persistence_weight
|
|
416
|
+
gamma = 0.01
|
|
417
|
+
loss = 0
|
|
418
|
+
##no multichannel here
|
|
419
|
+
for i in range(y_hat.shape[2]):
|
|
420
|
+
##error here
|
|
421
|
+
|
|
422
|
+
loss+= dilate_loss( batch['y'][:,:,i:i+1],x[:,:,i:i+1], alpha, gamma, y_hat.device)
|
|
423
|
+
|
|
424
|
+
elif self.loss_type=='huber':
|
|
425
|
+
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight/10)
|
|
426
|
+
#loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
|
|
427
|
+
if self.use_quantiles is False:
|
|
428
|
+
x = y_hat[:,:,:,0]
|
|
429
|
+
else:
|
|
430
|
+
x = y_hat[:,:,:,1]
|
|
431
|
+
BS = x.shape[0]
|
|
432
|
+
loss = loss(y_hat.reshape(BS,-1), batch['y'].reshape(BS,-1))
|
|
433
|
+
|
|
434
|
+
else:
|
|
435
|
+
loss = initial_loss
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
return loss
|