dsipts 1.1.10__py3-none-any.whl → 1.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dsipts/data_structure/data_structure.py +38 -16
- dsipts/models/ITransformer.py +2 -8
- dsipts/models/TTM.py +58 -165
- dsipts/models/base.py +47 -35
- dsipts/models/base_v2.py +50 -34
- {dsipts-1.1.10.dist-info → dsipts-1.1.11.dist-info}/METADATA +1 -1
- {dsipts-1.1.10.dist-info → dsipts-1.1.11.dist-info}/RECORD +9 -9
- {dsipts-1.1.10.dist-info → dsipts-1.1.11.dist-info}/WHEEL +0 -0
- {dsipts-1.1.10.dist-info → dsipts-1.1.11.dist-info}/top_level.txt +0 -0
|
@@ -35,7 +35,18 @@ from .modifiers import *
|
|
|
35
35
|
from aim.pytorch_lightning import AimLogger
|
|
36
36
|
import time
|
|
37
37
|
|
|
38
|
-
|
|
38
|
+
class DummyScaler():
|
|
39
|
+
def __init__(self):
|
|
40
|
+
pass
|
|
41
|
+
def fit(self,x):
|
|
42
|
+
pass
|
|
43
|
+
def transform(self,x):
|
|
44
|
+
return x
|
|
45
|
+
def inverse_transform(self,x):
|
|
46
|
+
return x
|
|
47
|
+
def fit_transform(self,x):
|
|
48
|
+
return x
|
|
49
|
+
|
|
39
50
|
|
|
40
51
|
pd.options.mode.chained_assignment = None
|
|
41
52
|
log = logging.getLogger(__name__)
|
|
@@ -210,20 +221,23 @@ class TimeSeries():
|
|
|
210
221
|
self.future_variables = []
|
|
211
222
|
self.target_variables = ['signal']
|
|
212
223
|
self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
|
|
213
|
-
|
|
224
|
+
self.num_var = list(np.sort(self.num_var))
|
|
214
225
|
|
|
215
226
|
def enrich(self,dataset,columns):
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
+
try:
|
|
228
|
+
if columns =='hour':
|
|
229
|
+
dataset[columns] = dataset.time.dt.hour
|
|
230
|
+
elif columns=='dow':
|
|
231
|
+
dataset[columns] = dataset.time.dt.weekday
|
|
232
|
+
elif columns=='month':
|
|
233
|
+
dataset[columns] = dataset.time.dt.month
|
|
234
|
+
elif columns=='minute':
|
|
235
|
+
dataset[columns] = dataset.time.dt.minute
|
|
236
|
+
else:
|
|
237
|
+
if columns not in dataset.columns:
|
|
238
|
+
beauty_string(f'I can not automatically enrich column {columns}. Please contact the developers or add it manually to your dataset.','section',True)
|
|
239
|
+
except:
|
|
240
|
+
beauty_string(f'I can not automatically enrich column {columns}. Probably not a temporal index.','section',True)
|
|
227
241
|
|
|
228
242
|
def load_signal(self,data:pd.DataFrame,
|
|
229
243
|
enrich_cat:List[str] = [],
|
|
@@ -300,7 +314,7 @@ class TimeSeries():
|
|
|
300
314
|
if check_past:
|
|
301
315
|
beauty_string('I will update past column adding all target columns, if you want to avoid this beahviour please use check_pass as false','info',self.verbose)
|
|
302
316
|
past_variables = list(set(past_variables).union(set(target_variables)))
|
|
303
|
-
|
|
317
|
+
past_variables = list(np.sort(past_variables))
|
|
304
318
|
self.cat_past_var = cat_past_var
|
|
305
319
|
self.cat_fut_var = cat_fut_var
|
|
306
320
|
|
|
@@ -321,14 +335,18 @@ class TimeSeries():
|
|
|
321
335
|
beauty_string('Categorical {c} already present, it will be added to categorical variable but not call the enriching function','info',self.verbose)
|
|
322
336
|
else:
|
|
323
337
|
self.enrich(dataset,c)
|
|
338
|
+
self.cat_past_var = list(np.sort(self.cat_past_var))
|
|
339
|
+
self.cat_fut_var = list(np.sort(self.cat_fut_var))
|
|
340
|
+
|
|
324
341
|
self.cat_var = list(set(self.cat_past_var+self.cat_fut_var)) ## all categorical data
|
|
325
|
-
|
|
342
|
+
self.cat_var = list(np.sort(self.cat_var))
|
|
326
343
|
self.dataset = dataset
|
|
327
344
|
self.past_variables = past_variables
|
|
328
345
|
self.future_variables = future_variables
|
|
329
346
|
self.target_variables = target_variables
|
|
330
347
|
self.out_vars = len(target_variables)
|
|
331
348
|
self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
|
|
349
|
+
self.num_var = list(np.sort(self.num_var))
|
|
332
350
|
if silly_model:
|
|
333
351
|
beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
|
|
334
352
|
self.future_variables+=self.target_variables
|
|
@@ -665,7 +683,11 @@ class TimeSeries():
|
|
|
665
683
|
#self.model.apply(weight_init_zeros)
|
|
666
684
|
|
|
667
685
|
self.config = config
|
|
668
|
-
|
|
686
|
+
try:
|
|
687
|
+
self.model = torch.compile(self.model)
|
|
688
|
+
except:
|
|
689
|
+
beauty_string('Can not compile the model','block',self.verbose)
|
|
690
|
+
|
|
669
691
|
beauty_string('Setting the model','block',self.verbose)
|
|
670
692
|
beauty_string(model,'',self.verbose)
|
|
671
693
|
|
dsipts/models/ITransformer.py
CHANGED
|
@@ -8,6 +8,8 @@ import numpy as np
|
|
|
8
8
|
from .itransformer.Transformer_EncDec import Encoder, EncoderLayer
|
|
9
9
|
from .itransformer.SelfAttention_Family import FullAttention, AttentionLayer
|
|
10
10
|
from .itransformer.Embed import DataEmbedding_inverted
|
|
11
|
+
from ..data_structure.utils import beauty_string
|
|
12
|
+
from .utils import get_scope,get_activation,Embedding_cat_variables
|
|
11
13
|
|
|
12
14
|
try:
|
|
13
15
|
import lightning.pytorch as pl
|
|
@@ -17,12 +19,6 @@ except:
|
|
|
17
19
|
import pytorch_lightning as pl
|
|
18
20
|
OLD_PL = True
|
|
19
21
|
from .base import Base
|
|
20
|
-
from .utils import QuantileLossMO,Permute, get_activation
|
|
21
|
-
|
|
22
|
-
from typing import List, Union
|
|
23
|
-
from ..data_structure.utils import beauty_string
|
|
24
|
-
from .utils import get_scope
|
|
25
|
-
from .utils import Embedding_cat_variables
|
|
26
22
|
|
|
27
23
|
|
|
28
24
|
|
|
@@ -34,8 +30,6 @@ class ITransformer(Base):
|
|
|
34
30
|
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
35
31
|
|
|
36
32
|
def __init__(self,
|
|
37
|
-
|
|
38
|
-
|
|
39
33
|
# specific params
|
|
40
34
|
hidden_size:int,
|
|
41
35
|
d_model: int,
|
dsipts/models/TTM.py
CHANGED
|
@@ -12,240 +12,133 @@ except:
|
|
|
12
12
|
from .base import Base
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
from typing import List,Union
|
|
17
|
-
|
|
18
|
-
from .utils import QuantileLossMO
|
|
15
|
+
from .ttm.utils import get_model, get_frequency_token, count_parameters
|
|
19
16
|
from ..data_structure.utils import beauty_string
|
|
20
|
-
from .
|
|
21
|
-
|
|
17
|
+
from .utils import get_scope
|
|
22
18
|
|
|
23
19
|
class TTM(Base):
|
|
20
|
+
handle_multivariate = True
|
|
21
|
+
handle_future_covariates = True
|
|
22
|
+
handle_categorical_variables = True
|
|
23
|
+
handle_quantile_loss = True
|
|
24
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
25
|
+
|
|
24
26
|
def __init__(self,
|
|
25
27
|
model_path:str,
|
|
26
|
-
past_steps:int,
|
|
27
|
-
future_steps:int,
|
|
28
|
-
freq_prefix_tuning:bool,
|
|
29
|
-
freq:str,
|
|
30
28
|
prefer_l1_loss:bool, # exog: set true to use l1 loss
|
|
31
29
|
prefer_longer_context:bool,
|
|
32
|
-
loss_type:str,
|
|
33
|
-
num_input_channels,
|
|
34
30
|
prediction_channel_indices,
|
|
35
|
-
|
|
31
|
+
exogenous_channel_indices_cont,
|
|
32
|
+
exogenous_channel_indices_cat,
|
|
36
33
|
decoder_mode,
|
|
34
|
+
freq,
|
|
35
|
+
freq_prefix_tuning,
|
|
37
36
|
fcm_context_length,
|
|
38
37
|
fcm_use_mixer,
|
|
39
38
|
fcm_mix_layers,
|
|
40
39
|
fcm_prepend_past,
|
|
41
40
|
enable_forecast_channel_mixing,
|
|
42
|
-
out_channels:int,
|
|
43
|
-
embs:List[int],
|
|
44
|
-
remove_last = False,
|
|
45
|
-
optim:Union[str,None]=None,
|
|
46
|
-
optim_config:dict=None,
|
|
47
|
-
scheduler_config:dict=None,
|
|
48
|
-
verbose = False,
|
|
49
|
-
use_quantiles=False,
|
|
50
|
-
persistence_weight:float=0.0,
|
|
51
|
-
quantiles:List[int]=[],
|
|
52
41
|
**kwargs)->None:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
model_path (str): _description_
|
|
57
|
-
past_steps (int): _description_
|
|
58
|
-
future_steps (int): _description_
|
|
59
|
-
freq_prefix_tuning (bool): _description_
|
|
60
|
-
freq (str): _description_
|
|
61
|
-
prefer_l1_loss (bool): _description_
|
|
62
|
-
loss_type (str): _description_
|
|
63
|
-
num_input_channels (_type_): _description_
|
|
64
|
-
prediction_channel_indices (_type_): _description_
|
|
65
|
-
exogenous_channel_indices (_type_): _description_
|
|
66
|
-
decoder_mode (_type_): _description_
|
|
67
|
-
fcm_context_length (_type_): _description_
|
|
68
|
-
fcm_use_mixer (_type_): _description_
|
|
69
|
-
fcm_mix_layers (_type_): _description_
|
|
70
|
-
fcm_prepend_past (_type_): _description_
|
|
71
|
-
enable_forecast_channel_mixing (_type_): _description_
|
|
72
|
-
out_channels (int): _description_
|
|
73
|
-
embs (List[int]): _description_
|
|
74
|
-
remove_last (bool, optional): _description_. Defaults to False.
|
|
75
|
-
optim (Union[str,None], optional): _description_. Defaults to None.
|
|
76
|
-
optim_config (dict, optional): _description_. Defaults to None.
|
|
77
|
-
scheduler_config (dict, optional): _description_. Defaults to None.
|
|
78
|
-
verbose (bool, optional): _description_. Defaults to False.
|
|
79
|
-
use_quantiles (bool, optional): _description_. Defaults to False.
|
|
80
|
-
persistence_weight (float, optional): _description_. Defaults to 0.0.
|
|
81
|
-
quantiles (List[int], optional): _description_. Defaults to [].
|
|
82
|
-
"""
|
|
83
|
-
super(TTM, self).__init__(verbose)
|
|
42
|
+
|
|
43
|
+
super().__init__(**kwargs)
|
|
84
44
|
self.save_hyperparameters(logger=False)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
self.
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
self.loss_type = loss_type
|
|
92
|
-
self.remove_last = remove_last
|
|
93
|
-
self.embs = embs
|
|
94
|
-
self.freq = freq
|
|
95
|
-
self.extend_variables = False
|
|
96
|
-
|
|
97
|
-
# NOTE: For Hydra
|
|
98
|
-
prediction_channel_indices = list(prediction_channel_indices)
|
|
99
|
-
exogenous_channel_indices = list(exogenous_channel_indices)
|
|
100
|
-
|
|
101
|
-
if len(quantiles)>0:
|
|
102
|
-
assert len(quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
103
|
-
self.use_quantiles = True
|
|
104
|
-
self.mul = len(quantiles)
|
|
105
|
-
self.loss = QuantileLossMO(quantiles)
|
|
106
|
-
self.extend_variables = True
|
|
107
|
-
if out_channels * 3 != len(prediction_channel_indices):
|
|
108
|
-
prediction_channel_indices, exogenous_channel_indices, num_input_channels = self.__add_quantile_features(prediction_channel_indices,
|
|
109
|
-
exogenous_channel_indices,
|
|
110
|
-
out_channels)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
self.index_fut = list(exogenous_channel_indices_cont)
|
|
48
|
+
|
|
49
|
+
if len(exogenous_channel_indices_cat)>0:
|
|
50
|
+
self.index_fut_cat = (self.past_channels+len(self.embs_past))+list(exogenous_channel_indices_cat)
|
|
111
51
|
else:
|
|
112
|
-
self.
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
elif self.loss_type == 'rmse':
|
|
116
|
-
self.loss = RMSELoss()
|
|
117
|
-
else:
|
|
118
|
-
self.loss = nn.L1Loss()
|
|
119
|
-
|
|
52
|
+
self.index_fut_cat = []
|
|
53
|
+
self.freq = freq
|
|
54
|
+
|
|
120
55
|
self.model = get_model(
|
|
121
56
|
model_path=model_path,
|
|
122
|
-
context_length=past_steps,
|
|
123
|
-
prediction_length=future_steps,
|
|
124
|
-
freq_prefix_tuning=freq_prefix_tuning,
|
|
125
|
-
freq=freq,
|
|
57
|
+
context_length=self.past_steps,
|
|
58
|
+
prediction_length=self.future_steps,
|
|
126
59
|
prefer_l1_loss=prefer_l1_loss,
|
|
127
60
|
prefer_longer_context=prefer_longer_context,
|
|
128
|
-
num_input_channels=
|
|
61
|
+
num_input_channels=self.past_channels+len(self.embs_past), #giusto
|
|
129
62
|
decoder_mode=decoder_mode,
|
|
130
63
|
prediction_channel_indices=list(prediction_channel_indices),
|
|
131
|
-
exogenous_channel_indices=
|
|
64
|
+
exogenous_channel_indices=self.index_fut + self.index_fut_cat,
|
|
132
65
|
fcm_context_length=fcm_context_length,
|
|
133
66
|
fcm_use_mixer=fcm_use_mixer,
|
|
134
67
|
fcm_mix_layers=fcm_mix_layers,
|
|
68
|
+
freq=freq,
|
|
69
|
+
freq_prefix_tuning=freq_prefix_tuning,
|
|
135
70
|
fcm_prepend_past=fcm_prepend_past,
|
|
136
|
-
#loss='mse',
|
|
137
71
|
enable_forecast_channel_mixing=enable_forecast_channel_mixing,
|
|
72
|
+
|
|
138
73
|
)
|
|
139
|
-
self.
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
prediction_channel_indices = list(range(out_channels * 3))
|
|
143
|
-
exogenous_channel_indices = [prediction_channel_indices[-1] + i for i in range(1, len(exogenous_channel_indices)+1)]
|
|
144
|
-
num_input_channels = len(prediction_channel_indices) + len(exogenous_channel_indices)
|
|
145
|
-
return prediction_channel_indices, exogenous_channel_indices, num_input_channels
|
|
74
|
+
hidden_size = self.model.config.hidden_size
|
|
75
|
+
self.model.prediction_head = torch.nn.Linear(hidden_size, self.out_channels*self.mul)
|
|
76
|
+
self._freeze_backbone()
|
|
146
77
|
|
|
147
|
-
def
|
|
78
|
+
def _freeze_backbone(self):
|
|
148
79
|
"""
|
|
149
80
|
Freeze the backbone of the model.
|
|
150
81
|
This is useful when you want to fine-tune only the head of the model.
|
|
151
82
|
"""
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
count_parameters(self.model),
|
|
155
|
-
)
|
|
83
|
+
beauty_string(f"Number of params before freezing backbone:{count_parameters(self.model)}",'info',self.verbose)
|
|
84
|
+
|
|
156
85
|
# Freeze the backbone of the model
|
|
157
86
|
for param in self.model.backbone.parameters():
|
|
158
87
|
param.requires_grad = False
|
|
159
88
|
# Count params
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
count_parameters(self.model),
|
|
163
|
-
)
|
|
89
|
+
beauty_string(f"Number of params after freezing the backbone: {count_parameters(self.model)}",'info',self.verbose)
|
|
90
|
+
|
|
164
91
|
|
|
165
|
-
def
|
|
166
|
-
|
|
167
|
-
|
|
92
|
+
def _scaler_past(self, input):
|
|
93
|
+
for i, e in enumerate(self.embs_past):
|
|
94
|
+
input[:,:,i] = input[:, :, i] / (e-1)
|
|
95
|
+
return input
|
|
96
|
+
def _scaler_fut(self, input):
|
|
97
|
+
for i, e in enumerate(self.embs_fut):
|
|
168
98
|
input[:,:,i] = input[:, :, i] / (e-1)
|
|
169
99
|
return input
|
|
170
|
-
|
|
171
|
-
def __build_tupla_indexes(self, size, target_idx, current_idx):
|
|
172
|
-
permute = list(range(size))
|
|
173
|
-
history = dict()
|
|
174
|
-
for j, i in enumerate(target_idx):
|
|
175
|
-
c = history.get(current_idx[j], current_idx[j])
|
|
176
|
-
permute[i], permute[c] = current_idx[j], i
|
|
177
|
-
history[i] = current_idx[j]
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def __permute_indexes(self, values, target_idx, current_idx):
|
|
181
|
-
if current_idx is None or target_idx is None:
|
|
182
|
-
raise ValueError("Indexes cannot be None")
|
|
183
|
-
if sorted(current_idx) != sorted(target_idx):
|
|
184
|
-
return values[..., self.__build_tupla_indexes(values.shape[-1], target_idx, current_idx)]
|
|
185
|
-
return values
|
|
186
|
-
|
|
187
|
-
def __extend_with_quantile_variables(self, x, original_indexes):
|
|
188
|
-
covariate_indexes = [i for i in range(x.shape[-1]) if i not in original_indexes]
|
|
189
|
-
covariate_tensors = x[..., covariate_indexes]
|
|
190
100
|
|
|
191
|
-
new_tensors = [x[..., target_index] for target_index in original_indexes for _ in range(3)]
|
|
192
101
|
|
|
193
|
-
new_original_indexes = list(range(len(original_indexes) * 3))
|
|
194
|
-
return torch.cat([torch.stack(new_tensors, dim=-1), covariate_tensors], dim=-1), new_original_indexes
|
|
195
|
-
|
|
196
102
|
def forward(self, batch):
|
|
197
103
|
x_enc = batch['x_num_past']
|
|
198
104
|
original_indexes = batch['idx_target'][0].tolist()
|
|
199
|
-
original_indexes_future = batch['idx_target_future'][0].tolist()
|
|
200
105
|
|
|
201
106
|
|
|
202
|
-
|
|
203
|
-
x_enc, original_indexes = self.__extend_with_quantile_variables(x_enc, original_indexes)
|
|
107
|
+
|
|
204
108
|
|
|
205
109
|
if 'x_cat_past' in batch.keys():
|
|
206
110
|
x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
|
|
207
|
-
x_mark_enc = self.
|
|
111
|
+
x_mark_enc = self._scaler_past(x_mark_enc)
|
|
208
112
|
past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
|
|
209
113
|
else:
|
|
210
114
|
past_values = x_enc
|
|
211
115
|
|
|
212
|
-
|
|
116
|
+
future_values = torch.zeros_like(past_values)
|
|
117
|
+
future_values = future_values[:,:self.future_steps,:]
|
|
118
|
+
|
|
213
119
|
if 'x_num_future' in batch.keys():
|
|
214
|
-
|
|
215
|
-
if self.extend_variables:
|
|
216
|
-
x_dec, original_indexes_future = self.__extend_with_quantile_variables(x_dec, original_indexes_future)
|
|
120
|
+
future_values[:,:,self.index_fut] = batch['x_num_future'].to(self.device)
|
|
217
121
|
if 'x_cat_future' in batch.keys():
|
|
218
122
|
x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
|
|
219
|
-
x_mark_dec = self.
|
|
220
|
-
future_values =
|
|
221
|
-
else:
|
|
222
|
-
future_values = x_dec
|
|
223
|
-
|
|
224
|
-
if self.remove_last:
|
|
225
|
-
idx_target = batch['idx_target'][0]
|
|
226
|
-
x_start = x_enc[:,-1,idx_target].unsqueeze(1)
|
|
227
|
-
x_enc[:,:,idx_target]-=x_start
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
past_values = self.__permute_indexes(past_values, self.model.prediction_channel_indices, original_indexes)
|
|
231
|
-
|
|
123
|
+
x_mark_dec = self._scaler_fut(x_mark_dec)
|
|
124
|
+
future_values[:,:,self.index_cat_fut] = x_mark_dec
|
|
232
125
|
|
|
233
|
-
future_values = self.__permute_indexes(future_values, self.model.prediction_channel_indices, original_indexes_future)
|
|
234
126
|
|
|
235
|
-
|
|
127
|
+
#investigating!!
|
|
128
|
+
freq_token = get_frequency_token(self.freq).repeat(past_values.shape[0])
|
|
236
129
|
|
|
237
130
|
res = self.model(
|
|
238
131
|
past_values= past_values,
|
|
239
|
-
future_values= future_values,
|
|
132
|
+
future_values= future_values,# future_values if future_values.shape[0]>0 else None,
|
|
240
133
|
past_observed_mask = None,
|
|
241
134
|
future_observed_mask = None,
|
|
242
135
|
output_hidden_states = False,
|
|
243
136
|
return_dict = False,
|
|
244
|
-
freq_token= freq_token,
|
|
137
|
+
freq_token= freq_token, ##investigating
|
|
245
138
|
static_categorical_values = None
|
|
246
139
|
)
|
|
247
|
-
|
|
248
|
-
|
|
140
|
+
|
|
141
|
+
|
|
249
142
|
BS = res.shape[0]
|
|
250
143
|
return res.reshape(BS,self.future_steps,-1,self.mul)
|
|
251
144
|
|
dsipts/models/base.py
CHANGED
|
@@ -111,8 +111,11 @@ class Base(pl.LightningModule):
|
|
|
111
111
|
self.train_loss_epoch = -100.0
|
|
112
112
|
self.verbose = verbose
|
|
113
113
|
self.name = self.__class__.__name__
|
|
114
|
-
self.train_epoch_metrics
|
|
115
|
-
self.validation_epoch_metrics
|
|
114
|
+
self.register_buffer("train_epoch_metrics", torch.tensor(0.0))
|
|
115
|
+
self.register_buffer("validation_epoch_metrics", torch.tensor(0.0))
|
|
116
|
+
self.register_buffer("train_epoch_count", torch.tensor(0))
|
|
117
|
+
self.register_buffer("validation_epoch_count", torch.tensor(0))
|
|
118
|
+
|
|
116
119
|
|
|
117
120
|
self.use_quantiles = True if len(quantiles)>0 else False
|
|
118
121
|
self.quantiles = quantiles
|
|
@@ -295,7 +298,8 @@ class Base(pl.LightningModule):
|
|
|
295
298
|
y_hat = self(batch)
|
|
296
299
|
loss = self.compute_loss(batch,y_hat)
|
|
297
300
|
|
|
298
|
-
self.train_epoch_metrics
|
|
301
|
+
self.train_epoch_metrics+=loss.detach()
|
|
302
|
+
self.train_epoch_count +=1
|
|
299
303
|
return loss
|
|
300
304
|
|
|
301
305
|
|
|
@@ -311,27 +315,20 @@ class Base(pl.LightningModule):
|
|
|
311
315
|
y_hat = self(batch)
|
|
312
316
|
score = 0
|
|
313
317
|
if batch_idx==0:
|
|
314
|
-
|
|
315
|
-
idx = 1
|
|
316
|
-
else:
|
|
317
|
-
idx = 0
|
|
318
|
-
#track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
|
|
319
|
-
|
|
318
|
+
|
|
320
319
|
if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
return self.compute_loss(batch,y_hat)+score
|
|
334
|
-
|
|
320
|
+
self._val_outputs.append({
|
|
321
|
+
"y": batch['y'].detach().cpu(),
|
|
322
|
+
"y_hat": y_hat.detach().cpu()
|
|
323
|
+
})
|
|
324
|
+
self.validation_epoch_metrics = (self.compute_loss(batch,y_hat)+score).detach()
|
|
325
|
+
self.validation_epoch_count+=1
|
|
326
|
+
|
|
327
|
+
return None #self.compute_loss(batch,y_hat)+score
|
|
328
|
+
|
|
329
|
+
def on_validation_start(self):
|
|
330
|
+
# reset buffer each epoch
|
|
331
|
+
self._val_outputs = []
|
|
335
332
|
|
|
336
333
|
def validation_epoch_end(self, outs):
|
|
337
334
|
"""
|
|
@@ -339,14 +336,30 @@ class Base(pl.LightningModule):
|
|
|
339
336
|
|
|
340
337
|
:meta private:
|
|
341
338
|
"""
|
|
342
|
-
if len(
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
339
|
+
if len(self._val_outputs)>0:
|
|
340
|
+
ys = torch.cat([o["y"] for o in self._val_outputs])
|
|
341
|
+
y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
|
|
342
|
+
if self.use_quantiles:
|
|
343
|
+
idx = 1
|
|
344
|
+
else:
|
|
345
|
+
idx = 0
|
|
346
|
+
for i in range(ys.shape[2]):
|
|
347
|
+
real = ys[0,:,i].cpu().detach().numpy()
|
|
348
|
+
pred = y_hats[0,:,i,idx].cpu().detach().numpy()
|
|
349
|
+
fig, ax = plt.subplots(figsize=(7,5))
|
|
350
|
+
ax.plot(real,'o-',label='real')
|
|
351
|
+
ax.plot(pred,'o-',label='pred')
|
|
352
|
+
ax.legend()
|
|
353
|
+
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
354
|
+
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
355
|
+
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
356
|
+
plt.close(fig)
|
|
357
|
+
avg = self.validation_epoch_metrics/self.validation_epoch_count
|
|
358
|
+
|
|
359
|
+
self.validation_epoch_metrics.zero_()
|
|
360
|
+
self.validation_epoch_count.zero_()
|
|
361
|
+
self.log("val_loss", avg,sync_dist=True)
|
|
362
|
+
beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
|
|
350
363
|
|
|
351
364
|
def training_epoch_end(self, outs):
|
|
352
365
|
"""
|
|
@@ -355,12 +368,11 @@ class Base(pl.LightningModule):
|
|
|
355
368
|
:meta private:
|
|
356
369
|
"""
|
|
357
370
|
|
|
358
|
-
loss =
|
|
359
|
-
self.log("train_loss", loss
|
|
371
|
+
loss = self.train_epoch_metrics/self.global_step
|
|
372
|
+
self.log("train_loss", loss,sync_dist=True)
|
|
360
373
|
self.count_epoch+=1
|
|
361
374
|
|
|
362
|
-
self.train_loss_epoch = loss
|
|
363
|
-
|
|
375
|
+
self.train_loss_epoch = loss
|
|
364
376
|
def compute_loss(self,batch,y_hat):
|
|
365
377
|
"""
|
|
366
378
|
custom loss calculation
|
dsipts/models/base_v2.py
CHANGED
|
@@ -15,7 +15,6 @@ from typing import List, Union
|
|
|
15
15
|
from .utils import QuantileLossMO, CPRS
|
|
16
16
|
import torch.nn as nn
|
|
17
17
|
|
|
18
|
-
|
|
19
18
|
def standardize_momentum(x,order):
|
|
20
19
|
mean = torch.mean(x,1).unsqueeze(1).repeat(1,x.shape[1],1)
|
|
21
20
|
num = torch.pow(x-mean,order).mean(axis=1)
|
|
@@ -113,8 +112,13 @@ class Base(pl.LightningModule):
|
|
|
113
112
|
self.train_loss_epoch = -100.0
|
|
114
113
|
self.verbose = verbose
|
|
115
114
|
self.name = self.__class__.__name__
|
|
116
|
-
self.train_epoch_metrics =
|
|
117
|
-
self.validation_epoch_metrics =
|
|
115
|
+
#self.train_epoch_metrics = 0
|
|
116
|
+
#self.validation_epoch_metrics = 0
|
|
117
|
+
|
|
118
|
+
self.register_buffer("train_epoch_metrics", torch.tensor(0.0))
|
|
119
|
+
self.register_buffer("validation_epoch_metrics", torch.tensor(0.0))
|
|
120
|
+
self.register_buffer("train_epoch_count", torch.tensor(0))
|
|
121
|
+
self.register_buffer("validation_epoch_count", torch.tensor(0))
|
|
118
122
|
|
|
119
123
|
self.use_quantiles = True if len(quantiles)>0 else False
|
|
120
124
|
self.quantiles = quantiles
|
|
@@ -299,7 +303,8 @@ class Base(pl.LightningModule):
|
|
|
299
303
|
y_hat = self(batch)
|
|
300
304
|
loss = self.compute_loss(batch,y_hat)
|
|
301
305
|
|
|
302
|
-
self.train_epoch_metrics
|
|
306
|
+
self.train_epoch_metrics+=loss.detach()
|
|
307
|
+
self.train_epoch_count +=1
|
|
303
308
|
return loss
|
|
304
309
|
|
|
305
310
|
|
|
@@ -316,41 +321,54 @@ class Base(pl.LightningModule):
|
|
|
316
321
|
y_hat = self(batch)
|
|
317
322
|
score = 0
|
|
318
323
|
if batch_idx==0:
|
|
319
|
-
|
|
320
|
-
idx = 1
|
|
321
|
-
else:
|
|
322
|
-
idx = 0
|
|
324
|
+
|
|
323
325
|
#track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
|
|
324
326
|
|
|
325
327
|
if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
|
|
328
|
+
self._val_outputs.append({
|
|
329
|
+
"y": batch['y'].detach().cpu(),
|
|
330
|
+
"y_hat": y_hat.detach().cpu()
|
|
331
|
+
})
|
|
332
|
+
self.validation_epoch_metrics+= (self.compute_loss(batch,y_hat)+score).detach()
|
|
333
|
+
self.validation_epoch_count+=1
|
|
334
|
+
return None
|
|
326
335
|
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
ax.plot(real,'o-',label='real')
|
|
332
|
-
ax.plot(pred,'o-',label='pred')
|
|
333
|
-
ax.legend()
|
|
334
|
-
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
335
|
-
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
336
|
-
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
337
|
-
self.validation_epoch_metrics.append(self.compute_loss(batch,y_hat)+score)
|
|
338
|
-
return
|
|
339
|
-
|
|
336
|
+
def on_validation_start(self):
|
|
337
|
+
# reset buffer each epoch
|
|
338
|
+
self._val_outputs = []
|
|
339
|
+
|
|
340
340
|
|
|
341
341
|
def on_validation_epoch_end(self):
|
|
342
342
|
"""
|
|
343
343
|
pythotrch lightening stuff
|
|
344
344
|
|
|
345
345
|
:meta private:
|
|
346
|
-
"""
|
|
346
|
+
"""
|
|
347
347
|
|
|
348
|
-
if len(self.
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
348
|
+
if len(self._val_outputs)>0:
|
|
349
|
+
ys = torch.cat([o["y"] for o in self._val_outputs])
|
|
350
|
+
y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
|
|
351
|
+
if self.use_quantiles:
|
|
352
|
+
idx = 1
|
|
353
|
+
else:
|
|
354
|
+
idx = 0
|
|
355
|
+
for i in range(ys.shape[2]):
|
|
356
|
+
real = ys[0,:,i].cpu().detach().numpy()
|
|
357
|
+
pred = y_hats[0,:,i,idx].cpu().detach().numpy()
|
|
358
|
+
fig, ax = plt.subplots(figsize=(7,5))
|
|
359
|
+
ax.plot(real,'o-',label='real')
|
|
360
|
+
ax.plot(pred,'o-',label='pred')
|
|
361
|
+
ax.legend()
|
|
362
|
+
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
363
|
+
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
364
|
+
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
365
|
+
plt.close(fig)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
avg = self.validation_epoch_metrics/self.validation_epoch_count
|
|
369
|
+
|
|
370
|
+
self.validation_epoch_metrics.zero_()
|
|
371
|
+
self.validation_epoch_count.zero_()
|
|
354
372
|
self.log("val_loss", avg,sync_dist=True)
|
|
355
373
|
beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
|
|
356
374
|
|
|
@@ -361,14 +379,12 @@ class Base(pl.LightningModule):
|
|
|
361
379
|
|
|
362
380
|
:meta private:
|
|
363
381
|
"""
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
beauty_string(f'THIS IS A BUG, It should be polulated','info',self.verbose)
|
|
367
|
-
else:
|
|
368
|
-
avg = np.stack(self.train_epoch_metrics).mean()
|
|
382
|
+
|
|
383
|
+
avg = self.train_epoch_metrics/self.train_epoch_count
|
|
369
384
|
self.log("train_loss", avg,sync_dist=True)
|
|
370
385
|
self.count_epoch+=1
|
|
371
|
-
self.train_epoch_metrics
|
|
386
|
+
self.train_epoch_metrics.zero_()
|
|
387
|
+
self.train_epoch_count.zero_()
|
|
372
388
|
self.train_loss_epoch = avg
|
|
373
389
|
|
|
374
390
|
def compute_loss(self,batch,y_hat):
|
|
@@ -3,7 +3,7 @@ dsipts/data_management/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
|
|
|
3
3
|
dsipts/data_management/monash.py,sha256=aZxq9FbIH6IsU8Lwou1hAokXjgOAK-wdl2VAeFg2k4M,13075
|
|
4
4
|
dsipts/data_management/public_datasets.py,sha256=yXFzOZZ-X0ZG1DoqVU-zFmEGVMc2033YDQhRgYxY8ws,6793
|
|
5
5
|
dsipts/data_structure/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
-
dsipts/data_structure/data_structure.py,sha256=
|
|
6
|
+
dsipts/data_structure/data_structure.py,sha256=uyGkc1eDjETpXb8rgMMbRUjG8i9Xiiu6vZc64xfTiew,59914
|
|
7
7
|
dsipts/data_structure/modifiers.py,sha256=qlry9dfw8pEE0GrvgwROZJkJ6oPpUnjEHPIG5qIetss,7948
|
|
8
8
|
dsipts/data_structure/utils.py,sha256=QwfKPZgSy6DIw5n6ztOdPJIAnzo4EnlMTgRbpiWnyko,6593
|
|
9
9
|
dsipts/models/Autoformer.py,sha256=ddGT3L9T4gAXNJHx1TsuYZy7j63Anyr0rkqqXaOoSu4,8447
|
|
@@ -13,7 +13,7 @@ dsipts/models/Diffusion.py,sha256=pUujnrdeSSkj4jC1RORbcptt03KpuCsGVwg414o4LPg,40
|
|
|
13
13
|
dsipts/models/DilatedConv.py,sha256=_c0NvFuT3vbYmo9A8cQchGo1XVb0qOpzBprNEkkAgiE,14292
|
|
14
14
|
dsipts/models/DilatedConvED.py,sha256=fXk1-EWiRC5J_VIepTjYKya_D02SlEAkyiJcCjhW_XU,14004
|
|
15
15
|
dsipts/models/Duet.py,sha256=EharWHT_r7tEYIk7BkozVLPZ0xptE5mmQmeFGm3uBsA,7628
|
|
16
|
-
dsipts/models/ITransformer.py,sha256=
|
|
16
|
+
dsipts/models/ITransformer.py,sha256=qMsk27PqpnakNY1YM_rbkj8MO6BaG06N3b6m30Oa0RQ,7256
|
|
17
17
|
dsipts/models/Informer.py,sha256=ByJ00qGk12ONFF7NZWAACzxxRb5UXcu5wpkGMYX9Cq4,6920
|
|
18
18
|
dsipts/models/LinearTS.py,sha256=B0-Sz4POwUyl-PN2ssSx8L-ZHgwrQQPcMmreyvSS47U,9104
|
|
19
19
|
dsipts/models/PatchTST.py,sha256=Z7DM1Kw5Ym8Hh9ywj0j9RuFtKaz_yVZmKFIYafjceM8,9061
|
|
@@ -23,13 +23,13 @@ dsipts/models/Samformer.py,sha256=s61Hi1o9iuw-KgSBPfiE80oJcK1j2fUA6N9f5BJgKJc,55
|
|
|
23
23
|
dsipts/models/Simple.py,sha256=K82E88A62NhV_7U9Euu2cn3Q8P287HDR7eIy7VqgwbM,3909
|
|
24
24
|
dsipts/models/TFT.py,sha256=JO2-AKIUag7bfm9Oeo4KmGfdYZJbzQBHPDqGVg0WUZI,13830
|
|
25
25
|
dsipts/models/TIDE.py,sha256=i8qXac2gImEVgE2X6cNxqW5kuQP3rzWMlQNdgJbNmKM,13033
|
|
26
|
-
dsipts/models/TTM.py,sha256=
|
|
26
|
+
dsipts/models/TTM.py,sha256=gc-8yzEtn8ZdRVvsZfZvz7iE-RgqpZc-JGmOCQr4U_0,5215
|
|
27
27
|
dsipts/models/TimeXER.py,sha256=aCg0003LxYZzqZWyWugpbW_iOybcdHN4OH6_v77qp4o,7056
|
|
28
28
|
dsipts/models/VQVAEA.py,sha256=sNJi8UZh-10qEIKcZK3SzhlOFUUjvqjoglzeZBFaeZM,13789
|
|
29
29
|
dsipts/models/VVA.py,sha256=BnPkJ0Nzue0oShSHZVRNlf5RvT0Iwtf9bx19vLB9Nn0,11939
|
|
30
30
|
dsipts/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
31
|
-
dsipts/models/base.py,sha256=
|
|
32
|
-
dsipts/models/base_v2.py,sha256=
|
|
31
|
+
dsipts/models/base.py,sha256=Gqsycy8ZXGaIVx9vvmYRpBCqdUxGE4tvC5ltgxlpEYY,19640
|
|
32
|
+
dsipts/models/base_v2.py,sha256=eraXo1IBEQmyW41f1dz3Q-i-61vZ2AS3tVz6_X8J0Pg,19886
|
|
33
33
|
dsipts/models/utils.py,sha256=kjTwyktNCFMpPUy6zoleBCSKlvMvK_Jkgyh2T1OXg3E,24497
|
|
34
34
|
dsipts/models/autoformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
35
35
|
dsipts/models/autoformer/layers.py,sha256=xHt8V1lKdD1cIvgxXdDbI_EqOz4zgOQ6LP8l7M1pAxM,13276
|
|
@@ -76,7 +76,7 @@ dsipts/models/vva/minigpt.py,sha256=bg0JddqSD322uxSGexen3nPXL_hGTsk3vNLR62d7-w8,
|
|
|
76
76
|
dsipts/models/vva/vqvae.py,sha256=RzCQ_M9xBprp7_x20dSV3EQqlO0FjPUGWV-qdyKrQsM,19680
|
|
77
77
|
dsipts/models/xlstm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
78
78
|
dsipts/models/xlstm/xLSTM.py,sha256=ZKZZmffmIq1Vb71CR4GSyM8viqVx-u0FChxhcNgHub8,10081
|
|
79
|
-
dsipts-1.1.
|
|
80
|
-
dsipts-1.1.
|
|
81
|
-
dsipts-1.1.
|
|
82
|
-
dsipts-1.1.
|
|
79
|
+
dsipts-1.1.11.dist-info/METADATA,sha256=fbMTKqi7b_vlvtmVSp5XJdkFrEC9SFF3DG_fKy58k_8,24795
|
|
80
|
+
dsipts-1.1.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
81
|
+
dsipts-1.1.11.dist-info/top_level.txt,sha256=i6o0rf5ScFwZK21E89dSKjVNjUBkrEQpn0-Vij43748,7
|
|
82
|
+
dsipts-1.1.11.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|