dsipts 1.1.10__py3-none-any.whl → 1.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,7 +35,18 @@ from .modifiers import *
35
35
  from aim.pytorch_lightning import AimLogger
36
36
  import time
37
37
 
38
-
38
+ class DummyScaler():
39
+ def __init__(self):
40
+ pass
41
+ def fit(self,x):
42
+ pass
43
+ def transform(self,x):
44
+ return x
45
+ def inverse_transform(self,x):
46
+ return x
47
+ def fit_transform(self,x):
48
+ return x
49
+
39
50
 
40
51
  pd.options.mode.chained_assignment = None
41
52
  log = logging.getLogger(__name__)
@@ -210,20 +221,23 @@ class TimeSeries():
210
221
  self.future_variables = []
211
222
  self.target_variables = ['signal']
212
223
  self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
213
-
224
+ self.num_var = list(np.sort(self.num_var))
214
225
 
215
226
  def enrich(self,dataset,columns):
216
- if columns =='hour':
217
- dataset[columns] = dataset.time.dt.hour
218
- elif columns=='dow':
219
- dataset[columns] = dataset.time.dt.weekday
220
- elif columns=='month':
221
- dataset[columns] = dataset.time.dt.month
222
- elif columns=='minute':
223
- dataset[columns] = dataset.time.dt.minute
224
- else:
225
- if columns not in dataset.columns:
226
- beauty_string(f'I can not automatically enrich column {columns}. Please contact the developers or add it manually to your dataset.','section',True)
227
+ try:
228
+ if columns =='hour':
229
+ dataset[columns] = dataset.time.dt.hour
230
+ elif columns=='dow':
231
+ dataset[columns] = dataset.time.dt.weekday
232
+ elif columns=='month':
233
+ dataset[columns] = dataset.time.dt.month
234
+ elif columns=='minute':
235
+ dataset[columns] = dataset.time.dt.minute
236
+ else:
237
+ if columns not in dataset.columns:
238
+ beauty_string(f'I can not automatically enrich column {columns}. Please contact the developers or add it manually to your dataset.','section',True)
239
+ except:
240
+ beauty_string(f'I can not automatically enrich column {columns}. Probably not a temporal index.','section',True)
227
241
 
