dsipts 1.1.7__tar.gz → 1.1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- {dsipts-1.1.7 → dsipts-1.1.9}/PKG-INFO +5 -8
- {dsipts-1.1.7 → dsipts-1.1.9}/README.md +4 -7
- {dsipts-1.1.7 → dsipts-1.1.9}/pyproject.toml +1 -1
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/__init__.py +2 -2
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/data_structure/data_structure.py +1 -1
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/RNN.py +1 -1
- dsipts-1.1.9/src/dsipts/models/Simple.py +113 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/base.py +25 -13
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/base_v2.py +28 -16
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/utils.py +73 -1
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts.egg-info/PKG-INFO +5 -8
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts.egg-info/SOURCES.txt +1 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/setup.cfg +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/data_management/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/data_management/monash.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/data_management/public_datasets.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/data_structure/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/data_structure/modifiers.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/data_structure/utils.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/Autoformer.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/CrossFormer.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/D3VAE.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/Diffusion.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/DilatedConv.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/DilatedConvED.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/Duet.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/ITransformer.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/Informer.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/LinearTS.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/PatchTST.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/Persistent.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/Samformer.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/TFT.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/TIDE.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/TTM.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/TimeXER.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/VQVAEA.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/VVA.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/autoformer/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/autoformer/layers.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/crossformer/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/crossformer/attn.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/crossformer/cross_decoder.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/crossformer/cross_embed.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/crossformer/cross_encoder.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/d3vae/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/d3vae/diffusion_process.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/d3vae/embedding.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/d3vae/encoder.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/d3vae/model.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/d3vae/neural_operations.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/d3vae/resnet.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/d3vae/utils.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/duet/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/duet/layers.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/duet/masked.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/informer/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/informer/attn.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/informer/decoder.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/informer/embed.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/informer/encoder.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/itransformer/Embed.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/itransformer/SelfAttention_Family.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/itransformer/Transformer_EncDec.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/itransformer/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/patchtst/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/patchtst/layers.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/samformer/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/samformer/utils.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/tft/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/tft/sub_nn.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/timexer/Layers.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/timexer/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/ttm/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/ttm/configuration_tinytimemixer.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/ttm/consts.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/ttm/modeling_tinytimemixer.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/ttm/utils.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/vva/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/vva/minigpt.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/vva/vqvae.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/xlstm/__init__.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts/models/xlstm/xLSTM.py +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts.egg-info/dependency_links.txt +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts.egg-info/requires.txt +0 -0
- {dsipts-1.1.7 → dsipts-1.1.9}/src/dsipts.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dsipts
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.9
|
|
4
4
|
Summary: Unified library for timeseries modelling
|
|
5
5
|
Author-email: Andrea Gobbi <agobbi@fbk.eu>
|
|
6
6
|
Project-URL: Homepage, https://github.com/DSIP-FBK/DSIPTS
|
|
@@ -356,6 +356,9 @@ loss, quantile loss, MDA and a couple of experimental losses for minimizing the
|
|
|
356
356
|
# Bash experiment
|
|
357
357
|
Most of the time you want to train the models in a cluster with a GPU and command line training procedure can help speedup the process. DSIPTS leverages on OmegaConf-Hydra to to this and in the folder `bash_examples` you can find an examples. Please read the documentation [here](/bash_examples/README.md)
|
|
358
358
|
|
|
359
|
+
## Losses
|
|
360
|
+
|
|
361
|
+
- `dilated`: `persistence_weight` between 0 and 1
|
|
359
362
|
|
|
360
363
|
|
|
361
364
|
# Modifiers
|
|
@@ -366,13 +369,7 @@ The VVA model is composed by two steps: the first is a clusterting procedure tha
|
|
|
366
369
|
- **inverse_transform**: the output of the model are reverted to the original shape. In the VVA model the centroids are used for reconstruct the predicted timeseries.
|
|
367
370
|
|
|
368
371
|
|
|
369
|
-
|
|
370
|
-
You can find the documentation [here](https://dsip.pages.fbk.eu/dsip_dlresearch/timeseries/):
|
|
371
|
-
or in the folder `docs/_build/html/index.html`
|
|
372
|
-
If yon need to generate the documentation after some modification just run:
|
|
373
|
-
```
|
|
374
|
-
./make_doc.sh
|
|
375
|
-
```
|
|
372
|
+
|
|
376
373
|
|
|
377
374
|
For user only: be sure that the the CI file has pages enabled, see [public pages](https://roneo.org/en/gitlab-public-pages-private-repo/)
|
|
378
375
|
|
|
@@ -324,6 +324,9 @@ loss, quantile loss, MDA and a couple of experimental losses for minimizing the
|
|
|
324
324
|
# Bash experiment
|
|
325
325
|
Most of the time you want to train the models in a cluster with a GPU and command line training procedure can help speedup the process. DSIPTS leverages on OmegaConf-Hydra to to this and in the folder `bash_examples` you can find an examples. Please read the documentation [here](/bash_examples/README.md)
|
|
326
326
|
|
|
327
|
+
## Losses
|
|
328
|
+
|
|
329
|
+
- `dilated`: `persistence_weight` between 0 and 1
|
|
327
330
|
|
|
328
331
|
|
|
329
332
|
# Modifiers
|
|
@@ -334,13 +337,7 @@ The VVA model is composed by two steps: the first is a clusterting procedure tha
|
|
|
334
337
|
- **inverse_transform**: the output of the model are reverted to the original shape. In the VVA model the centroids are used for reconstruct the predicted timeseries.
|
|
335
338
|
|
|
336
339
|
|
|
337
|
-
|
|
338
|
-
You can find the documentation [here](https://dsip.pages.fbk.eu/dsip_dlresearch/timeseries/):
|
|
339
|
-
or in the folder `docs/_build/html/index.html`
|
|
340
|
-
If yon need to generate the documentation after some modification just run:
|
|
341
|
-
```
|
|
342
|
-
./make_doc.sh
|
|
343
|
-
```
|
|
340
|
+
|
|
344
341
|
|
|
345
342
|
For user only: be sure that the the CI file has pages enabled, see [public pages](https://roneo.org/en/gitlab-public-pages-private-repo/)
|
|
346
343
|
|
|
@@ -25,7 +25,7 @@ from .models.TimeXER import TimeXER
|
|
|
25
25
|
from .models.TTM import TTM
|
|
26
26
|
from .models.Samformer import Samformer
|
|
27
27
|
from .models.Duet import Duet
|
|
28
|
-
|
|
28
|
+
from .models.Simple import Simple
|
|
29
29
|
try:
|
|
30
30
|
import lightning.pytorch as pl
|
|
31
31
|
from .models.base_v2 import Base
|
|
@@ -44,5 +44,5 @@ __all__ = [
|
|
|
44
44
|
"RNN", "LinearTS", "Persistent", "D3VAE", "DilatedConv", "TFT",
|
|
45
45
|
"Informer", "VVA", "VQVAEA", "CrossFormer", "Autoformer", "PatchTST",
|
|
46
46
|
"Diffusion", "DilatedConvED", "TIDE", "ITransformer", "TimeXER",
|
|
47
|
-
"TTM", "Samformer", "Duet", "Base"
|
|
47
|
+
"TTM", "Samformer", "Duet", "Base", "Simple"
|
|
48
48
|
]
|
|
@@ -16,7 +16,7 @@ from ..data_structure.utils import beauty_string
|
|
|
16
16
|
from .utils import get_scope
|
|
17
17
|
from .xlstm.xLSTM import xLSTM
|
|
18
18
|
from .utils import Embedding_cat_variables
|
|
19
|
-
|
|
19
|
+
torch.autograd.set_detect_anomaly(True)
|
|
20
20
|
|
|
21
21
|
class MyBN(nn.Module):
|
|
22
22
|
def __init__(self,channels):
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
|
|
2
|
+
## Copyright 2022 DLinear Authors (https://github.com/cure-lab/LTSF-Linear/tree/main?tab=Apache-2.0-1-ov-file#readme)
|
|
3
|
+
## Code modified for align the notation and the batch generation
|
|
4
|
+
## extended to all present in informer, autoformer folder
|
|
5
|
+
|
|
6
|
+
from torch import nn
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import lightning.pytorch as pl
|
|
11
|
+
from .base_v2 import Base
|
|
12
|
+
OLD_PL = False
|
|
13
|
+
except:
|
|
14
|
+
import pytorch_lightning as pl
|
|
15
|
+
OLD_PL = True
|
|
16
|
+
from .base import Base
|
|
17
|
+
from .utils import QuantileLossMO, get_activation
|
|
18
|
+
from typing import List, Union
|
|
19
|
+
from ..data_structure.utils import beauty_string
|
|
20
|
+
from .utils import get_scope
|
|
21
|
+
from .utils import Embedding_cat_variables
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Simple(Base):
|
|
27
|
+
handle_multivariate = True
|
|
28
|
+
handle_future_covariates = True
|
|
29
|
+
handle_categorical_variables = True
|
|
30
|
+
handle_quantile_loss = True
|
|
31
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
32
|
+
description+='\n THE SIMPLE IMPLEMENTATION DOES NOT USE CATEGORICAL NOR FUTURE VARIABLES'
|
|
33
|
+
|
|
34
|
+
def __init__(self,
|
|
35
|
+
|
|
36
|
+
hidden_size:int,
|
|
37
|
+
dropout_rate:float=0.1,
|
|
38
|
+
activation:str='torch.nn.ReLU',
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
**kwargs)->None:
|
|
42
|
+
|
|
43
|
+
super().__init__(**kwargs)
|
|
44
|
+
|
|
45
|
+
if activation == 'torch.nn.SELU':
|
|
46
|
+
beauty_string('SELU do not require BN','info',self.verbose)
|
|
47
|
+
use_bn = False
|
|
48
|
+
|
|
49
|
+
if isinstance(activation, str):
|
|
50
|
+
activation = get_activation(activation)
|
|
51
|
+
else:
|
|
52
|
+
beauty_string('There is a bug in pytorch lightening, the constructior is called twice','info',self.verbose)
|
|
53
|
+
|
|
54
|
+
self.save_hyperparameters(logger=False)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
59
|
+
self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
60
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
61
|
+
emb_fut_out_channel = self.emb_fut.output_channels
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
self.linear = (nn.Sequential(nn.Linear(emb_past_out_channel*self.past_steps+emb_fut_out_channel*self.future_steps+self.past_steps*self.past_channels+self.future_channels*self.future_steps,hidden_size),
|
|
68
|
+
activation(),nn.Dropout(dropout_rate),
|
|
69
|
+
nn.Linear(hidden_size,self.out_channels*self.future_steps*self.mul)))
|
|
70
|
+
|
|
71
|
+
def forward(self, batch):
|
|
72
|
+
|
|
73
|
+
x = batch['x_num_past'].to(self.device)
|
|
74
|
+
|
|
75
|
+
BS = x.shape[0]
|
|
76
|
+
if 'x_cat_future' in batch.keys():
|
|
77
|
+
emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
|
|
78
|
+
else:
|
|
79
|
+
emb_fut = self.emb_fut(BS,None)
|
|
80
|
+
if 'x_cat_past' in batch.keys():
|
|
81
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
82
|
+
else:
|
|
83
|
+
emb_past = self.emb_past(BS,None)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
if 'x_num_future' in batch.keys():
|
|
89
|
+
x_future = batch['x_num_future'].to(self.device)
|
|
90
|
+
else:
|
|
91
|
+
x_future = None
|
|
92
|
+
|
|
93
|
+
tmp = [x,emb_past]
|
|
94
|
+
tot_past = torch.cat(tmp,2).flatten(1)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
tmp = [emb_fut]
|
|
99
|
+
|
|
100
|
+
if x_future is not None:
|
|
101
|
+
tmp.append(x_future)
|
|
102
|
+
|
|
103
|
+
tot_future = torch.cat(tmp,2).flatten(1)
|
|
104
|
+
tot = torch.cat([tot_past,tot_future],1)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
res = self.linear(tot)
|
|
108
|
+
res = res.reshape(BS,self.future_steps,-1,self.mul)
|
|
109
|
+
|
|
110
|
+
##
|
|
111
|
+
|
|
112
|
+
return res
|
|
113
|
+
|
|
@@ -12,7 +12,7 @@ import numpy as np
|
|
|
12
12
|
from aim import Image
|
|
13
13
|
import matplotlib.pyplot as plt
|
|
14
14
|
from typing import List, Union
|
|
15
|
-
from .utils import QuantileLossMO
|
|
15
|
+
from .utils import QuantileLossMO, CPRS
|
|
16
16
|
import torch.nn as nn
|
|
17
17
|
|
|
18
18
|
def standardize_momentum(x,order):
|
|
@@ -135,10 +135,15 @@ class Base(pl.LightningModule):
|
|
|
135
135
|
if n_classes==0:
|
|
136
136
|
self.is_classification = False
|
|
137
137
|
if len(self.quantiles)>0:
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
138
|
+
if self.loss_type=='cprs':
|
|
139
|
+
self.use_quantiles = False
|
|
140
|
+
self.mul = len(self.quantiles)
|
|
141
|
+
self.loss = CPRS()
|
|
142
|
+
else:
|
|
143
|
+
assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
144
|
+
self.use_quantiles = True
|
|
145
|
+
self.mul = len(self.quantiles)
|
|
146
|
+
self.loss = QuantileLossMO(quantiles)
|
|
142
147
|
else:
|
|
143
148
|
self.use_quantiles = False
|
|
144
149
|
self.mul = 1
|
|
@@ -186,6 +191,10 @@ class Base(pl.LightningModule):
|
|
|
186
191
|
Returns:
|
|
187
192
|
torch.tensor: result
|
|
188
193
|
"""
|
|
194
|
+
if self.loss_type=='cprs':
|
|
195
|
+
tmp = self(batch)
|
|
196
|
+
return tmp.mean(axis=-1)
|
|
197
|
+
|
|
189
198
|
return self(batch)
|
|
190
199
|
|
|
191
200
|
def configure_optimizers(self):
|
|
@@ -357,6 +366,14 @@ class Base(pl.LightningModule):
|
|
|
357
366
|
:meta private:
|
|
358
367
|
"""
|
|
359
368
|
|
|
369
|
+
if self.loss_type=='cprs':
|
|
370
|
+
return self.loss(y_hat,batch['y'])
|
|
371
|
+
|
|
372
|
+
if self.loss_type=='long_lag':
|
|
373
|
+
batch_size,width,n_variables = batch['y'].shape
|
|
374
|
+
tmp = torch.abs(y_hat[:,:,:,0]-batch['y'])*torch.linspace(1,self.persistence_weight,width).view(1,width,1).repeat(batch_size,1,n_variables)
|
|
375
|
+
return tmp.mean()
|
|
376
|
+
|
|
360
377
|
if self.use_quantiles is False:
|
|
361
378
|
initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
|
|
362
379
|
else:
|
|
@@ -428,12 +445,7 @@ class Base(pl.LightningModule):
|
|
|
428
445
|
|
|
429
446
|
elif self.loss_type=='dilated':
|
|
430
447
|
#BxLxCxMUL
|
|
431
|
-
|
|
432
|
-
alpha = 0.25
|
|
433
|
-
if self.persistence_weight==1:
|
|
434
|
-
alpha = 0.5
|
|
435
|
-
else:
|
|
436
|
-
alpha =0.75
|
|
448
|
+
|
|
437
449
|
alpha = self.persistence_weight
|
|
438
450
|
gamma = 0.01
|
|
439
451
|
loss = 0
|
|
@@ -444,8 +456,8 @@ class Base(pl.LightningModule):
|
|
|
444
456
|
loss+= dilate_loss( batch['y'][:,:,i:i+1],x[:,:,i:i+1], alpha, gamma, y_hat.device)
|
|
445
457
|
|
|
446
458
|
elif self.loss_type=='huber':
|
|
447
|
-
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight
|
|
448
|
-
|
|
459
|
+
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
|
|
460
|
+
|
|
449
461
|
if self.use_quantiles is False:
|
|
450
462
|
x = y_hat[:,:,:,0]
|
|
451
463
|
else:
|
|
@@ -12,7 +12,7 @@ import numpy as np
|
|
|
12
12
|
from aim import Image
|
|
13
13
|
import matplotlib.pyplot as plt
|
|
14
14
|
from typing import List, Union
|
|
15
|
-
from .utils import QuantileLossMO
|
|
15
|
+
from .utils import QuantileLossMO, CPRS
|
|
16
16
|
import torch.nn as nn
|
|
17
17
|
|
|
18
18
|
|
|
@@ -137,10 +137,15 @@ class Base(pl.LightningModule):
|
|
|
137
137
|
if n_classes==0:
|
|
138
138
|
self.is_classification = False
|
|
139
139
|
if len(self.quantiles)>0:
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
140
|
+
if self.loss_type=='cprs':
|
|
141
|
+
self.use_quantiles = False
|
|
142
|
+
self.mul = len(self.quantiles)
|
|
143
|
+
self.loss = CPRS()
|
|
144
|
+
else:
|
|
145
|
+
assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
146
|
+
self.use_quantiles = True
|
|
147
|
+
self.mul = len(self.quantiles)
|
|
148
|
+
self.loss = QuantileLossMO(quantiles)
|
|
144
149
|
else:
|
|
145
150
|
self.use_quantiles = False
|
|
146
151
|
self.mul = 1
|
|
@@ -189,6 +194,11 @@ class Base(pl.LightningModule):
|
|
|
189
194
|
Returns:
|
|
190
195
|
torch.tensor: result
|
|
191
196
|
"""
|
|
197
|
+
|
|
198
|
+
if self.loss_type=='cprs':
|
|
199
|
+
tmp = self(batch)
|
|
200
|
+
return tmp.mean(axis=-1)
|
|
201
|
+
|
|
192
202
|
return self(batch)
|
|
193
203
|
|
|
194
204
|
def configure_optimizers(self):
|
|
@@ -365,6 +375,16 @@ class Base(pl.LightningModule):
|
|
|
365
375
|
|
|
366
376
|
:meta private:
|
|
367
377
|
"""
|
|
378
|
+
if self.loss_type=='cprs':
|
|
379
|
+
return self.loss(y_hat,batch['y'])
|
|
380
|
+
|
|
381
|
+
if self.loss_type=='long_lag':
|
|
382
|
+
|
|
383
|
+
batch_size,width,n_variables = batch['y'].shape
|
|
384
|
+
tmp = torch.abs(y_hat[:,:,:,0]-batch['y'])*torch.linspace(1,self.persistence_weight,width).view(1,width,1).repeat(batch_size,1,n_variables)
|
|
385
|
+
return tmp.mean()
|
|
386
|
+
|
|
387
|
+
|
|
368
388
|
|
|
369
389
|
if self.use_quantiles is False:
|
|
370
390
|
initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
|
|
@@ -437,24 +457,16 @@ class Base(pl.LightningModule):
|
|
|
437
457
|
|
|
438
458
|
elif self.loss_type=='dilated':
|
|
439
459
|
#BxLxCxMUL
|
|
440
|
-
|
|
441
|
-
alpha = 0.25
|
|
442
|
-
if self.persistence_weight==1:
|
|
443
|
-
alpha = 0.5
|
|
444
|
-
else:
|
|
445
|
-
alpha =0.75
|
|
460
|
+
|
|
446
461
|
alpha = self.persistence_weight
|
|
447
462
|
gamma = 0.01
|
|
448
463
|
loss = 0
|
|
449
464
|
##no multichannel here
|
|
450
|
-
for i in range(y_hat.shape[2]):
|
|
451
|
-
##error here
|
|
452
|
-
|
|
465
|
+
for i in range(y_hat.shape[2]):
|
|
453
466
|
loss+= dilate_loss( batch['y'][:,:,i:i+1],x[:,:,i:i+1], alpha, gamma, y_hat.device)
|
|
454
467
|
|
|
455
468
|
elif self.loss_type=='huber':
|
|
456
|
-
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight
|
|
457
|
-
#loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
|
|
469
|
+
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
|
|
458
470
|
if self.use_quantiles is False:
|
|
459
471
|
x = y_hat[:,:,:,0]
|
|
460
472
|
else:
|
|
@@ -621,4 +621,76 @@ class Embedding_cat_variables(nn.Module):
|
|
|
621
621
|
emb.append(layer(cat_vars[:, :, index]).unsqueeze(2))
|
|
622
622
|
|
|
623
623
|
cat_n_embd = torch.cat(emb,dim=2)
|
|
624
|
-
return cat_n_embd
|
|
624
|
+
return cat_n_embd
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
class CPRS(nn.Module):
|
|
629
|
+
"""
|
|
630
|
+
Efficient vectorized implementation of Almost Fair CRPS.
|
|
631
|
+
|
|
632
|
+
This version avoids explicit loops and uses broadcasting for better performance
|
|
633
|
+
with large ensembles.
|
|
634
|
+
"""
|
|
635
|
+
|
|
636
|
+
def __init__(self, alpha=0.95, reduction='mean'):
|
|
637
|
+
super().__init__()
|
|
638
|
+
self.alpha = alpha
|
|
639
|
+
self.reduction = reduction
|
|
640
|
+
|
|
641
|
+
def forward(self, y_hat, target, weights=None):
|
|
642
|
+
"""
|
|
643
|
+
Compute the almost fair CRPS loss (efficient version).
|
|
644
|
+
|
|
645
|
+
Args:
|
|
646
|
+
ensemble: Tensor of shape (batch_size, n_members, ...)
|
|
647
|
+
target: Tensor of shape (batch_size, ...)
|
|
648
|
+
weights: Optional per-variable or per-location weights
|
|
649
|
+
|
|
650
|
+
Returns:
|
|
651
|
+
Loss tensor
|
|
652
|
+
"""
|
|
653
|
+
## initial shape BS,width,n_variables,n_members need to go into batch_size, n_members, width, n_variables
|
|
654
|
+
ensemble = y_hat.permute(0,3,1,2)
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
batch_size, n_members = ensemble.shape[:2]
|
|
658
|
+
epsilon = (1 - self.alpha) / n_members
|
|
659
|
+
|
|
660
|
+
# Expand target to match ensemble shape
|
|
661
|
+
target_expanded = target.unsqueeze(1).expand_as(ensemble)
|
|
662
|
+
|
|
663
|
+
# Compute first term: mean absolute error to target
|
|
664
|
+
mae_term = torch.abs(ensemble - target_expanded).mean(dim=1)
|
|
665
|
+
|
|
666
|
+
# Compute second term: pairwise differences between ensemble members
|
|
667
|
+
# Use broadcasting to compute all pairwise differences efficiently
|
|
668
|
+
ensemble_i = ensemble.unsqueeze(2) # (batch, n_members, 1, ...)
|
|
669
|
+
ensemble_j = ensemble.unsqueeze(1) # (batch, 1, n_members, ...)
|
|
670
|
+
|
|
671
|
+
pairwise_diffs = torch.abs(ensemble_i - ensemble_j)
|
|
672
|
+
|
|
673
|
+
# Sum over all pairs (excluding diagonal)
|
|
674
|
+
# Create mask to exclude diagonal (i=j)
|
|
675
|
+
mask = ~torch.eye(n_members, dtype=torch.bool, device=ensemble.device)
|
|
676
|
+
mask = mask.view(1, n_members, n_members, *[1]*(len(ensemble.shape)-2))
|
|
677
|
+
|
|
678
|
+
# Apply mask and compute mean
|
|
679
|
+
pairwise_term = (pairwise_diffs * mask).sum(dim=(1, 2)) / (n_members * (n_members - 1))
|
|
680
|
+
|
|
681
|
+
# Combine terms according to afCRPS formula
|
|
682
|
+
loss = mae_term - (1 - epsilon) * pairwise_term
|
|
683
|
+
|
|
684
|
+
# Apply weights if provided
|
|
685
|
+
if weights is not None:
|
|
686
|
+
loss = loss * weights
|
|
687
|
+
|
|
688
|
+
# Apply reduction
|
|
689
|
+
if self.reduction == 'none':
|
|
690
|
+
return loss
|
|
691
|
+
elif self.reduction == 'sum':
|
|
692
|
+
return loss.sum()
|
|
693
|
+
elif self.reduction == 'mean':
|
|
694
|
+
return loss.mean()
|
|
695
|
+
else:
|
|
696
|
+
raise ValueError(f"Invalid reduction: {self.reduction}")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dsipts
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.9
|
|
4
4
|
Summary: Unified library for timeseries modelling
|
|
5
5
|
Author-email: Andrea Gobbi <agobbi@fbk.eu>
|
|
6
6
|
Project-URL: Homepage, https://github.com/DSIP-FBK/DSIPTS
|
|
@@ -356,6 +356,9 @@ loss, quantile loss, MDA and a couple of experimental losses for minimizing the
|
|
|
356
356
|
# Bash experiment
|
|
357
357
|
Most of the time you want to train the models in a cluster with a GPU and command line training procedure can help speedup the process. DSIPTS leverages on OmegaConf-Hydra to to this and in the folder `bash_examples` you can find an examples. Please read the documentation [here](/bash_examples/README.md)
|
|
358
358
|
|
|
359
|
+
## Losses
|
|
360
|
+
|
|
361
|
+
- `dilated`: `persistence_weight` between 0 and 1
|
|
359
362
|
|
|
360
363
|
|
|
361
364
|
# Modifiers
|
|
@@ -366,13 +369,7 @@ The VVA model is composed by two steps: the first is a clusterting procedure tha
|
|
|
366
369
|
- **inverse_transform**: the output of the model are reverted to the original shape. In the VVA model the centroids are used for reconstruct the predicted timeseries.
|
|
367
370
|
|
|
368
371
|
|
|
369
|
-
|
|
370
|
-
You can find the documentation [here](https://dsip.pages.fbk.eu/dsip_dlresearch/timeseries/):
|
|
371
|
-
or in the folder `docs/_build/html/index.html`
|
|
372
|
-
If yon need to generate the documentation after some modification just run:
|
|
373
|
-
```
|
|
374
|
-
./make_doc.sh
|
|
375
|
-
```
|
|
372
|
+
|
|
376
373
|
|
|
377
374
|
For user only: be sure that the the CI file has pages enabled, see [public pages](https://roneo.org/en/gitlab-public-pages-private-repo/)
|
|
378
375
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|