dsipts 1.1.6__py3-none-any.whl → 1.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +2 -2
- dsipts/data_structure/data_structure.py +2 -0
- dsipts/models/DilatedConv.py +3 -20
- dsipts/models/RNN.py +1 -1
- dsipts/models/Simple.py +113 -0
- dsipts/models/base.py +36 -19
- dsipts/models/base_v2.py +40 -22
- {dsipts-1.1.6.dist-info → dsipts-1.1.8.dist-info}/METADATA +11 -10
- {dsipts-1.1.6.dist-info → dsipts-1.1.8.dist-info}/RECORD +11 -10
- {dsipts-1.1.6.dist-info → dsipts-1.1.8.dist-info}/WHEEL +0 -0
- {dsipts-1.1.6.dist-info → dsipts-1.1.8.dist-info}/top_level.txt +0 -0
dsipts/__init__.py
CHANGED
|
@@ -25,7 +25,7 @@ from .models.TimeXER import TimeXER
|
|
|
25
25
|
from .models.TTM import TTM
|
|
26
26
|
from .models.Samformer import Samformer
|
|
27
27
|
from .models.Duet import Duet
|
|
28
|
-
|
|
28
|
+
from .models.Simple import Simple
|
|
29
29
|
try:
|
|
30
30
|
import lightning.pytorch as pl
|
|
31
31
|
from .models.base_v2 import Base
|
|
@@ -44,5 +44,5 @@ __all__ = [
|
|
|
44
44
|
"RNN", "LinearTS", "Persistent", "D3VAE", "DilatedConv", "TFT",
|
|
45
45
|
"Informer", "VVA", "VQVAEA", "CrossFormer", "Autoformer", "PatchTST",
|
|
46
46
|
"Diffusion", "DilatedConvED", "TIDE", "ITransformer", "TimeXER",
|
|
47
|
-
"TTM", "Samformer", "Duet", "Base"
|
|
47
|
+
"TTM", "Samformer", "Duet", "Base", "Simple"
|
|
48
48
|
]
|
|
@@ -800,6 +800,7 @@ class TimeSeries():
|
|
|
800
800
|
callbacks=[checkpoint_callback,mc],
|
|
801
801
|
auto_lr_find=auto_lr_find,
|
|
802
802
|
accelerator=accelerator,
|
|
803
|
+
log_every_n_steps=1,
|
|
803
804
|
devices=devices,
|
|
804
805
|
strategy=strategy,
|
|
805
806
|
enable_progress_bar=False,
|
|
@@ -813,6 +814,7 @@ class TimeSeries():
|
|
|
813
814
|
callbacks=[checkpoint_callback,mc],
|
|
814
815
|
strategy='auto',
|
|
815
816
|
devices=devices,
|
|
817
|
+
log_every_n_steps=5,
|
|
816
818
|
enable_progress_bar=False,
|
|
817
819
|
precision=precision,
|
|
818
820
|
gradient_clip_val=gradient_clip_val,
|
dsipts/models/DilatedConv.py
CHANGED
|
@@ -231,28 +231,11 @@ class DilatedConv(Base):
|
|
|
231
231
|
activation(),
|
|
232
232
|
nn.Linear(hidden_RNN//4,1)))
|
|
233
233
|
|
|
234
|
+
self.return_additional_loss = True
|
|
234
235
|
|
|
235
236
|
|
|
236
237
|
|
|
237
238
|
|
|
238
|
-
def training_step(self, batch, batch_idx):
|
|
239
|
-
"""
|
|
240
|
-
pythotrch lightening stuff
|
|
241
|
-
|
|
242
|
-
:meta private:
|
|
243
|
-
"""
|
|
244
|
-
y_hat,score = self(batch)
|
|
245
|
-
return self.compute_loss(batch,y_hat)#+torch.abs(score-self.glu_percentage)*loss/5.0 ##TODO investigating
|
|
246
|
-
|
|
247
|
-
def validation_step(self, batch, batch_idx):
|
|
248
|
-
"""
|
|
249
|
-
pythotrch lightening stuff
|
|
250
|
-
|
|
251
|
-
:meta private:
|
|
252
|
-
"""
|
|
253
|
-
y_hat,score = self(batch)
|
|
254
|
-
return self.compute_loss(batch,y_hat)#+torch.abs(score-self.glu_percentage)*loss/5.0 ##TODO investigating
|
|
255
|
-
|
|
256
239
|
def forward(self, batch):
|
|
257
240
|
"""It is mandatory to implement this method
|
|
258
241
|
|
|
@@ -332,11 +315,11 @@ class DilatedConv(Base):
|
|
|
332
315
|
res = res.reshape(B,self.future_steps,-1,self.mul)
|
|
333
316
|
if self.remove_last:
|
|
334
317
|
res+=x_start.unsqueeze(1)
|
|
335
|
-
|
|
318
|
+
|
|
336
319
|
|
|
337
320
|
return res, score
|
|
338
321
|
|
|
339
322
|
def inference(self, batch:dict)->torch.tensor:
|
|
340
|
-
|
|
323
|
+
|
|
341
324
|
res, score = self(batch)
|
|
342
325
|
return res
|
dsipts/models/RNN.py
CHANGED
|
@@ -16,7 +16,7 @@ from ..data_structure.utils import beauty_string
|
|
|
16
16
|
from .utils import get_scope
|
|
17
17
|
from .xlstm.xLSTM import xLSTM
|
|
18
18
|
from .utils import Embedding_cat_variables
|
|
19
|
-
|
|
19
|
+
torch.autograd.set_detect_anomaly(True)
|
|
20
20
|
|
|
21
21
|
class MyBN(nn.Module):
|
|
22
22
|
def __init__(self,channels):
|
dsipts/models/Simple.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
|
|
2
|
+
## Copyright 2022 DLinear Authors (https://github.com/cure-lab/LTSF-Linear/tree/main?tab=Apache-2.0-1-ov-file#readme)
|
|
3
|
+
## Code modified for align the notation and the batch generation
|
|
4
|
+
## extended to all present in informer, autoformer folder
|
|
5
|
+
|
|
6
|
+
from torch import nn
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import lightning.pytorch as pl
|
|
11
|
+
from .base_v2 import Base
|
|
12
|
+
OLD_PL = False
|
|
13
|
+
except:
|
|
14
|
+
import pytorch_lightning as pl
|
|
15
|
+
OLD_PL = True
|
|
16
|
+
from .base import Base
|
|
17
|
+
from .utils import QuantileLossMO, get_activation
|
|
18
|
+
from typing import List, Union
|
|
19
|
+
from ..data_structure.utils import beauty_string
|
|
20
|
+
from .utils import get_scope
|
|
21
|
+
from .utils import Embedding_cat_variables
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Simple(Base):
|
|
27
|
+
handle_multivariate = True
|
|
28
|
+
handle_future_covariates = True
|
|
29
|
+
handle_categorical_variables = True
|
|
30
|
+
handle_quantile_loss = True
|
|
31
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
32
|
+
description+='\n THE SIMPLE IMPLEMENTATION DOES NOT USE CATEGORICAL NOR FUTURE VARIABLES'
|
|
33
|
+
|
|
34
|
+
def __init__(self,
|
|
35
|
+
|
|
36
|
+
hidden_size:int,
|
|
37
|
+
dropout_rate:float=0.1,
|
|
38
|
+
activation:str='torch.nn.ReLU',
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
**kwargs)->None:
|
|
42
|
+
|
|
43
|
+
super().__init__(**kwargs)
|
|
44
|
+
|
|
45
|
+
if activation == 'torch.nn.SELU':
|
|
46
|
+
beauty_string('SELU do not require BN','info',self.verbose)
|
|
47
|
+
use_bn = False
|
|
48
|
+
|
|
49
|
+
if isinstance(activation, str):
|
|
50
|
+
activation = get_activation(activation)
|
|
51
|
+
else:
|
|
52
|
+
beauty_string('There is a bug in pytorch lightening, the constructior is called twice','info',self.verbose)
|
|
53
|
+
|
|
54
|
+
self.save_hyperparameters(logger=False)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
59
|
+
self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
60
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
61
|
+
emb_fut_out_channel = self.emb_fut.output_channels
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
self.linear = (nn.Sequential(nn.Linear(emb_past_out_channel*self.past_steps+emb_fut_out_channel*self.future_steps+self.past_steps*self.past_channels+self.future_channels*self.future_steps,hidden_size),
|
|
68
|
+
activation(),nn.Dropout(dropout_rate),
|
|
69
|
+
nn.Linear(hidden_size,self.out_channels*self.future_steps*self.mul)))
|
|
70
|
+
|
|
71
|
+
def forward(self, batch):
|
|
72
|
+
|
|
73
|
+
x = batch['x_num_past'].to(self.device)
|
|
74
|
+
|
|
75
|
+
BS = x.shape[0]
|
|
76
|
+
if 'x_cat_future' in batch.keys():
|
|
77
|
+
emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
|
|
78
|
+
else:
|
|
79
|
+
emb_fut = self.emb_fut(BS,None)
|
|
80
|
+
if 'x_cat_past' in batch.keys():
|
|
81
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
82
|
+
else:
|
|
83
|
+
emb_past = self.emb_past(BS,None)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
if 'x_num_future' in batch.keys():
|
|
89
|
+
x_future = batch['x_num_future'].to(self.device)
|
|
90
|
+
else:
|
|
91
|
+
x_future = None
|
|
92
|
+
|
|
93
|
+
tmp = [x,emb_past]
|
|
94
|
+
tot_past = torch.cat(tmp,2).flatten(1)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
tmp = [emb_fut]
|
|
99
|
+
|
|
100
|
+
if x_future is not None:
|
|
101
|
+
tmp.append(x_future)
|
|
102
|
+
|
|
103
|
+
tot_future = torch.cat(tmp,2).flatten(1)
|
|
104
|
+
tot = torch.cat([tot_past,tot_future],1)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
res = self.linear(tot)
|
|
108
|
+
res = res.reshape(BS,self.future_steps,-1,self.mul)
|
|
109
|
+
|
|
110
|
+
##
|
|
111
|
+
|
|
112
|
+
return res
|
|
113
|
+
|
dsipts/models/base.py
CHANGED
|
@@ -154,7 +154,7 @@ class Base(pl.LightningModule):
|
|
|
154
154
|
assert self.out_channels==1, "Classification require only one channel"
|
|
155
155
|
|
|
156
156
|
self.future_steps = future_steps
|
|
157
|
-
|
|
157
|
+
self.return_additional_loss = False
|
|
158
158
|
beauty_string(self.description,'info',True)
|
|
159
159
|
@abstractmethod
|
|
160
160
|
def forward(self, batch:dict)-> torch.tensor:
|
|
@@ -247,14 +247,22 @@ class Base(pl.LightningModule):
|
|
|
247
247
|
opt = self.optimizers()
|
|
248
248
|
def closure():
|
|
249
249
|
opt.zero_grad()
|
|
250
|
-
|
|
251
|
-
|
|
250
|
+
if self.return_additional_loss:
|
|
251
|
+
y_hat,score = self(batch)
|
|
252
|
+
loss = self.compute_loss(batch,y_hat) + score
|
|
253
|
+
else:
|
|
254
|
+
y_hat = self(batch)
|
|
255
|
+
loss = self.compute_loss(batch,y_hat)
|
|
252
256
|
self.manual_backward(loss)
|
|
253
257
|
return loss
|
|
254
258
|
|
|
255
259
|
opt.step(closure)
|
|
256
|
-
|
|
257
|
-
|
|
260
|
+
if self.return_additional_loss:
|
|
261
|
+
y_hat,score = self(batch)
|
|
262
|
+
loss = self.compute_loss(batch,y_hat)+score
|
|
263
|
+
else:
|
|
264
|
+
y_hat = self(batch)
|
|
265
|
+
loss = self.compute_loss(batch,y_hat)
|
|
258
266
|
|
|
259
267
|
#opt.first_step(zero_grad=True)
|
|
260
268
|
|
|
@@ -269,8 +277,14 @@ class Base(pl.LightningModule):
|
|
|
269
277
|
|
|
270
278
|
#self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.increment("optimizer")
|
|
271
279
|
else:
|
|
272
|
-
|
|
273
|
-
|
|
280
|
+
if self.return_additional_loss:
|
|
281
|
+
y_hat,score = self(batch)
|
|
282
|
+
loss = self.compute_loss(batch,y_hat)+score
|
|
283
|
+
else:
|
|
284
|
+
y_hat = self(batch)
|
|
285
|
+
loss = self.compute_loss(batch,y_hat)
|
|
286
|
+
|
|
287
|
+
self.train_epoch_metrics.append(loss.item())
|
|
274
288
|
return loss
|
|
275
289
|
|
|
276
290
|
|
|
@@ -280,7 +294,11 @@ class Base(pl.LightningModule):
|
|
|
280
294
|
|
|
281
295
|
:meta private:
|
|
282
296
|
"""
|
|
283
|
-
|
|
297
|
+
if self.return_additional_loss:
|
|
298
|
+
y_hat,score = self(batch)
|
|
299
|
+
else:
|
|
300
|
+
y_hat = self(batch)
|
|
301
|
+
score = 0
|
|
284
302
|
if batch_idx==0:
|
|
285
303
|
if self.use_quantiles:
|
|
286
304
|
idx = 1
|
|
@@ -301,7 +319,7 @@ class Base(pl.LightningModule):
|
|
|
301
319
|
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
302
320
|
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
303
321
|
|
|
304
|
-
return self.compute_loss(batch,y_hat)
|
|
322
|
+
return self.compute_loss(batch,y_hat)+score
|
|
305
323
|
|
|
306
324
|
|
|
307
325
|
def validation_epoch_end(self, outs):
|
|
@@ -310,8 +328,12 @@ class Base(pl.LightningModule):
|
|
|
310
328
|
|
|
311
329
|
:meta private:
|
|
312
330
|
"""
|
|
313
|
-
|
|
314
|
-
|
|
331
|
+
if len(outs)==0:
|
|
332
|
+
loss = 10000
|
|
333
|
+
beauty_string(f'THIS IS A BUG, It should be polulated','info',self.verbose)
|
|
334
|
+
else:
|
|
335
|
+
loss = torch.stack(outs).mean()
|
|
336
|
+
|
|
315
337
|
self.log("val_loss", loss.item(),sync_dist=True)
|
|
316
338
|
beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {loss.item():.4f}','info',self.verbose)
|
|
317
339
|
|
|
@@ -406,12 +428,7 @@ class Base(pl.LightningModule):
|
|
|
406
428
|
|
|
407
429
|
elif self.loss_type=='dilated':
|
|
408
430
|
#BxLxCxMUL
|
|
409
|
-
|
|
410
|
-
alpha = 0.25
|
|
411
|
-
if self.persistence_weight==1:
|
|
412
|
-
alpha = 0.5
|
|
413
|
-
else:
|
|
414
|
-
alpha =0.75
|
|
431
|
+
|
|
415
432
|
alpha = self.persistence_weight
|
|
416
433
|
gamma = 0.01
|
|
417
434
|
loss = 0
|
|
@@ -422,8 +439,8 @@ class Base(pl.LightningModule):
|
|
|
422
439
|
loss+= dilate_loss( batch['y'][:,:,i:i+1],x[:,:,i:i+1], alpha, gamma, y_hat.device)
|
|
423
440
|
|
|
424
441
|
elif self.loss_type=='huber':
|
|
425
|
-
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight
|
|
426
|
-
|
|
442
|
+
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
|
|
443
|
+
|
|
427
444
|
if self.use_quantiles is False:
|
|
428
445
|
x = y_hat[:,:,:,0]
|
|
429
446
|
else:
|
dsipts/models/base_v2.py
CHANGED
|
@@ -157,7 +157,7 @@ class Base(pl.LightningModule):
|
|
|
157
157
|
|
|
158
158
|
|
|
159
159
|
self.future_steps = future_steps
|
|
160
|
-
|
|
160
|
+
self.return_additional_loss = False
|
|
161
161
|
beauty_string(self.description,'info',True)
|
|
162
162
|
@abstractmethod
|
|
163
163
|
def forward(self, batch:dict)-> torch.tensor:
|
|
@@ -250,14 +250,22 @@ class Base(pl.LightningModule):
|
|
|
250
250
|
opt = self.optimizers()
|
|
251
251
|
def closure():
|
|
252
252
|
opt.zero_grad()
|
|
253
|
-
|
|
254
|
-
|
|
253
|
+
if self.return_additional_loss:
|
|
254
|
+
y_hat,score = self(batch)
|
|
255
|
+
loss = self.compute_loss(batch,y_hat) + score
|
|
256
|
+
else:
|
|
257
|
+
y_hat = self(batch)
|
|
258
|
+
loss = self.compute_loss(batch,y_hat)
|
|
255
259
|
self.manual_backward(loss)
|
|
256
260
|
return loss
|
|
257
261
|
|
|
258
262
|
opt.step(closure)
|
|
259
|
-
|
|
260
|
-
|
|
263
|
+
if self.return_additional_loss:
|
|
264
|
+
y_hat,score = self(batch)
|
|
265
|
+
loss = self.compute_loss(batch,y_hat)+score
|
|
266
|
+
else:
|
|
267
|
+
y_hat = self(batch)
|
|
268
|
+
loss = self.compute_loss(batch,y_hat)
|
|
261
269
|
|
|
262
270
|
#opt.first_step(zero_grad=True)
|
|
263
271
|
|
|
@@ -272,8 +280,12 @@ class Base(pl.LightningModule):
|
|
|
272
280
|
|
|
273
281
|
#self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.increment("optimizer")
|
|
274
282
|
else:
|
|
275
|
-
|
|
276
|
-
|
|
283
|
+
if self.return_additional_loss:
|
|
284
|
+
y_hat,score = self(batch)
|
|
285
|
+
loss = self.compute_loss(batch,y_hat)+score
|
|
286
|
+
else:
|
|
287
|
+
y_hat = self(batch)
|
|
288
|
+
loss = self.compute_loss(batch,y_hat)
|
|
277
289
|
|
|
278
290
|
self.train_epoch_metrics.append(loss.item())
|
|
279
291
|
return loss
|
|
@@ -285,7 +297,12 @@ class Base(pl.LightningModule):
|
|
|
285
297
|
|
|
286
298
|
:meta private:
|
|
287
299
|
"""
|
|
288
|
-
|
|
300
|
+
|
|
301
|
+
if self.return_additional_loss:
|
|
302
|
+
y_hat,score = self(batch)
|
|
303
|
+
else:
|
|
304
|
+
y_hat = self(batch)
|
|
305
|
+
score = 0
|
|
289
306
|
if batch_idx==0:
|
|
290
307
|
if self.use_quantiles:
|
|
291
308
|
idx = 1
|
|
@@ -305,7 +322,7 @@ class Base(pl.LightningModule):
|
|
|
305
322
|
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
306
323
|
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
307
324
|
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
308
|
-
self.validation_epoch_metrics.append(self.compute_loss(batch,y_hat))
|
|
325
|
+
self.validation_epoch_metrics.append(self.compute_loss(batch,y_hat)+score)
|
|
309
326
|
return
|
|
310
327
|
|
|
311
328
|
|
|
@@ -315,7 +332,12 @@ class Base(pl.LightningModule):
|
|
|
315
332
|
|
|
316
333
|
:meta private:
|
|
317
334
|
"""
|
|
318
|
-
|
|
335
|
+
|
|
336
|
+
if len(self.validation_epoch_metrics)==0:
|
|
337
|
+
avg = 10000
|
|
338
|
+
beauty_string(f'THIS IS A BUG, It should be polulated','info',self.verbose)
|
|
339
|
+
else:
|
|
340
|
+
avg = torch.stack(self.validation_epoch_metrics).mean()
|
|
319
341
|
self.validation_epoch_metrics = []
|
|
320
342
|
self.log("val_loss", avg,sync_dist=True)
|
|
321
343
|
beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
|
|
@@ -327,7 +349,11 @@ class Base(pl.LightningModule):
|
|
|
327
349
|
|
|
328
350
|
:meta private:
|
|
329
351
|
"""
|
|
330
|
-
|
|
352
|
+
if len(self.train_epoch_metrics)==0:
|
|
353
|
+
avg = 0
|
|
354
|
+
beauty_string(f'THIS IS A BUG, It should be polulated','info',self.verbose)
|
|
355
|
+
else:
|
|
356
|
+
avg = np.stack(self.train_epoch_metrics).mean()
|
|
331
357
|
self.log("train_loss", avg,sync_dist=True)
|
|
332
358
|
self.count_epoch+=1
|
|
333
359
|
self.train_epoch_metrics = []
|
|
@@ -411,24 +437,16 @@ class Base(pl.LightningModule):
|
|
|
411
437
|
|
|
412
438
|
elif self.loss_type=='dilated':
|
|
413
439
|
#BxLxCxMUL
|
|
414
|
-
|
|
415
|
-
alpha = 0.25
|
|
416
|
-
if self.persistence_weight==1:
|
|
417
|
-
alpha = 0.5
|
|
418
|
-
else:
|
|
419
|
-
alpha =0.75
|
|
440
|
+
|
|
420
441
|
alpha = self.persistence_weight
|
|
421
442
|
gamma = 0.01
|
|
422
443
|
loss = 0
|
|
423
444
|
##no multichannel here
|
|
424
|
-
for i in range(y_hat.shape[2]):
|
|
425
|
-
##error here
|
|
426
|
-
|
|
445
|
+
for i in range(y_hat.shape[2]):
|
|
427
446
|
loss+= dilate_loss( batch['y'][:,:,i:i+1],x[:,:,i:i+1], alpha, gamma, y_hat.device)
|
|
428
447
|
|
|
429
448
|
elif self.loss_type=='huber':
|
|
430
|
-
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight
|
|
431
|
-
#loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
|
|
449
|
+
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
|
|
432
450
|
if self.use_quantiles is False:
|
|
433
451
|
x = y_hat[:,:,:,0]
|
|
434
452
|
else:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dsipts
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.8
|
|
4
4
|
Summary: Unified library for timeseries modelling
|
|
5
5
|
Author-email: Andrea Gobbi <agobbi@fbk.eu>
|
|
6
6
|
Project-URL: Homepage, https://github.com/DSIP-FBK/DSIPTS
|
|
@@ -128,10 +128,14 @@ or attention based models:
|
|
|
128
128
|
## Install
|
|
129
129
|
Clone the repo (gitlab or github)
|
|
130
130
|
The library is structured to work with [uv](https://github.com/astral-sh/uv). After installing `uv` just run
|
|
131
|
-
```
|
|
131
|
+
```bash
|
|
132
132
|
uv pip install .
|
|
133
133
|
```
|
|
134
|
-
|
|
134
|
+
You can install also the package from pip (be sure that the python version is less than 3.12, still sperimental):
|
|
135
|
+
```bash
|
|
136
|
+
uv venv --python 3.11
|
|
137
|
+
uv pip install dsipts
|
|
138
|
+
```
|
|
135
139
|
|
|
136
140
|
|
|
137
141
|
## For developers
|
|
@@ -352,6 +356,9 @@ loss, quantile loss, MDA and a couple of experimental losses for minimizing the
|
|
|
352
356
|
# Bash experiment
|
|
353
357
|
Most of the time you want to train the models in a cluster with a GPU and command line training procedure can help speedup the process. DSIPTS leverages on OmegaConf-Hydra to to this and in the folder `bash_examples` you can find an examples. Please read the documentation [here](/bash_examples/README.md)
|
|
354
358
|
|
|
359
|
+
## Losses
|
|
360
|
+
|
|
361
|
+
- `dilated`: `persistence_weight` between 0 and 1
|
|
355
362
|
|
|
356
363
|
|
|
357
364
|
# Modifiers
|
|
@@ -362,13 +369,7 @@ The VVA model is composed by two steps: the first is a clusterting procedure tha
|
|
|
362
369
|
- **inverse_transform**: the output of the model are reverted to the original shape. In the VVA model the centroids are used for reconstruct the predicted timeseries.
|
|
363
370
|
|
|
364
371
|
|
|
365
|
-
|
|
366
|
-
You can find the documentation [here](https://dsip.pages.fbk.eu/dsip_dlresearch/timeseries/):
|
|
367
|
-
or in the folder `docs/_build/html/index.html`
|
|
368
|
-
If yon need to generate the documentation after some modification just run:
|
|
369
|
-
```
|
|
370
|
-
./make_doc.sh
|
|
371
|
-
```
|
|
372
|
+
|
|
372
373
|
|
|
373
374
|
For user only: be sure that the the CI file has pages enabled, see [public pages](https://roneo.org/en/gitlab-public-pages-private-repo/)
|
|
374
375
|
|
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
dsipts/__init__.py,sha256=
|
|
1
|
+
dsipts/__init__.py,sha256=UWmrBJ2LLoRCKLOyTBSJAw9n31o8ZwNjLoRAax5Wll8,1694
|
|
2
2
|
dsipts/data_management/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
dsipts/data_management/monash.py,sha256=aZxq9FbIH6IsU8Lwou1hAokXjgOAK-wdl2VAeFg2k4M,13075
|
|
4
4
|
dsipts/data_management/public_datasets.py,sha256=yXFzOZZ-X0ZG1DoqVU-zFmEGVMc2033YDQhRgYxY8ws,6793
|
|
5
5
|
dsipts/data_structure/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
-
dsipts/data_structure/data_structure.py,sha256=
|
|
6
|
+
dsipts/data_structure/data_structure.py,sha256=87VtKelx2EPoddrVYcja9dO5rQqaS83vZlQB_NY54PI,58994
|
|
7
7
|
dsipts/data_structure/modifiers.py,sha256=qlry9dfw8pEE0GrvgwROZJkJ6oPpUnjEHPIG5qIetss,7948
|
|
8
8
|
dsipts/data_structure/utils.py,sha256=QwfKPZgSy6DIw5n6ztOdPJIAnzo4EnlMTgRbpiWnyko,6593
|
|
9
9
|
dsipts/models/Autoformer.py,sha256=ddGT3L9T4gAXNJHx1TsuYZy7j63Anyr0rkqqXaOoSu4,8447
|
|
10
10
|
dsipts/models/CrossFormer.py,sha256=iO64L3S01jxuWA9dmm8FsK1WRvBIXbZ0PQ2tZlEQg4w,6481
|
|
11
11
|
dsipts/models/D3VAE.py,sha256=NstHIniNteBRrkfL7SJ3-bJEl3l3IIxoSxavRV3j16U,6857
|
|
12
12
|
dsipts/models/Diffusion.py,sha256=pUujnrdeSSkj4jC1RORbcptt03KpuCsGVwg414o4LPg,40733
|
|
13
|
-
dsipts/models/DilatedConv.py,sha256=
|
|
13
|
+
dsipts/models/DilatedConv.py,sha256=_c0NvFuT3vbYmo9A8cQchGo1XVb0qOpzBprNEkkAgiE,14292
|
|
14
14
|
dsipts/models/DilatedConvED.py,sha256=fXk1-EWiRC5J_VIepTjYKya_D02SlEAkyiJcCjhW_XU,14004
|
|
15
15
|
dsipts/models/Duet.py,sha256=EharWHT_r7tEYIk7BkozVLPZ0xptE5mmQmeFGm3uBsA,7628
|
|
16
16
|
dsipts/models/ITransformer.py,sha256=jO8wxLaC06Wgu4GncrFFTISv3pVyfFLLhQvbEOYsz6Y,7368
|
|
@@ -18,8 +18,9 @@ dsipts/models/Informer.py,sha256=ByJ00qGk12ONFF7NZWAACzxxRb5UXcu5wpkGMYX9Cq4,692
|
|
|
18
18
|
dsipts/models/LinearTS.py,sha256=B0-Sz4POwUyl-PN2ssSx8L-ZHgwrQQPcMmreyvSS47U,9104
|
|
19
19
|
dsipts/models/PatchTST.py,sha256=Z7DM1Kw5Ym8Hh9ywj0j9RuFtKaz_yVZmKFIYafjceM8,9061
|
|
20
20
|
dsipts/models/Persistent.py,sha256=URwyaBb0M7zbPXSGMImtHlwC9XCy-OquFCwfWvn3P70,1249
|
|
21
|
-
dsipts/models/RNN.py,sha256=
|
|
21
|
+
dsipts/models/RNN.py,sha256=GbH6QyrGhvQg-Hnt_0l3YSnhNHE0Hl0AWsZpdQUAzug,9633
|
|
22
22
|
dsipts/models/Samformer.py,sha256=s61Hi1o9iuw-KgSBPfiE80oJcK1j2fUA6N9f5BJgKJc,5551
|
|
23
|
+
dsipts/models/Simple.py,sha256=K82E88A62NhV_7U9Euu2cn3Q8P287HDR7eIy7VqgwbM,3909
|
|
23
24
|
dsipts/models/TFT.py,sha256=JO2-AKIUag7bfm9Oeo4KmGfdYZJbzQBHPDqGVg0WUZI,13830
|
|
24
25
|
dsipts/models/TIDE.py,sha256=i8qXac2gImEVgE2X6cNxqW5kuQP3rzWMlQNdgJbNmKM,13033
|
|
25
26
|
dsipts/models/TTM.py,sha256=WpCiTN0qX3JFO6xgPLedoqMKXUC2pQpNAe9ee-Rw89Q,10602
|
|
@@ -27,8 +28,8 @@ dsipts/models/TimeXER.py,sha256=aCg0003LxYZzqZWyWugpbW_iOybcdHN4OH6_v77qp4o,7056
|
|
|
27
28
|
dsipts/models/VQVAEA.py,sha256=sNJi8UZh-10qEIKcZK3SzhlOFUUjvqjoglzeZBFaeZM,13789
|
|
28
29
|
dsipts/models/VVA.py,sha256=BnPkJ0Nzue0oShSHZVRNlf5RvT0Iwtf9bx19vLB9Nn0,11939
|
|
29
30
|
dsipts/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
|
-
dsipts/models/base.py,sha256=
|
|
31
|
-
dsipts/models/base_v2.py,sha256=
|
|
31
|
+
dsipts/models/base.py,sha256=mIsEUkuyj_2MlYEvH97PPD790DrS0PQw4UCiWN8uqKI,18159
|
|
32
|
+
dsipts/models/base_v2.py,sha256=jjlX5fIw2stCx5J3i3xFTgzYmCX-n8Lf4-4cLoq-diQ,18426
|
|
32
33
|
dsipts/models/utils.py,sha256=H1lr1lukDk7FNyXXTJh217tyTBsBW8hVDQ6jL9oev7I,21765
|
|
33
34
|
dsipts/models/autoformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
35
|
dsipts/models/autoformer/layers.py,sha256=xHt8V1lKdD1cIvgxXdDbI_EqOz4zgOQ6LP8l7M1pAxM,13276
|
|
@@ -75,7 +76,7 @@ dsipts/models/vva/minigpt.py,sha256=bg0JddqSD322uxSGexen3nPXL_hGTsk3vNLR62d7-w8,
|
|
|
75
76
|
dsipts/models/vva/vqvae.py,sha256=RzCQ_M9xBprp7_x20dSV3EQqlO0FjPUGWV-qdyKrQsM,19680
|
|
76
77
|
dsipts/models/xlstm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
77
78
|
dsipts/models/xlstm/xLSTM.py,sha256=ZKZZmffmIq1Vb71CR4GSyM8viqVx-u0FChxhcNgHub8,10081
|
|
78
|
-
dsipts-1.1.
|
|
79
|
-
dsipts-1.1.
|
|
80
|
-
dsipts-1.1.
|
|
81
|
-
dsipts-1.1.
|
|
79
|
+
dsipts-1.1.8.dist-info/METADATA,sha256=fObwUSnqEBaCA_sDxvmOnfKsmb-Mu9gOrITzl3Tp4qQ,24794
|
|
80
|
+
dsipts-1.1.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
81
|
+
dsipts-1.1.8.dist-info/top_level.txt,sha256=i6o0rf5ScFwZK21E89dSKjVNjUBkrEQpn0-Vij43748,7
|
|
82
|
+
dsipts-1.1.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|