228
242
  def load_signal(self,data:pd.DataFrame,
229
243
  enrich_cat:List[str] = [],
@@ -300,7 +314,7 @@ class TimeSeries():
300
314
  if check_past:
301
315
  beauty_string('I will update past column adding all target columns, if you want to avoid this beahviour please use check_pass as false','info',self.verbose)
302
316
  past_variables = list(set(past_variables).union(set(target_variables)))
303
-
317
+ past_variables = list(np.sort(past_variables))
304
318
  self.cat_past_var = cat_past_var
305
319
  self.cat_fut_var = cat_fut_var
306
320
 
@@ -321,14 +335,18 @@ class TimeSeries():
321
335
  beauty_string('Categorical {c} already present, it will be added to categorical variable but not call the enriching function','info',self.verbose)
322
336
  else:
323
337
  self.enrich(dataset,c)
338
+ self.cat_past_var = list(np.sort(self.cat_past_var))
339
+ self.cat_fut_var = list(np.sort(self.cat_fut_var))
340
+
324
341
  self.cat_var = list(set(self.cat_past_var+self.cat_fut_var)) ## all categorical data
325
-
342
+ self.cat_var = list(np.sort(self.cat_var))
326
343
  self.dataset = dataset
327
344
  self.past_variables = past_variables
328
345
  self.future_variables = future_variables
329
346
  self.target_variables = target_variables
330
347
  self.out_vars = len(target_variables)
331
348
  self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
349
+ self.num_var = list(np.sort(self.num_var))
332
350
  if silly_model:
333
351
  beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
334
352
  self.future_variables+=self.target_variables
@@ -665,7 +683,11 @@ class TimeSeries():
665
683
  #self.model.apply(weight_init_zeros)
666
684
 
667
685
  self.config = config
668
-
686
+ try:
687
+ self.model = torch.compile(self.model)
688
+ except:
689
+ beauty_string('Can not compile the model','block',self.verbose)
690
+
669
691
  beauty_string('Setting the model','block',self.verbose)
670
692
  beauty_string(model,'',self.verbose)
671
693
 
@@ -8,6 +8,8 @@ import numpy as np
8
8
  from .itransformer.Transformer_EncDec import Encoder, EncoderLayer
9
9
  from .itransformer.SelfAttention_Family import FullAttention, AttentionLayer
10
10
  from .itransformer.Embed import DataEmbedding_inverted
11
+ from ..data_structure.utils import beauty_string
12
+ from .utils import get_scope,get_activation,Embedding_cat_variables
11
13
 
12
14
  try:
13
15
  import lightning.pytorch as pl
@@ -17,12 +19,6 @@ except:
17
19
  import pytorch_lightning as pl
18
20
  OLD_PL = True
19
21
  from .base import Base
20
- from .utils import QuantileLossMO,Permute, get_activation
21
-
22
- from typing import List, Union
23
- from ..data_structure.utils import beauty_string
24
- from .utils import get_scope
25
- from .utils import Embedding_cat_variables
26
22
 
27
23
 
28
24
 
@@ -34,8 +30,6 @@ class ITransformer(Base):
34
30
  description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
35
31
 
36
32
  def __init__(self,
37
-
38
-
39
33
  # specific params
40
34
  hidden_size:int,
41
35
  d_model: int,
dsipts/models/TTM.py CHANGED
@@ -12,240 +12,133 @@ except:
12
12
  from .base import Base
13
13
 
14
14
 
15
-
16
- from typing import List,Union
17
-
18
- from .utils import QuantileLossMO
15
+ from .ttm.utils import get_model, get_frequency_token, count_parameters
19
16
  from ..data_structure.utils import beauty_string
20
- from .ttm.utils import get_model, get_frequency_token, count_parameters, RMSELoss
21
-
17
+ from .utils import get_scope
22
18
 
23
19
  class TTM(Base):
20
+ handle_multivariate = True
21
+ handle_future_covariates = True
22
+ handle_categorical_variables = True
23
+ handle_quantile_loss = True
24
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
25
+
24
26
  def __init__(self,
25
27
  model_path:str,
26
- past_steps:int,
27
- future_steps:int,
28
- freq_prefix_tuning:bool,
29
- freq:str,
30
28
  prefer_l1_loss:bool, # exog: set true to use l1 loss
31
29
  prefer_longer_context:bool,
32
- loss_type:str,
33
- num_input_channels,
34
30
  prediction_channel_indices,
35
- exogenous_channel_indices,
31
+ exogenous_channel_indices_cont,
32
+ exogenous_channel_indices_cat,
36
33
  decoder_mode,
34
+ freq,
35
+ freq_prefix_tuning,
37
36
  fcm_context_length,
38
37
  fcm_use_mixer,
39
38
  fcm_mix_layers,
40
39
  fcm_prepend_past,
41
40
  enable_forecast_channel_mixing,
42
- out_channels:int,
43
- embs:List[int],
44
- remove_last = False,
45
- optim:Union[str,None]=None,
46
- optim_config:dict=None,
47
- scheduler_config:dict=None,
48
- verbose = False,
49
- use_quantiles=False,
50
- persistence_weight:float=0.0,
51
- quantiles:List[int]=[],
52
41
  **kwargs)->None:
53
- """TODO and FIX for future and past categorical variables
54
-
55
- Args:
56
- model_path (str): _description_
57
- past_steps (int): _description_
58
- future_steps (int): _description_
59
- freq_prefix_tuning (bool): _description_
60
- freq (str): _description_
61
- prefer_l1_loss (bool): _description_
62
- loss_type (str): _description_
63
- num_input_channels (_type_): _description_
64
- prediction_channel_indices (_type_): _description_
65
- exogenous_channel_indices (_type_): _description_
66
- decoder_mode (_type_): _description_
67
- fcm_context_length (_type_): _description_
68
- fcm_use_mixer (_type_): _description_
69
- fcm_mix_layers (_type_): _description_
70
- fcm_prepend_past (_type_): _description_
71
- enable_forecast_channel_mixing (_type_): _description_
72
- out_channels (int): _description_
73
- embs (List[int]): _description_
74
- remove_last (bool, optional): _description_. Defaults to False.
75
- optim (Union[str,None], optional): _description_. Defaults to None.
76
- optim_config (dict, optional): _description_. Defaults to None.
77
- scheduler_config (dict, optional): _description_. Defaults to None.
78
- verbose (bool, optional): _description_. Defaults to False.
79
- use_quantiles (bool, optional): _description_. Defaults to False.
80
- persistence_weight (float, optional): _description_. Defaults to 0.0.
81
- quantiles (List[int], optional): _description_. Defaults to [].
82
- """
83
- super(TTM, self).__init__(verbose)
42
+
43
+ super().__init__(**kwargs)
84
44
  self.save_hyperparameters(logger=False)
85
- self.future_steps = future_steps
86
- self.use_quantiles = use_quantiles
87
- self.optim = optim
88
- self.optim_config = optim_config
89
- self.scheduler_config = scheduler_config
90
- self.persistence_weight = persistence_weight
91
- self.loss_type = loss_type
92
- self.remove_last = remove_last
93
- self.embs = embs
94
- self.freq = freq
95
- self.extend_variables = False
96
-
97
- # NOTE: For Hydra
98
- prediction_channel_indices = list(prediction_channel_indices)
99
- exogenous_channel_indices = list(exogenous_channel_indices)
100
-
101
- if len(quantiles)>0:
102
- assert len(quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
103
- self.use_quantiles = True
104
- self.mul = len(quantiles)
105
- self.loss = QuantileLossMO(quantiles)
106
- self.extend_variables = True
107
- if out_channels * 3 != len(prediction_channel_indices):
108
- prediction_channel_indices, exogenous_channel_indices, num_input_channels = self.__add_quantile_features(prediction_channel_indices,
109
- exogenous_channel_indices,
110
- out_channels)
45
+
46
+
47
+ self.index_fut = list(exogenous_channel_indices_cont)
48
+
49
+ if len(exogenous_channel_indices_cat)>0:
50
+ self.index_fut_cat = (self.past_channels+len(self.embs_past))+list(exogenous_channel_indices_cat)
111
51
  else:
112
- self.mul = 1
113
- if self.loss_type == 'mse':
114
- self.loss = nn.MSELoss(reduction="mean")
115
- elif self.loss_type == 'rmse':
116
- self.loss = RMSELoss()
117
- else:
118
- self.loss = nn.L1Loss()
119
-
52
+ self.index_fut_cat = []
53
+ self.freq = freq
54
+
120
55
  self.model = get_model(
121
56
  model_path=model_path,
122
- context_length=past_steps,
123
- prediction_length=future_steps,
124
- freq_prefix_tuning=freq_prefix_tuning,
125
- freq=freq,
57
+ context_length=self.past_steps,
58
+ prediction_length=self.future_steps,
126
59
  prefer_l1_loss=prefer_l1_loss,
127
60
  prefer_longer_context=prefer_longer_context,
128
- num_input_channels=num_input_channels,
61
+ num_input_channels=self.past_channels+len(self.embs_past), #giusto
129
62
  decoder_mode=decoder_mode,
130
63
  prediction_channel_indices=list(prediction_channel_indices),
131
- exogenous_channel_indices=list(exogenous_channel_indices),
64
+ exogenous_channel_indices=self.index_fut + self.index_fut_cat,
132
65
  fcm_context_length=fcm_context_length,
133
66
  fcm_use_mixer=fcm_use_mixer,
134
67
  fcm_mix_layers=fcm_mix_layers,
68
+ freq=freq,
69
+ freq_prefix_tuning=freq_prefix_tuning,
135
70
  fcm_prepend_past=fcm_prepend_past,
136
- #loss='mse',
137
71
  enable_forecast_channel_mixing=enable_forecast_channel_mixing,
72
+
138
73
  )
139
- self.__freeze_backbone()
140
-
141
- def __add_quantile_features(self, prediction_channel_indices, exogenous_channel_indices, out_channels):
142
- prediction_channel_indices = list(range(out_channels * 3))
143
- exogenous_channel_indices = [prediction_channel_indices[-1] + i for i in range(1, len(exogenous_channel_indices)+1)]
144
- num_input_channels = len(prediction_channel_indices) + len(exogenous_channel_indices)
145
- return prediction_channel_indices, exogenous_channel_indices, num_input_channels
74
+ hidden_size = self.model.config.hidden_size
75
+ self.model.prediction_head = torch.nn.Linear(hidden_size, self.out_channels*self.mul)
76
+ self._freeze_backbone()
146
77
 
147
- def __freeze_backbone(self):
78
+ def _freeze_backbone(self):
148
79
  """
149
80
  Freeze the backbone of the model.
150
81
  This is useful when you want to fine-tune only the head of the model.
151
82
  """
152
- print(
153
- "Number of params before freezing backbone",
154
- count_parameters(self.model),
155
- )
83
+ beauty_string(f"Number of params before freezing backbone:{count_parameters(self.model)}",'info',self.verbose)
84
+
156
85
  # Freeze the backbone of the model
157
86
  for param in self.model.backbone.parameters():
158
87
  param.requires_grad = False
159
88
  # Count params
160
- print(
161
- "Number of params after freezing the backbone",
162
- count_parameters(self.model),
163
- )
89
+ beauty_string(f"Number of params after freezing the backbone: {count_parameters(self.model)}",'info',self.verbose)
90
+
164
91
 
165
- def __scaler(self, input):
166
- #new_data = torch.tensor([MinMaxScaler().fit_transform(step_data) for step_data in data])
167
- for i, e in enumerate(self.embs):
92
+ def _scaler_past(self, input):
93
+ for i, e in enumerate(self.embs_past):
94
+ input[:,:,i] = input[:, :, i] / (e-1)
95
+ return input
96
+ def _scaler_fut(self, input):
97
+ for i, e in enumerate(self.embs_fut):
168
98
  input[:,:,i] = input[:, :, i] / (e-1)
169
99
  return input
170
-
171
- def __build_tupla_indexes(self, size, target_idx, current_idx):
172
- permute = list(range(size))
173
- history = dict()
174
- for j, i in enumerate(target_idx):
175
- c = history.get(current_idx[j], current_idx[j])
176
- permute[i], permute[c] = current_idx[j], i
177
- history[i] = current_idx[j]
178
-
179
-
180
- def __permute_indexes(self, values, target_idx, current_idx):
181
- if current_idx is None or target_idx is None:
182
- raise ValueError("Indexes cannot be None")
183
- if sorted(current_idx) != sorted(target_idx):
184
- return values[..., self.__build_tupla_indexes(values.shape[-1], target_idx, current_idx)]
185
- return values
186
-
187
- def __extend_with_quantile_variables(self, x, original_indexes):
188
- covariate_indexes = [i for i in range(x.shape[-1]) if i not in original_indexes]
189
- covariate_tensors = x[..., covariate_indexes]
190
100
 
191
- new_tensors = [x[..., target_index] for target_index in original_indexes for _ in range(3)]
192
101
 
193
- new_original_indexes = list(range(len(original_indexes) * 3))
194
- return torch.cat([torch.stack(new_tensors, dim=-1), covariate_tensors], dim=-1), new_original_indexes
195
-
196
102
  def forward(self, batch):
197
103
  x_enc = batch['x_num_past']
198
104
  original_indexes = batch['idx_target'][0].tolist()
199
- original_indexes_future = batch['idx_target_future'][0].tolist()
200
105
 
201
106
 
202
- if self.extend_variables:
203
- x_enc, original_indexes = self.__extend_with_quantile_variables(x_enc, original_indexes)
107
+
204
108
 
205
109
  if 'x_cat_past' in batch.keys():
206
110
  x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
207
- x_mark_enc = self.__scaler(x_mark_enc)
111
+ x_mark_enc = self._scaler_past(x_mark_enc)
208
112
  past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
209
113
  else:
210
114
  past_values = x_enc
211
115
 
212
- x_dec = torch.tensor([]).to(self.device)
116
+ future_values = torch.zeros_like(past_values)
117
+ future_values = future_values[:,:self.future_steps,:]
118
+
213
119
  if 'x_num_future' in batch.keys():
214
- x_dec = batch['x_num_future'].to(self.device)
215
- if self.extend_variables:
216
- x_dec, original_indexes_future = self.__extend_with_quantile_variables(x_dec, original_indexes_future)
120
+ future_values[:,:,self.index_fut] = batch['x_num_future'].to(self.device)
217
121
  if 'x_cat_future' in batch.keys():
218
122
  x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
219
- x_mark_dec = self.__scaler(x_mark_dec)
220
- future_values = torch.cat((x_dec, x_mark_dec), axis=-1).type(torch.float32)
221
- else:
222
- future_values = x_dec
223
-
224
- if self.remove_last:
225
- idx_target = batch['idx_target'][0]
226
- x_start = x_enc[:,-1,idx_target].unsqueeze(1)
227
- x_enc[:,:,idx_target]-=x_start
228
-
229
-
230
- past_values = self.__permute_indexes(past_values, self.model.prediction_channel_indices, original_indexes)
231
-
123
+ x_mark_dec = self._scaler_fut(x_mark_dec)
124
+ future_values[:,:,self.index_cat_fut] = x_mark_dec
232
125
 
233
- future_values = self.__permute_indexes(future_values, self.model.prediction_channel_indices, original_indexes_future)
234
126
 
235
- freq_token = get_frequency_token(self.freq).repeat(x_enc.shape[0])
127
+ #investigating!!
128
+ freq_token = get_frequency_token(self.freq).repeat(past_values.shape[0])
236
129
 
237
130
  res = self.model(
238
131
  past_values= past_values,
239
- future_values= future_values,
132
+ future_values= future_values,# future_values if future_values.shape[0]>0 else None,
240
133
  past_observed_mask = None,
241
134
  future_observed_mask = None,
242
135
  output_hidden_states = False,
243
136
  return_dict = False,
244
- freq_token= freq_token,
137
+ freq_token= freq_token, ##investigating
245
138
  static_categorical_values = None
246
139
  )
247
- #args = None
248
- #res = self.model(**args)
140
+
141
+
249
142
  BS = res.shape[0]
250
143
  return res.reshape(BS,self.future_steps,-1,self.mul)
251
144
 
dsipts/models/base.py CHANGED
@@ -111,8 +111,11 @@ class Base(pl.LightningModule):
111
111
  self.train_loss_epoch = -100.0
112
112
  self.verbose = verbose
113
113
  self.name = self.__class__.__name__
114
- self.train_epoch_metrics = []
115
- self.validation_epoch_metrics = []
114
+ self.register_buffer("train_epoch_metrics", torch.tensor(0.0))
115
+ self.register_buffer("validation_epoch_metrics", torch.tensor(0.0))
116
+ self.register_buffer("train_epoch_count", torch.tensor(0))
117
+ self.register_buffer("validation_epoch_count", torch.tensor(0))
118
+
116
119
 
117
120
  self.use_quantiles = True if len(quantiles)>0 else False
118
121
  self.quantiles = quantiles
@@ -295,7 +298,8 @@ class Base(pl.LightningModule):
295
298
  y_hat = self(batch)
296
299
  loss = self.compute_loss(batch,y_hat)
297
300
 
298
- self.train_epoch_metrics.append(loss.item())
301
+ self.train_epoch_metrics+=loss.detach()
302
+ self.train_epoch_count +=1
299
303
  return loss
300
304
 
301
305
 
@@ -311,27 +315,20 @@ class Base(pl.LightningModule):
311
315
  y_hat = self(batch)
312
316
  score = 0
313
317
  if batch_idx==0:
314
- if self.use_quantiles:
315
- idx = 1
316
- else:
317
- idx = 0
318
- #track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
319
-
318
+
320
319
  if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
321
-
322
- for i in range(batch['y'].shape[2]):
323
- real = batch['y'][0,:,i].cpu().detach().numpy()
324
- pred = y_hat[0,:,i,idx].cpu().detach().numpy()
325
- fig, ax = plt.subplots(figsize=(7,5))
326
- ax.plot(real,'o-',label='real')
327
- ax.plot(pred,'o-',label='pred')
328
- ax.legend()
329
- ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
330
- self.logger.experiment.track(Image(fig), name='cm_training_end')
331
- #self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
332
-
333
- return self.compute_loss(batch,y_hat)+score
334
-
320
+ self._val_outputs.append({
321
+ "y": batch['y'].detach().cpu(),
322
+ "y_hat": y_hat.detach().cpu()
323
+ })
324
+ self.validation_epoch_metrics = (self.compute_loss(batch,y_hat)+score).detach()
325
+ self.validation_epoch_count+=1
326
+
327
+ return None #self.compute_loss(batch,y_hat)+score
328
+
329
+ def on_validation_start(self):
330
+ # reset buffer each epoch
331
+ self._val_outputs = []
335
332
 
336
333
  def validation_epoch_end(self, outs):
337
334
  """
@@ -339,14 +336,30 @@ class Base(pl.LightningModule):
339
336
 
340
337
  :meta private:
341
338
  """
342
- if len(outs)==0:
343
- loss = 10000
344
- beauty_string(f'THIS IS A BUG, It should be polulated','info',self.verbose)
345
- else:
346
- loss = torch.stack(outs).mean()
347
-
348
- self.log("val_loss", loss.item(),sync_dist=True)
349
- beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {loss.item():.4f}','info',self.verbose)
339
+ if len(self._val_outputs)>0:
340
+ ys = torch.cat([o["y"] for o in self._val_outputs])
341
+ y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
342
+ if self.use_quantiles:
343
+ idx = 1
344
+ else:
345
+ idx = 0
346
+ for i in range(ys.shape[2]):
347
+ real = ys[0,:,i].cpu().detach().numpy()
348
+ pred = y_hats[0,:,i,idx].cpu().detach().numpy()
349
+ fig, ax = plt.subplots(figsize=(7,5))
350
+ ax.plot(real,'o-',label='real')
351
+ ax.plot(pred,'o-',label='pred')
352
+ ax.legend()
353
+ ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
354
+ self.logger.experiment.track(Image(fig), name='cm_training_end')
355
+ #self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
356
+ plt.close(fig)
357
+ avg = self.validation_epoch_metrics/self.validation_epoch_count
358
+
359
+ self.validation_epoch_metrics.zero_()
360
+ self.validation_epoch_count.zero_()
361
+ self.log("val_loss", avg,sync_dist=True)
362
+ beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
350
363
 
351
364
  def training_epoch_end(self, outs):
352
365
  """
@@ -355,12 +368,11 @@ class Base(pl.LightningModule):
355
368
  :meta private:
356
369
  """
357
370
 
358
- loss = sum(outs['loss'] for outs in outs) / len(outs)
359
- self.log("train_loss", loss.item(),sync_dist=True)
371
+ loss = self.train_epoch_metrics/self.global_step
372
+ self.log("train_loss", loss,sync_dist=True)
360
373
  self.count_epoch+=1
361
374
 
362
- self.train_loss_epoch = loss.item()
363
-
375
+ self.train_loss_epoch = loss
364
376
  def compute_loss(self,batch,y_hat):
365
377
  """
366
378
  custom loss calculation
dsipts/models/base_v2.py CHANGED
@@ -15,7 +15,6 @@ from typing import List, Union
15
15
  from .utils import QuantileLossMO, CPRS
16
16
  import torch.nn as nn
17
17
 
18
-
19
18
  def standardize_momentum(x,order):
20
19
  mean = torch.mean(x,1).unsqueeze(1).repeat(1,x.shape[1],1)
21
20
  num = torch.pow(x-mean,order).mean(axis=1)
@@ -113,8 +112,13 @@ class Base(pl.LightningModule):
113
112
  self.train_loss_epoch = -100.0
114
113
  self.verbose = verbose
115
114
  self.name = self.__class__.__name__
116
- self.train_epoch_metrics = []
117
- self.validation_epoch_metrics = []
115
+ #self.train_epoch_metrics = 0
116
+ #self.validation_epoch_metrics = 0
117
+
118
+ self.register_buffer("train_epoch_metrics", torch.tensor(0.0))
119
+ self.register_buffer("validation_epoch_metrics", torch.tensor(0.0))
120
+ self.register_buffer("train_epoch_count", torch.tensor(0))
121
+ self.register_buffer("validation_epoch_count", torch.tensor(0))
118
122
 
119
123
  self.use_quantiles = True if len(quantiles)>0 else False
120
124
  self.quantiles = quantiles
@@ -299,7 +303,8 @@ class Base(pl.LightningModule):
299
303
  y_hat = self(batch)
300
304
  loss = self.compute_loss(batch,y_hat)
301
305
 
302
- self.train_epoch_metrics.append(loss.item())
306
+ self.train_epoch_metrics+=loss.detach()
307
+ self.train_epoch_count +=1
303
308
  return loss
304
309
 
305
310
 
@@ -316,41 +321,54 @@ class Base(pl.LightningModule):
316
321
  y_hat = self(batch)
317
322
  score = 0
318
323
  if batch_idx==0:
319
- if self.use_quantiles:
320
- idx = 1
321
- else:
322
- idx = 0
324
+
323
325
  #track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
324
326
 
325
327
  if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
328
+ self._val_outputs.append({
329
+ "y": batch['y'].detach().cpu(),
330
+ "y_hat": y_hat.detach().cpu()
331
+ })
332
+ self.validation_epoch_metrics+= (self.compute_loss(batch,y_hat)+score).detach()
333
+ self.validation_epoch_count+=1
334
+ return None
326
335
 
327
- for i in range(batch['y'].shape[2]):
328
- real = batch['y'][0,:,i].cpu().detach().numpy()
329
- pred = y_hat[0,:,i,idx].cpu().detach().numpy()
330
- fig, ax = plt.subplots(figsize=(7,5))
331
- ax.plot(real,'o-',label='real')
332
- ax.plot(pred,'o-',label='pred')
333
- ax.legend()
334
- ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
335
- self.logger.experiment.track(Image(fig), name='cm_training_end')
336
- #self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
337
- self.validation_epoch_metrics.append(self.compute_loss(batch,y_hat)+score)
338
- return
339
-
336
+ def on_validation_start(self):
337
+ # reset buffer each epoch
338
+ self._val_outputs = []
339
+
340
340
 
341
341
  def on_validation_epoch_end(self):
342
342
  """
343
343
  pythotrch lightening stuff
344
344
 
345
345
  :meta private:
346
- """
346
+ """
347
347
 
348
- if len(self.validation_epoch_metrics)==0:
349
- avg = 10000
350
- beauty_string(f'THIS IS A BUG, It should be polulated','info',self.verbose)
351
- else:
352
- avg = torch.stack(self.validation_epoch_metrics).mean()
353
- self.validation_epoch_metrics = []
348
+ if len(self._val_outputs)>0:
349
+ ys = torch.cat([o["y"] for o in self._val_outputs])
350
+ y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
351
+ if self.use_quantiles:
352
+ idx = 1
353
+ else:
354
+ idx = 0
355
+ for i in range(ys.shape[2]):
356
+ real = ys[0,:,i].cpu().detach().numpy()
357
+ pred = y_hats[0,:,i,idx].cpu().detach().numpy()
358
+ fig, ax = plt.subplots(figsize=(7,5))
359
+ ax.plot(real,'o-',label='real')
360
+ ax.plot(pred,'o-',label='pred')
361
+ ax.legend()
362
+ ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
363
+ self.logger.experiment.track(Image(fig), name='cm_training_end')
364
+ #self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
365
+ plt.close(fig)
366
+
367
+
368
+ avg = self.validation_epoch_metrics/self.validation_epoch_count
369
+
370
+ self.validation_epoch_metrics.zero_()
371
+ self.validation_epoch_count.zero_()
354
372
  self.log("val_loss", avg,sync_dist=True)
355
373
  beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
356
374
 
@@ -361,14 +379,12 @@ class Base(pl.LightningModule):
361
379
 
362
380
  :meta private:
363
381
  """
364
- if len(self.train_epoch_metrics)==0:
365
- avg = 0
366
- beauty_string(f'THIS IS A BUG, It should be polulated','info',self.verbose)
367
- else:
368
- avg = np.stack(self.train_epoch_metrics).mean()
382
+
383
+ avg = self.train_epoch_metrics/self.train_epoch_count
369
384
  self.log("train_loss", avg,sync_dist=True)
370
385
  self.count_epoch+=1
371
- self.train_epoch_metrics = []
386
+ self.train_epoch_metrics.zero_()
387
+ self.train_epoch_count.zero_()
372
388
  self.train_loss_epoch = avg
373
389
 
374
390
  def compute_loss(self,batch,y_hat):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dsipts
3
- Version: 1.1.10
3
+ Version: 1.1.11
4
4
  Summary: Unified library for timeseries modelling
5
5
  Author-email: Andrea Gobbi <agobbi@fbk.eu>
6
6
  Project-URL: Homepage, https://github.com/DSIP-FBK/DSIPTS
@@ -3,7 +3,7 @@ dsipts/data_management/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
3
3
  dsipts/data_management/monash.py,sha256=aZxq9FbIH6IsU8Lwou1hAokXjgOAK-wdl2VAeFg2k4M,13075
4
4
  dsipts/data_management/public_datasets.py,sha256=yXFzOZZ-X0ZG1DoqVU-zFmEGVMc2033YDQhRgYxY8ws,6793
5
5
  dsipts/data_structure/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- dsipts/data_structure/data_structure.py,sha256=87VtKelx2EPoddrVYcja9dO5rQqaS83vZlQB_NY54PI,58994
6
+ dsipts/data_structure/data_structure.py,sha256=uyGkc1eDjETpXb8rgMMbRUjG8i9Xiiu6vZc64xfTiew,59914
7
7
  dsipts/data_structure/modifiers.py,sha256=qlry9dfw8pEE0GrvgwROZJkJ6oPpUnjEHPIG5qIetss,7948
8
8
  dsipts/data_structure/utils.py,sha256=QwfKPZgSy6DIw5n6ztOdPJIAnzo4EnlMTgRbpiWnyko,6593
9
9
  dsipts/models/Autoformer.py,sha256=ddGT3L9T4gAXNJHx1TsuYZy7j63Anyr0rkqqXaOoSu4,8447
@@ -13,7 +13,7 @@ dsipts/models/Diffusion.py,sha256=pUujnrdeSSkj4jC1RORbcptt03KpuCsGVwg414o4LPg,40
13
13
  dsipts/models/DilatedConv.py,sha256=_c0NvFuT3vbYmo9A8cQchGo1XVb0qOpzBprNEkkAgiE,14292
14
14
  dsipts/models/DilatedConvED.py,sha256=fXk1-EWiRC5J_VIepTjYKya_D02SlEAkyiJcCjhW_XU,14004
15
15
  dsipts/models/Duet.py,sha256=EharWHT_r7tEYIk7BkozVLPZ0xptE5mmQmeFGm3uBsA,7628
16
- dsipts/models/ITransformer.py,sha256=jO8wxLaC06Wgu4GncrFFTISv3pVyfFLLhQvbEOYsz6Y,7368
16
+ dsipts/models/ITransformer.py,sha256=qMsk27PqpnakNY1YM_rbkj8MO6BaG06N3b6m30Oa0RQ,7256
17
17
  dsipts/models/Informer.py,sha256=ByJ00qGk12ONFF7NZWAACzxxRb5UXcu5wpkGMYX9Cq4,6920
18
18
  dsipts/models/LinearTS.py,sha256=B0-Sz4POwUyl-PN2ssSx8L-ZHgwrQQPcMmreyvSS47U,9104
19
19
  dsipts/models/PatchTST.py,sha256=Z7DM1Kw5Ym8Hh9ywj0j9RuFtKaz_yVZmKFIYafjceM8,9061
@@ -23,13 +23,13 @@ dsipts/models/Samformer.py,sha256=s61Hi1o9iuw-KgSBPfiE80oJcK1j2fUA6N9f5BJgKJc,55
23
23
  dsipts/models/Simple.py,sha256=K82E88A62NhV_7U9Euu2cn3Q8P287HDR7eIy7VqgwbM,3909
24
24
  dsipts/models/TFT.py,sha256=JO2-AKIUag7bfm9Oeo4KmGfdYZJbzQBHPDqGVg0WUZI,13830
25
25
  dsipts/models/TIDE.py,sha256=i8qXac2gImEVgE2X6cNxqW5kuQP3rzWMlQNdgJbNmKM,13033
26
- dsipts/models/TTM.py,sha256=WpCiTN0qX3JFO6xgPLedoqMKXUC2pQpNAe9ee-Rw89Q,10602
26
+ dsipts/models/TTM.py,sha256=gc-8yzEtn8ZdRVvsZfZvz7iE-RgqpZc-JGmOCQr4U_0,5215
27
27
  dsipts/models/TimeXER.py,sha256=aCg0003LxYZzqZWyWugpbW_iOybcdHN4OH6_v77qp4o,7056
28
28
  dsipts/models/VQVAEA.py,sha256=sNJi8UZh-10qEIKcZK3SzhlOFUUjvqjoglzeZBFaeZM,13789
29
29
  dsipts/models/VVA.py,sha256=BnPkJ0Nzue0oShSHZVRNlf5RvT0Iwtf9bx19vLB9Nn0,11939
30
30
  dsipts/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
- dsipts/models/base.py,sha256=0r_gGD9CPAVmuqTmySugTpCVUgoHJrwaMAqLx3P-ZBw,19021
32
- dsipts/models/base_v2.py,sha256=b_RaVTBnA2dU4HpVPI-P0_VkmbsQHtYzxVf5iFVvp1U,19299
31
+ dsipts/models/base.py,sha256=Gqsycy8ZXGaIVx9vvmYRpBCqdUxGE4tvC5ltgxlpEYY,19640
32
+ dsipts/models/base_v2.py,sha256=eraXo1IBEQmyW41f1dz3Q-i-61vZ2AS3tVz6_X8J0Pg,19886
33
33
  dsipts/models/utils.py,sha256=kjTwyktNCFMpPUy6zoleBCSKlvMvK_Jkgyh2T1OXg3E,24497
34
34
  dsipts/models/autoformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
35
  dsipts/models/autoformer/layers.py,sha256=xHt8V1lKdD1cIvgxXdDbI_EqOz4zgOQ6LP8l7M1pAxM,13276
@@ -76,7 +76,7 @@ dsipts/models/vva/minigpt.py,sha256=bg0JddqSD322uxSGexen3nPXL_hGTsk3vNLR62d7-w8,
76
76
  dsipts/models/vva/vqvae.py,sha256=RzCQ_M9xBprp7_x20dSV3EQqlO0FjPUGWV-qdyKrQsM,19680
77
77
  dsipts/models/xlstm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
78
  dsipts/models/xlstm/xLSTM.py,sha256=ZKZZmffmIq1Vb71CR4GSyM8viqVx-u0FChxhcNgHub8,10081
79
- dsipts-1.1.10.dist-info/METADATA,sha256=hwFJB926XiPZjhisLz-Usqpic_ty16lk3ZwvHoZHC0c,24795
80
- dsipts-1.1.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
81
- dsipts-1.1.10.dist-info/top_level.txt,sha256=i6o0rf5ScFwZK21E89dSKjVNjUBkrEQpn0-Vij43748,7
82
- dsipts-1.1.10.dist-info/RECORD,,
79
+ dsipts-1.1.11.dist-info/METADATA,sha256=fbMTKqi7b_vlvtmVSp5XJdkFrEC9SFF3DG_fKy58k_8,24795
80
+ dsipts-1.1.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
81
+ dsipts-1.1.11.dist-info/top_level.txt,sha256=i6o0rf5ScFwZK21E89dSKjVNjUBkrEQpn0-Vij43748,7
82
+ dsipts-1.1.11.dist-info/RECORD,,