dsipts 1.1.7__py3-none-any.whl → 1.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +2 -2
- dsipts/data_structure/data_structure.py +1 -1
- dsipts/models/RNN.py +1 -1
- dsipts/models/Simple.py +113 -0
- dsipts/models/base.py +3 -8
- dsipts/models/base_v2.py +3 -11
- {dsipts-1.1.7.dist-info → dsipts-1.1.8.dist-info}/METADATA +5 -8
- {dsipts-1.1.7.dist-info → dsipts-1.1.8.dist-info}/RECORD +10 -9
- {dsipts-1.1.7.dist-info → dsipts-1.1.8.dist-info}/WHEEL +0 -0
- {dsipts-1.1.7.dist-info → dsipts-1.1.8.dist-info}/top_level.txt +0 -0
dsipts/__init__.py
CHANGED
|
@@ -25,7 +25,7 @@ from .models.TimeXER import TimeXER
|
|
|
25
25
|
from .models.TTM import TTM
|
|
26
26
|
from .models.Samformer import Samformer
|
|
27
27
|
from .models.Duet import Duet
|
|
28
|
-
|
|
28
|
+
from .models.Simple import Simple
|
|
29
29
|
try:
|
|
30
30
|
import lightning.pytorch as pl
|
|
31
31
|
from .models.base_v2 import Base
|
|
@@ -44,5 +44,5 @@ __all__ = [
|
|
|
44
44
|
"RNN", "LinearTS", "Persistent", "D3VAE", "DilatedConv", "TFT",
|
|
45
45
|
"Informer", "VVA", "VQVAEA", "CrossFormer", "Autoformer", "PatchTST",
|
|
46
46
|
"Diffusion", "DilatedConvED", "TIDE", "ITransformer", "TimeXER",
|
|
47
|
-
"TTM", "Samformer", "Duet", "Base"
|
|
47
|
+
"TTM", "Samformer", "Duet", "Base", "Simple"
|
|
48
48
|
]
|
dsipts/models/RNN.py
CHANGED
|
@@ -16,7 +16,7 @@ from ..data_structure.utils import beauty_string
|
|
|
16
16
|
from .utils import get_scope
|
|
17
17
|
from .xlstm.xLSTM import xLSTM
|
|
18
18
|
from .utils import Embedding_cat_variables
|
|
19
|
-
|
|
19
|
+
torch.autograd.set_detect_anomaly(True)
|
|
20
20
|
|
|
21
21
|
class MyBN(nn.Module):
|
|
22
22
|
def __init__(self,channels):
|
dsipts/models/Simple.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
|
|
2
|
+
## Copyright 2022 DLinear Authors (https://github.com/cure-lab/LTSF-Linear/tree/main?tab=Apache-2.0-1-ov-file#readme)
|
|
3
|
+
## Code modified for align the notation and the batch generation
|
|
4
|
+
## extended to all present in informer, autoformer folder
|
|
5
|
+
|
|
6
|
+
from torch import nn
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import lightning.pytorch as pl
|
|
11
|
+
from .base_v2 import Base
|
|
12
|
+
OLD_PL = False
|
|
13
|
+
except:
|
|
14
|
+
import pytorch_lightning as pl
|
|
15
|
+
OLD_PL = True
|
|
16
|
+
from .base import Base
|
|
17
|
+
from .utils import QuantileLossMO, get_activation
|
|
18
|
+
from typing import List, Union
|
|
19
|
+
from ..data_structure.utils import beauty_string
|
|
20
|
+
from .utils import get_scope
|
|
21
|
+
from .utils import Embedding_cat_variables
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Simple(Base):
|
|
27
|
+
handle_multivariate = True
|
|
28
|
+
handle_future_covariates = True
|
|
29
|
+
handle_categorical_variables = True
|
|
30
|
+
handle_quantile_loss = True
|
|
31
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
32
|
+
description+='\n THE SIMPLE IMPLEMENTATION DOES NOT USE CATEGORICAL NOR FUTURE VARIABLES'
|
|
33
|
+
|
|
34
|
+
def __init__(self,
|
|
35
|
+
|
|
36
|
+
hidden_size:int,
|
|
37
|
+
dropout_rate:float=0.1,
|
|
38
|
+
activation:str='torch.nn.ReLU',
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
**kwargs)->None:
|
|
42
|
+
|
|
43
|
+
super().__init__(**kwargs)
|
|
44
|
+
|
|
45
|
+
if activation == 'torch.nn.SELU':
|
|
46
|
+
beauty_string('SELU do not require BN','info',self.verbose)
|
|
47
|
+
use_bn = False
|
|
48
|
+
|
|
49
|
+
if isinstance(activation, str):
|
|
50
|
+
activation = get_activation(activation)
|
|
51
|
+
else:
|
|
52
|
+
beauty_string('There is a bug in pytorch lightening, the constructior is called twice','info',self.verbose)
|
|
53
|
+
|
|
54
|
+
self.save_hyperparameters(logger=False)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
59
|
+
self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
60
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
61
|
+
emb_fut_out_channel = self.emb_fut.output_channels
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
self.linear = (nn.Sequential(nn.Linear(emb_past_out_channel*self.past_steps+emb_fut_out_channel*self.future_steps+self.past_steps*self.past_channels+self.future_channels*self.future_steps,hidden_size),
|
|
68
|
+
activation(),nn.Dropout(dropout_rate),
|
|
69
|
+
nn.Linear(hidden_size,self.out_channels*self.future_steps*self.mul)))
|
|
70
|
+
|
|
71
|
+
def forward(self, batch):
|
|
72
|
+
|
|
73
|
+
x = batch['x_num_past'].to(self.device)
|
|
74
|
+
|
|
75
|
+
BS = x.shape[0]
|
|
76
|
+
if 'x_cat_future' in batch.keys():
|
|
77
|
+
emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
|
|
78
|
+
else:
|
|
79
|
+
emb_fut = self.emb_fut(BS,None)
|
|
80
|
+
if 'x_cat_past' in batch.keys():
|
|
81
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
82
|
+
else:
|
|
83
|
+
emb_past = self.emb_past(BS,None)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
if 'x_num_future' in batch.keys():
|
|
89
|
+
x_future = batch['x_num_future'].to(self.device)
|
|
90
|
+
else:
|
|
91
|
+
x_future = None
|
|
92
|
+
|
|
93
|
+
tmp = [x,emb_past]
|
|
94
|
+
tot_past = torch.cat(tmp,2).flatten(1)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
tmp = [emb_fut]
|
|
99
|
+
|
|
100
|
+
if x_future is not None:
|
|
101
|
+
tmp.append(x_future)
|
|
102
|
+
|
|
103
|
+
tot_future = torch.cat(tmp,2).flatten(1)
|
|
104
|
+
tot = torch.cat([tot_past,tot_future],1)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
res = self.linear(tot)
|
|
108
|
+
res = res.reshape(BS,self.future_steps,-1,self.mul)
|
|
109
|
+
|
|
110
|
+
##
|
|
111
|
+
|
|
112
|
+
return res
|
|
113
|
+
|
dsipts/models/base.py
CHANGED
|
@@ -428,12 +428,7 @@ class Base(pl.LightningModule):
|
|
|
428
428
|
|
|
429
429
|
elif self.loss_type=='dilated':
|
|
430
430
|
#BxLxCxMUL
|
|
431
|
-
|
|
432
|
-
alpha = 0.25
|
|
433
|
-
if self.persistence_weight==1:
|
|
434
|
-
alpha = 0.5
|
|
435
|
-
else:
|
|
436
|
-
alpha =0.75
|
|
431
|
+
|
|
437
432
|
alpha = self.persistence_weight
|
|
438
433
|
gamma = 0.01
|
|
439
434
|
loss = 0
|
|
@@ -444,8 +439,8 @@ class Base(pl.LightningModule):
|
|
|
444
439
|
loss+= dilate_loss( batch['y'][:,:,i:i+1],x[:,:,i:i+1], alpha, gamma, y_hat.device)
|
|
445
440
|
|
|
446
441
|
elif self.loss_type=='huber':
|
|
447
|
-
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight
|
|
448
|
-
|
|
442
|
+
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
|
|
443
|
+
|
|
449
444
|
if self.use_quantiles is False:
|
|
450
445
|
x = y_hat[:,:,:,0]
|
|
451
446
|
else:
|
dsipts/models/base_v2.py
CHANGED
|
@@ -437,24 +437,16 @@ class Base(pl.LightningModule):
|
|
|
437
437
|
|
|
438
438
|
elif self.loss_type=='dilated':
|
|
439
439
|
#BxLxCxMUL
|
|
440
|
-
|
|
441
|
-
alpha = 0.25
|
|
442
|
-
if self.persistence_weight==1:
|
|
443
|
-
alpha = 0.5
|
|
444
|
-
else:
|
|
445
|
-
alpha =0.75
|
|
440
|
+
|
|
446
441
|
alpha = self.persistence_weight
|
|
447
442
|
gamma = 0.01
|
|
448
443
|
loss = 0
|
|
449
444
|
##no multichannel here
|
|
450
|
-
for i in range(y_hat.shape[2]):
|
|
451
|
-
##error here
|
|
452
|
-
|
|
445
|
+
for i in range(y_hat.shape[2]):
|
|
453
446
|
loss+= dilate_loss( batch['y'][:,:,i:i+1],x[:,:,i:i+1], alpha, gamma, y_hat.device)
|
|
454
447
|
|
|
455
448
|
elif self.loss_type=='huber':
|
|
456
|
-
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight
|
|
457
|
-
#loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
|
|
449
|
+
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
|
|
458
450
|
if self.use_quantiles is False:
|
|
459
451
|
x = y_hat[:,:,:,0]
|
|
460
452
|
else:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dsipts
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.8
|
|
4
4
|
Summary: Unified library for timeseries modelling
|
|
5
5
|
Author-email: Andrea Gobbi <agobbi@fbk.eu>
|
|
6
6
|
Project-URL: Homepage, https://github.com/DSIP-FBK/DSIPTS
|
|
@@ -356,6 +356,9 @@ loss, quantile loss, MDA and a couple of experimental losses for minimizing the
|
|
|
356
356
|
# Bash experiment
|
|
357
357
|
Most of the time you want to train the models in a cluster with a GPU and command line training procedure can help speedup the process. DSIPTS leverages on OmegaConf-Hydra to to this and in the folder `bash_examples` you can find an examples. Please read the documentation [here](/bash_examples/README.md)
|
|
358
358
|
|
|
359
|
+
## Losses
|
|
360
|
+
|
|
361
|
+
- `dilated`: `persistence_weight` between 0 and 1
|
|
359
362
|
|
|
360
363
|
|
|
361
364
|
# Modifiers
|
|
@@ -366,13 +369,7 @@ The VVA model is composed by two steps: the first is a clusterting procedure tha
|
|
|
366
369
|
- **inverse_transform**: the output of the model are reverted to the original shape. In the VVA model the centroids are used for reconstruct the predicted timeseries.
|
|
367
370
|
|
|
368
371
|
|
|
369
|
-
|
|
370
|
-
You can find the documentation [here](https://dsip.pages.fbk.eu/dsip_dlresearch/timeseries/):
|
|
371
|
-
or in the folder `docs/_build/html/index.html`
|
|
372
|
-
If yon need to generate the documentation after some modification just run:
|
|
373
|
-
```
|
|
374
|
-
./make_doc.sh
|
|
375
|
-
```
|
|
372
|
+
|
|
376
373
|
|
|
377
374
|
For user only: be sure that the the CI file has pages enabled, see [public pages](https://roneo.org/en/gitlab-public-pages-private-repo/)
|
|
378
375
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
dsipts/__init__.py,sha256=
|
|
1
|
+
dsipts/__init__.py,sha256=UWmrBJ2LLoRCKLOyTBSJAw9n31o8ZwNjLoRAax5Wll8,1694
|
|
2
2
|
dsipts/data_management/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
dsipts/data_management/monash.py,sha256=aZxq9FbIH6IsU8Lwou1hAokXjgOAK-wdl2VAeFg2k4M,13075
|
|
4
4
|
dsipts/data_management/public_datasets.py,sha256=yXFzOZZ-X0ZG1DoqVU-zFmEGVMc2033YDQhRgYxY8ws,6793
|
|
5
5
|
dsipts/data_structure/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
-
dsipts/data_structure/data_structure.py,sha256=
|
|
6
|
+
dsipts/data_structure/data_structure.py,sha256=87VtKelx2EPoddrVYcja9dO5rQqaS83vZlQB_NY54PI,58994
|
|
7
7
|
dsipts/data_structure/modifiers.py,sha256=qlry9dfw8pEE0GrvgwROZJkJ6oPpUnjEHPIG5qIetss,7948
|
|
8
8
|
dsipts/data_structure/utils.py,sha256=QwfKPZgSy6DIw5n6ztOdPJIAnzo4EnlMTgRbpiWnyko,6593
|
|
9
9
|
dsipts/models/Autoformer.py,sha256=ddGT3L9T4gAXNJHx1TsuYZy7j63Anyr0rkqqXaOoSu4,8447
|
|
@@ -18,8 +18,9 @@ dsipts/models/Informer.py,sha256=ByJ00qGk12ONFF7NZWAACzxxRb5UXcu5wpkGMYX9Cq4,692
|
|
|
18
18
|
dsipts/models/LinearTS.py,sha256=B0-Sz4POwUyl-PN2ssSx8L-ZHgwrQQPcMmreyvSS47U,9104
|
|
19
19
|
dsipts/models/PatchTST.py,sha256=Z7DM1Kw5Ym8Hh9ywj0j9RuFtKaz_yVZmKFIYafjceM8,9061
|
|
20
20
|
dsipts/models/Persistent.py,sha256=URwyaBb0M7zbPXSGMImtHlwC9XCy-OquFCwfWvn3P70,1249
|
|
21
|
-
dsipts/models/RNN.py,sha256=
|
|
21
|
+
dsipts/models/RNN.py,sha256=GbH6QyrGhvQg-Hnt_0l3YSnhNHE0Hl0AWsZpdQUAzug,9633
|
|
22
22
|
dsipts/models/Samformer.py,sha256=s61Hi1o9iuw-KgSBPfiE80oJcK1j2fUA6N9f5BJgKJc,5551
|
|
23
|
+
dsipts/models/Simple.py,sha256=K82E88A62NhV_7U9Euu2cn3Q8P287HDR7eIy7VqgwbM,3909
|
|
23
24
|
dsipts/models/TFT.py,sha256=JO2-AKIUag7bfm9Oeo4KmGfdYZJbzQBHPDqGVg0WUZI,13830
|
|
24
25
|
dsipts/models/TIDE.py,sha256=i8qXac2gImEVgE2X6cNxqW5kuQP3rzWMlQNdgJbNmKM,13033
|
|
25
26
|
dsipts/models/TTM.py,sha256=WpCiTN0qX3JFO6xgPLedoqMKXUC2pQpNAe9ee-Rw89Q,10602
|
|
@@ -27,8 +28,8 @@ dsipts/models/TimeXER.py,sha256=aCg0003LxYZzqZWyWugpbW_iOybcdHN4OH6_v77qp4o,7056
|
|
|
27
28
|
dsipts/models/VQVAEA.py,sha256=sNJi8UZh-10qEIKcZK3SzhlOFUUjvqjoglzeZBFaeZM,13789
|
|
28
29
|
dsipts/models/VVA.py,sha256=BnPkJ0Nzue0oShSHZVRNlf5RvT0Iwtf9bx19vLB9Nn0,11939
|
|
29
30
|
dsipts/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
|
-
dsipts/models/base.py,sha256=
|
|
31
|
-
dsipts/models/base_v2.py,sha256=
|
|
31
|
+
dsipts/models/base.py,sha256=mIsEUkuyj_2MlYEvH97PPD790DrS0PQw4UCiWN8uqKI,18159
|
|
32
|
+
dsipts/models/base_v2.py,sha256=jjlX5fIw2stCx5J3i3xFTgzYmCX-n8Lf4-4cLoq-diQ,18426
|
|
32
33
|
dsipts/models/utils.py,sha256=H1lr1lukDk7FNyXXTJh217tyTBsBW8hVDQ6jL9oev7I,21765
|
|
33
34
|
dsipts/models/autoformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
35
|
dsipts/models/autoformer/layers.py,sha256=xHt8V1lKdD1cIvgxXdDbI_EqOz4zgOQ6LP8l7M1pAxM,13276
|
|
@@ -75,7 +76,7 @@ dsipts/models/vva/minigpt.py,sha256=bg0JddqSD322uxSGexen3nPXL_hGTsk3vNLR62d7-w8,
|
|
|
75
76
|
dsipts/models/vva/vqvae.py,sha256=RzCQ_M9xBprp7_x20dSV3EQqlO0FjPUGWV-qdyKrQsM,19680
|
|
76
77
|
dsipts/models/xlstm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
77
78
|
dsipts/models/xlstm/xLSTM.py,sha256=ZKZZmffmIq1Vb71CR4GSyM8viqVx-u0FChxhcNgHub8,10081
|
|
78
|
-
dsipts-1.1.
|
|
79
|
-
dsipts-1.1.
|
|
80
|
-
dsipts-1.1.
|
|
81
|
-
dsipts-1.1.
|
|
79
|
+
dsipts-1.1.8.dist-info/METADATA,sha256=fObwUSnqEBaCA_sDxvmOnfKsmb-Mu9gOrITzl3Tp4qQ,24794
|
|
80
|
+
dsipts-1.1.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
81
|
+
dsipts-1.1.8.dist-info/top_level.txt,sha256=i6o0rf5ScFwZK21E89dSKjVNjUBkrEQpn0-Vij43748,7
|
|
82
|
+
dsipts-1.1.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|