dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1167 @@
1
+ import numpy as np
2
+ import plotly.express as px
3
+ import pandas as pd
4
+ from typing import List
5
+ from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
6
+ from sklearn.preprocessing import *
7
+ from torch.utils.data import DataLoader
8
+ from .utils import extend_time_df,MetricsCallback, MyDataset, ActionEnum,beauty_string
9
+
10
+ try:
11
+
12
+ #new version of lightning
13
+ from lightning.pytorch.callbacks import ModelCheckpoint
14
+ import lightning.pytorch as pl
15
+ from ..models.base_v2 import Base
16
+ beauty_string('V2','block',True)
17
+ OLD_PL = False
18
+ except:
19
+ ## older version of lightning
20
+ from pytorch_lightning.callbacks import ModelCheckpoint
21
+ import pytorch_lightning as pl
22
+ from ..models.base import Base
23
+ beauty_string('V1','block',True)
24
+
25
+ OLD_PL = True
26
+
27
+ from typing import Union
28
+ import os
29
+ import torch
30
+ import pickle
31
+ from datetime import datetime
32
+ from ..models.utils import weight_init_zeros,weight_init
33
+ import logging
34
+ from .modifiers import *
35
+ from aim.pytorch_lightning import AimLogger
36
+ import time
37
+
38
+
39
+
40
+ pd.options.mode.chained_assignment = None
41
+ log = logging.getLogger(__name__)
42
+ log.addHandler(logging.NullHandler())
43
+
44
+
45
+ class Categorical():
46
+
47
+ def __init__(self,name:str, frequency: int,duration: List[int], classes: int, action: ActionEnum, level: List[float]):
48
+ """Class for generating toy categorical data
49
+
50
+ Args:
51
+ name (str): name of the categorical signal
52
+ frequency (int): frequency of the signal
53
+ duration (List[int]): duration of each class
54
+ classes (int): number of classes
55
+ action (str): one between additive or multiplicative
56
+ level (List[float]): intensity of each class
57
+
58
+ """
59
+
60
+ self.name = name
61
+ self.frequency = frequency
62
+ self.duration = duration
63
+ self.classes = classes
64
+ self.action = action
65
+ self.level = level
66
+ self.validate()
67
+
68
+ def validate(self):
69
+ """Validate, maybe there will be other checks in the future
70
+
71
+ :meta private:
72
+ """
73
+ if len(self.level) == self.classes:
74
+ pass
75
+ else:
76
+ raise ValueError("Length must match")
77
+
78
+ def generate_signal(self,length:int)->None:
79
+ """Generate the resposne signal
80
+
81
+ Args:
82
+ length (int): length of the signal
83
+ """
84
+ if self.action == 'multiplicative':
85
+ signal = np.ones(length)
86
+ elif self.action == 'additive':
87
+ signal = np.zeros(length)
88
+ classes = []
89
+ _class = 0
90
+ _level = self.level[0]
91
+ _duration = self.duration[0]
92
+
93
+ count = 0
94
+ count_freq = 0
95
+ for i in range(length):
96
+ if count_freq%self.frequency == 0:
97
+ signal[i] = _level
98
+ classes.append(_class)
99
+ count+=1
100
+ if count == _duration:
101
+ #change class
102
+ count = 0
103
+ _class+=1
104
+ count_freq+= _duration
105
+ _class = _class%self.classes
106
+ _level = self.level[_class]
107
+ _duration = self.duration[_class]
108
+
109
+
110
+ else:
111
+ classes.append(-1)
112
+ count_freq+=1
113
+
114
+ self.classes_array = classes
115
+ self.signal_array = signal
116
+
117
+ def plot(self)->None:
118
+ """Plot the series
119
+ """
120
+ tmp = pd.DataFrame({'time':range(len(self.classes_array)),'signal':self.signal,'class':self.classes_array})
121
+ fig = px.scatter(tmp,x='time',y='signal',color='class',title=self.name)
122
+ fig.show()
123
+
124
+
125
+ class TimeSeries():
126
+
127
+ def __init__(self,name:str,stacked:bool=False):
128
+ """Class for generating time series object. If you don't have any time series you can build one fake timeseries using some helping classes (Categorical for instance).
129
+
130
+
131
+ Args:
132
+ name (str): name of the series
133
+ stacked (bool): if true it is a stacked model
134
+
135
+ Usage:
136
+ For example we can generate a toy timeseries:\n
137
+ - add a multiplicative categorical feature (weekly)\n
138
+ >>> settimana = Categorical('settimanale',1,[1,1,1,1,1,1,1],7,'multiplicative',[0.9,0.8,0.7,0.6,0.5,0.99,0.99])\n
139
+ - an additive montly feature (here a year is composed by 5 months)\n
140
+ >>> mese = Categorical('mensile',1,[31,28,20,10,33],5,'additive',[10,20,-10,20,0])\n
141
+ - a spotted categorical variable that happens every 100 days and lasts 1 day\n
142
+ >>> spot = Categorical('spot',100,[7],1,'additive',[10])\n
143
+ >>> ts = TimeSeries('prova')\n
144
+ >>> ts.generate_signal(length = 5000,categorical_variables = [settimana,mese,spot],noise_mean=1,type=0) ##we can add also noise\n
145
+ >>> ts.plot()\n
146
+ """
147
+ self.is_trained = False
148
+ self.name = name
149
+ self.stacked = stacked
150
+ self.verbose = True
151
+ self.group = None
152
+ def __str__(self) -> str:
153
+ return f"Timeseries named {self.name} of length {self.dataset.shape[0]}.\n Categorical variable: {self.cat_var},\n Future variables: {self.future_variables},\n Past variables: {self.past_variables},\n Target variables: {self.target_variables} \n With {'no group' if self.group is None else self.group+' as group' }"
154
+ def __repr__(self) -> str:
155
+ return f"Timeseries named {self.name} of length {self.dataset.shape[0]}.\n Categorical variable: {self.cat_var},\n Future variables: {self.future_variables},\n Past variables: {self.past_variables},\n Target variables: {self.target_variables}\n With {'no group' if self.group is None else self.group+' as group' }"
156
+
157
+ def set_verbose(self,verbose:bool):
158
+ self.verbose = verbose
159
+ def _generate_base(self,length:int,type:int=0)-> None:
160
+ """Generate a basic timeseries
161
+
162
+ Args:
163
+ length (int): length
164
+ type (int, optional): Type of the generated timeseries. Defaults to 0.
165
+ """
166
+ if type==0:
167
+ self.base_signal = 10*np.cos(np.arange(length)/(2*np.pi*length/100))
168
+ self.out_vars = 1
169
+ else:
170
+ beauty_string('Please implement your own method','block',True)
171
+ """
172
+
173
+ """
174
+ def generate_signal(self,length:int=5000,categorical_variables:List[Categorical]=[],noise_mean:int=1,type:int=0)->None:
175
+ """This will generate a syntetic signal with a selected length, a noise level and some categorical variables. The additive series are added at the end while the multiplicative series acts on the original signal
176
+ The TS structure will be populated
177
+
178
+ Args:
179
+ length (int, optional): length of the signal. Defaults to 5000.
180
+ categorical_variables (List[Categorical], optional): list of Categorical variables. Defaults to [].
181
+ noise_mean (int, optional): variance of the noise to add at the end. Defaults to 1.
182
+ type (int, optional): type of the timeseries (only type=0 available right now). Defaults to 0.
183
+ """
184
+
185
+
186
+ dataset = pd.DataFrame({'time':range(length)})
187
+ self._generate_base(length,type)
188
+ signal = self.base_signal.copy()
189
+ tot = None
190
+ self.cat_var = []
191
+ for c in categorical_variables:
192
+ c.generate_signal(length)
193
+ _signal = c.signal_array
194
+ classes = c.classes_array
195
+ dataset[c.name] = classes
196
+ self.cat_var.append(c.name)
197
+ if c.action=='multiplicative':
198
+ signal*=_signal
199
+ else:
200
+ if tot is None:
201
+ additive = _signal
202
+ else:
203
+ additive+=_signal
204
+ signal+=additive
205
+ dataset['signal'] = signal + noise_mean*np.random.randn(len(signal))
206
+ self.dataset = dataset
207
+
208
+
209
+ self.past_variables = ['signal']
210
+ self.future_variables = []
211
+ self.target_variables = ['signal']
212
+ self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
213
+
214
+
215
+ def enrich(self,dataset,columns):
216
+ if columns =='hour':
217
+ dataset[columns] = dataset.time.dt.hour
218
+ elif columns=='dow':
219
+ dataset[columns] = dataset.time.dt.weekday
220
+ elif columns=='month':
221
+ dataset[columns] = dataset.time.dt.month
222
+ elif columns=='minute':
223
+ dataset[columns] = dataset.time.dt.minute
224
+ else:
225
+ if columns not in dataset.columns:
226
+ beauty_string(f'I can not automatically enrich column {columns}. Please contact the developers or add it manually to your dataset.','section',True)
227
+
228
+ def load_signal(self,data:pd.DataFrame,
229
+ enrich_cat:List[str] = [],
230
+ past_variables:List[str]=[],
231
+ future_variables:List[str]=[],
232
+ target_variables:List[str]=[],
233
+ cat_past_var:List[str]=[],
234
+ cat_fut_var:List[str]=[],
235
+ check_past:bool=True,
236
+ group:Union[None,str]=None,
237
+ check_holes_and_duplicates:bool=True,
238
+ silly_model:bool=False)->None:
239
+ """ This is a crucial point in the data structure. We expect here to have a dataset with time as timestamp.
240
+ There are some checks:
241
+ 1- the duplicates will tbe removed taking the first instance
242
+
243
+ 2- the frequency will the inferred taking the minumum time distance between samples
244
+
245
+ 3- the dataset will be filled completing the missing timestamps
246
+
247
+ Args:
248
+ data (pd.DataFrame): input dataset the column indicating the time must be called `time`
249
+ enrich_cat (List[str], optional): it is possible to let this function enrich the dataset for example adding the standard columns: hour, dow, month and minute. Defaults to [].
250
+ past_variables (List[str], optional): list of column names of past variables not available for future times . Defaults to [].
251
+ future_variables (List[str], optional): list of future variables available for tuture times. Defaults to [].
252
+ target_variables (List[str], optional): list of the target variables. They will added to past_variables by default unless `check_past` is false. Defaults to [].
253
+ cat_past_var (List[str], optional): list of the past categorical variables. Defaults to [].
254
+ cat_future_var (List[str], optional): list of the future categorical variables. Defaults to [].
255
+ check_past (bool, optional): see `target_variables`. Defaults to True.
256
+ group (str or None, optional): if not None the time serie dataset is considered composed by omogeneus timeseries coming from different realization (for example point of sales, cities, locations) and the relative series are not splitted during the sample generation. Defaults to None
257
+ check_holes_and_duplicates (bool, optional): if False duplicates or holes will not checked, the dataloader can not correctly work, disable at your own risk. Defaults True
258
+ silly_model (bool, optional): if True, target variables will be added to the pool of the future variables. This can be useful to see if information passes throught the decoder part of your model (if any)
259
+ """
260
+
261
+
262
+
263
+ dataset = data.copy()
264
+ dataset.sort_values(by='time',inplace=True)
265
+
266
+ if check_holes_and_duplicates:
267
+ beauty_string('I will drop duplicates, I dont like them','section',self.verbose)
268
+ dataset.drop_duplicates(subset=['time'] if group is None else [group,'time'], keep='first', inplace=True, ignore_index=True)
269
+
270
+ if group is None:
271
+ differences = dataset.time.diff()[1:]
272
+ else:
273
+ differences = dataset[dataset[group]==dataset[group].unique()[0]].time.diff()[1:]
274
+
275
+
276
+ if isinstance(dataset.time[0], datetime):
277
+ freq = pd.to_timedelta(differences.min())
278
+ else:
279
+ if int(dataset.time[0])==dataset.time[0]: ##ONLY THINK THAT WORKS IN GENERAL
280
+ freq = int(differences.min())
281
+ else:
282
+ raise TypeError("time must be integer or datetime")
283
+ self.freq = freq
284
+
285
+
286
+ if differences.nunique()>1:
287
+ beauty_string("There are holes in the dataset i will try to extend the dataframe inserting NAN",'info',self.verbose)
288
+ beauty_string(f'Detected minumum frequency: {freq}','section',self.verbose)
289
+ dataset = extend_time_df(dataset,freq,group).merge(dataset,how='left')
290
+ else:
291
+ beauty_string("I will compute the frequency as minimum of the time difference",'info',self.verbose)
292
+ self.freq = dataset.time.diff()[1:].min()
293
+ if isinstance(dataset.time.dtype, datetime):
294
+ self.freq = pd.to_timedelta(self.freq)
295
+
296
+
297
+ assert len(target_variables)>0, 'Provide at least one column for target'
298
+ assert 'time' in dataset.columns, 'The temporal column must be called time'
299
+ if set(target_variables).intersection(set(past_variables))!= set(target_variables):
300
+ if check_past:
301
+ beauty_string('I will update past column adding all target columns, if you want to avoid this beahviour please use check_pass as false','info',self.verbose)
302
+ past_variables = list(set(past_variables).union(set(target_variables)))
303
+
304
+ self.cat_past_var = cat_past_var
305
+ self.cat_fut_var = cat_fut_var
306
+
307
+ self.group = group
308
+ if group is not None:
309
+ if group not in cat_past_var:
310
+ beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
311
+ self.cat_var.append(group)
312
+ if group not in cat_fut_var:
313
+ beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
314
+ self.cat_fut_var.append(group)
315
+
316
+ self.enrich_cat = enrich_cat
317
+ for c in enrich_cat:
318
+ self.cat_past_var = list(set(self.cat_past_var+[c]))
319
+ self.cat_fut_var = list(set(self.cat_fut_var+[c]))
320
+ if c in dataset.columns:
321
+ beauty_string('Categorical {c} already present, it will be added to categorical variable but not call the enriching function','info',self.verbose)
322
+ else:
323
+ self.enrich(dataset,c)
324
+ self.cat_var = list(set(self.cat_past_var+self.cat_fut_var)) ## all categorical data
325
+
326
+ self.dataset = dataset
327
+ self.past_variables = past_variables
328
+ self.future_variables = future_variables
329
+ self.target_variables = target_variables
330
+ self.out_vars = len(target_variables)
331
+ self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
332
+ if silly_model:
333
+ beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
334
+ self.future_variables+=self.target_variables
335
+
336
+ def plot(self):
337
+ """
338
+ Easy way to control the loaded data
339
+ Returns:
340
+ plotly.graph_objects._figure.Figure: figure of the target variables
341
+ """
342
+
343
+ beauty_string('Plotting only target variables','block',self.verbose)
344
+ if self.group is None:
345
+ tmp = self.dataset[['time']+self.target_variables].melt(id_vars=['time'])
346
+ fig = px.line(tmp,x='time',y='value',color='variable',title=self.name)
347
+ fig.show()
348
+ else:
349
+ tmp = self.dataset[['time',self.group]+self.target_variables].melt(id_vars=['time',self.group])
350
+ fig = px.line(tmp,x='time',y='value',color='variable',title=self.name,facet_row=self.group)
351
+ fig.show()
352
+ return fig
353
+
354
+
355
+ def create_data_loader(self,data:pd.DataFrame,
356
+ past_steps:int,
357
+ future_steps:int,
358
+ shift:int=0,
359
+ keep_entire_seq_while_shifting:bool=False,
360
+ starting_point:Union[None,dict]=None,
361
+ skip_step:int=1,
362
+ is_inference:bool=False
363
+
364
+ )->MyDataset:
365
+ """ Create the dataset for the training/inference step
366
+
367
+ Args:
368
+ data (pd.DataFrame): input dataset, usually a subset of self.data
369
+ past_steps (int): past context length
370
+ future_steps (int): future lags to predict
371
+ shift (int, optional): if >0 the future input variables will be shifted (categorical and numerical). For example for attention model it is better to start with a know value of y and use it during the process. Defaults to 0.
372
+ keep_entire_seq_while_shifting (bool, optional): if the dataset is shifted, you may want the future data be of length future_step+shift (like informer), default false
373
+ starting_point (Union[None,dict], optional): a dictionary indicating if a sample must be considered. It is checked for the first lag in the future (useful in the case your model has to predict only starting from hour 12). Defaults to None.
374
+ skip_step (int, optional): list of the categortial variables (same for past and future). Usual there is a skip of one between two saples but for debugging or training time purposes you can skip some samples. Defaults to 1.
375
+ Returns:
376
+ MyDataset: class that extends torch.utils.data.Dataset (see utils)
377
+ keys of a batch:
378
+ y : the target variable(s)
379
+ x_num_past: the numerical past variables
380
+ x_num_future: the numerical future variables
381
+ x_cat_past: the categorical past variables
382
+ x_cat_future: the categorical future variables
383
+ idx_target: index of target features in the past array
384
+ """
385
+ beauty_string('Creating data loader','block',self.verbose)
386
+
387
+ x_num_past_samples = []
388
+ x_num_future_samples = []
389
+ x_cat_past_samples = []
390
+ x_cat_future_samples = []
391
+ y_samples = []
392
+ t_samples = []
393
+ g_samples = []
394
+
395
+ if starting_point is not None:
396
+ kk = list(starting_point.keys())[0]
397
+ assert kk not in self.cat_var, beauty_string('CAN NOT USE FEATURE {kk} as starting point it may have a different value due to the normalization step, please add a second column with a suitable name','info',True)
398
+
399
+ ##overwrite categorical columns
400
+ for c in self.cat_var:
401
+ self.enrich(data,c)
402
+
403
+ if self.group is None:
404
+ data['_GROUP_'] = '1'
405
+ else:
406
+ data['_GROUP_'] = data[self.group].values
407
+
408
+
409
+ if self.normalize_per_group:
410
+ tot = []
411
+ groups = data[self.group].unique()
412
+
413
+ data[self.group] = self.scaler_cat[self.group].transform(data[self.group].values.reshape(-1,1)).flatten()
414
+
415
+ for group in groups:
416
+ tmp = data[data['_GROUP_']==group].copy()
417
+
418
+ for c in self.num_var:
419
+ tmp[c] = self.scaler_num[f'{c}_{group}'].transform(tmp[c].values.reshape(-1,1)).flatten()
420
+ for c in self.cat_var:
421
+ if c!=self.group:
422
+ tmp[c] = self.scaler_cat[f'{c}_{group}'].transform(tmp[c].values.reshape(-1,1)).flatten()
423
+ tot.append(tmp)
424
+ data = pd.concat(tot,ignore_index=True)
425
+ else:
426
+ for c in self.cat_var:
427
+ data[c] = self.scaler_cat[c].transform(data[c].values.reshape(-1,1)).flatten()
428
+ for c in self.num_var:
429
+ data[c] = self.scaler_num[c].transform(data[c].values.reshape(-1,1)).flatten()
430
+
431
+ idx_target = []
432
+ for c in self.target_variables:
433
+ idx_target.append(self.past_variables.index(c))
434
+
435
+ idx_target_future = []
436
+
437
+ for c in self.target_variables:
438
+ if c in self.future_variables:
439
+ idx_target_future.append(self.future_variables.index(c))
440
+ if len(idx_target_future)==0:
441
+ idx_target_future = None
442
+
443
+
444
+ if self.stacked:
445
+ skip_stacked = future_steps*future_steps-future_steps
446
+ else:
447
+ skip_stacked = 0
448
+ for group in data['_GROUP_'].unique():
449
+ tmp = data[data['_GROUP_']==group]
450
+ groups = tmp['_GROUP_'].values
451
+ t = tmp.time.values
452
+ x_num_past = tmp[self.past_variables].values
453
+ if len(self.future_variables)>0:
454
+ x_num_future = tmp[self.future_variables].values
455
+ if len(self.cat_past_var)>0:
456
+ x_past_cat = tmp[self.cat_past_var].values
457
+ if len(self.cat_fut_var)>0:
458
+ x_fut_cat = tmp[self.cat_fut_var].values
459
+ y_target = tmp[self.target_variables].values
460
+
461
+
462
+ if starting_point is not None:
463
+ check = tmp[list(starting_point.keys())[0]].values == starting_point[list(starting_point.keys())[0]]
464
+ else:
465
+ check = [True]*len(y_target)
466
+
467
+ for i in range(past_steps,tmp.shape[0]-future_steps-skip_stacked,skip_step):
468
+ if check[i]:
469
+
470
+ if len(self.future_variables)>0:
471
+ if keep_entire_seq_while_shifting:
472
+ xx = x_num_future[i-shift+skip_stacked:i+future_steps+skip_stacked].mean()
473
+ else:
474
+ xx = x_num_future[i-shift+skip_stacked:i+future_steps-shift+skip_stacked].mean()
475
+ else:
476
+ xx = 0.0
477
+ if is_inference is False:
478
+ xx+=y_target[i+skip_stacked:i+future_steps+skip_stacked].min()
479
+
480
+ if np.isfinite(x_num_past[i-past_steps:i].min() + xx):
481
+
482
+ x_num_past_samples.append(x_num_past[i-past_steps:i])
483
+ if len(self.future_variables)>0:
484
+ if keep_entire_seq_while_shifting:
485
+ x_num_future_samples.append(x_num_future[i-shift+skip_stacked:i+future_steps+skip_stacked])
486
+ else:
487
+ x_num_future_samples.append(x_num_future[i-shift+skip_stacked:i+future_steps-shift+skip_stacked])
488
+ if len(self.cat_past_var)>0:
489
+ x_cat_past_samples.append(x_past_cat[i-past_steps:i])
490
+ if len(self.cat_fut_var)>0:
491
+ if keep_entire_seq_while_shifting:
492
+ x_cat_future_samples.append(x_fut_cat[i-shift+skip_stacked:i+future_steps+skip_stacked])
493
+ else:
494
+ x_cat_future_samples.append(x_fut_cat[i-shift+skip_stacked:i+future_steps-shift+skip_stacked])
495
+
496
+ y_samples.append(y_target[i+skip_stacked:i+future_steps+skip_stacked])
497
+ t_samples.append(t[i+skip_stacked:i+future_steps+skip_stacked])
498
+ g_samples.append(groups[i])
499
+
500
+
501
+
502
+ if len(self.future_variables)>0:
503
+ try:
504
+ x_num_future_samples = np.stack(x_num_future_samples)
505
+ except Exception as e:
506
+ beauty_string('WARNING x_num_future_samples is empty and it should not','info',True)
507
+
508
+ y_samples = np.stack(y_samples)
509
+ t_samples = np.stack(t_samples)
510
+ g_samples = np.stack(g_samples)
511
+
512
+ if len(self.cat_past_var)>0:
513
+ x_cat_past_samples = np.stack(x_cat_past_samples).astype(np.int32)
514
+ if len(self.cat_fut_var)>0:
515
+ x_cat_future_samples = np.stack(x_cat_future_samples).astype(np.int32)
516
+ x_num_past_samples = np.stack(x_num_past_samples)
517
+ if self.stacked:
518
+ mod = 0
519
+ else:
520
+ mod = 1.0
521
+ dd = {'y':y_samples.astype(np.float32),
522
+
523
+ 'x_num_past':(x_num_past_samples*mod).astype(np.float32)}
524
+ if len(self.cat_past_var)>0:
525
+ dd['x_cat_past'] = x_cat_past_samples
526
+ if len(self.cat_fut_var)>0:
527
+ dd['x_cat_future'] = x_cat_future_samples
528
+ if len(self.future_variables)>0:
529
+ dd['x_num_future'] = x_num_future_samples.astype(np.float32)
530
+
531
+ return MyDataset(dd,t_samples,g_samples,idx_target,idx_target_future)
532
+
533
+
534
+
535
+ def split_for_train(self,
536
+ perc_train:Union[float,None]=0.6,
537
+ perc_valid:Union[float,None]=0.2,
538
+ range_train:Union[List[Union[datetime, str]],None]=None,
539
+ range_validation:Union[List[Union[datetime, str]],None]=None,
540
+ range_test:Union[List[Union[datetime, str]],None]=None,
541
+ past_steps:int = 100,
542
+ future_steps:int=20,
543
+ shift:int = 0,
544
+ keep_entire_seq_while_shifting:bool=False,
545
+ starting_point:Union[None, dict]=None,
546
+ skip_step:int=1,
547
+ normalize_per_group: bool=False,
548
+ check_consecutive: bool=True,
549
+ scaler: str='StandardScaler()'
550
+ )->List[DataLoader]:
551
+ """Split the data and create the datasets.
552
+
553
+ Args:
554
+ perc_train (Union[float,None], optional): fraction of the training set. Defaults to 0.6.
555
+ perc_valid (Union[float,None], optional): fraction of the test set. Defaults to 0.2.
556
+ range_train (Union[List[Union[datetime, str]],None], optional): a list of two elements indicating the starting point and end point of the training set (string date style or datetime). Defaults to None.
557
+ range_validation (Union[List[Union[datetime, str]],None], optional):a list of two elements indicating the starting point and end point of the validation set (string date style or datetime). Defaults to None.
558
+ range_test (Union[List[Union[datetime, str]],None], optional): a list of two elements indicating the starting point and end point of the test set (string date style or datetime). Defaults to None.
559
+ past_steps (int, optional): past step to consider for making the prediction. Defaults to 100.
560
+ future_steps (int, optional): future step to predict. Defaults to 20.
561
+ shift (int, optional): see `create_data_loader`. Defaults to 0.
562
+ keep_entire_seq_while_shifting (bool, optional): if the dataset is shifted, you may want the future data be of length future_step+shift (like informer), default false
563
+
564
+ starting_point (Union[None, dict], optional): see `create_data_loader`. Defaults to None.
565
+ skip_step (int, optional): see `create_data_loader`. Defaults to 1.
566
+ normalize_per_group (boolean, optional): if true and self.group is not None, the variables are scaled respect to the groups. Default False
567
+ check_consecutive (boolean, optional): if false it skips the check on the consecutive ranges. Default True
568
+ scaler: instance of a sklearn.preprocessing scaler. Default 'StandardScaler()'
569
+ Returns:
570
+ List[DataLoader,DataLoader,DataLoadtrainer]: three dataloader used for training or inference
571
+ """
572
+
573
+ beauty_string('Splitting for train','block',self.verbose)
574
+
575
+
576
+ try:
577
+ ls = self.dataset.shape[0]
578
+ except Exception as _:
579
+ beauty_string('Empty dataset','info', True)
580
+ return None, None, None
581
+
582
+ if range_train is None:
583
+ if self.group is None:
584
+ beauty_string(f'Split temporally using perc_train: {perc_train} and perc_valid:{perc_valid}','section',self.verbose)
585
+ train = self.dataset.iloc[0:int(perc_train*ls)]
586
+ validation = self.dataset.iloc[int(perc_train*ls):int(perc_train*ls+perc_valid*ls)]
587
+ test = self.dataset.iloc[int(perc_train*ls+perc_valid*ls):]
588
+ else:
589
+ beauty_string(f'Split temporally using perc_train: {perc_train} and perc_valid:{perc_valid} for each group!','info',self.verbose)
590
+ train = []
591
+ validation =[]
592
+ test = []
593
+ ls = self.dataset.groupby(self.group).time.count().reset_index()
594
+ for group in self.dataset[self.group].unique():
595
+ tmp = self.dataset[self.dataset[self.group]==group]
596
+ lt = ls[ls[self.group]==group].time.values[0]
597
+ train.append(tmp[0:int(perc_train*lt)])
598
+ validation.append(tmp[int(perc_train*lt):int(perc_train*lt+perc_valid*lt)])
599
+ test.append(tmp[int(perc_train*lt+perc_valid*lt):])
600
+
601
+ train = pd.concat(train,ignore_index=True)
602
+ validation = pd.concat(validation,ignore_index=True)
603
+ test = pd.concat(test,ignore_index=True)
604
+
605
+
606
+ else:
607
+ if check_consecutive:
608
+ assert range_train[0]<range_train[1]<=range_validation[0]<range_validation[1]<=range_test[0]<range_test[1], beauty_string(f'The range are not correct','info',True)
609
+ beauty_string('Split temporally using the time intervals provided','section',self.verbose)
610
+ train = self.dataset[self.dataset.time.between(range_train[0],range_train[1])]
611
+ validation = self.dataset[self.dataset.time.between(range_validation[0],range_validation[1])]
612
+ test = self.dataset[self.dataset.time.between(range_test[0],range_test[1])]
613
+
614
+
615
+ beauty_string('Train categorical and numerical scalers','block',self.verbose)
616
+
617
+ if self.is_trained:
618
+ pass
619
+ else:
620
+ self.scaler_cat = {}
621
+ self.scaler_num = {}
622
+ if self.group is None or normalize_per_group is False:
623
+ self.normalize_per_group = False
624
+ for c in self.num_var:
625
+ self.scaler_num[c] = eval(scaler)
626
+ self.scaler_num[c].fit(train[c].values.reshape(-1,1))
627
+ for c in self.cat_var:
628
+ self.scaler_cat[c] = OrdinalEncoder(dtype=np.int32,handle_unknown= 'use_encoded_value',unknown_value=train[c].nunique())
629
+
630
+ self.scaler_cat[c].fit(train[c].values.reshape(-1,1))
631
+ else:
632
+ self.normalize_per_group = True
633
+ self.scaler_cat[self.group] = OrdinalEncoder(dtype=np.int32,handle_unknown= 'use_encoded_value',unknown_value=train[c].nunique())
634
+ self.scaler_cat[self.group].fit(train[self.group].values.reshape(-1,1))
635
+ for group in train[self.group].unique():
636
+ tmp = train[train[self.group]==group]
637
+
638
+ for c in self.num_var:
639
+ self.scaler_num[f'{c}_{group}'] = eval(scaler)
640
+ self.scaler_num[f'{c}_{group}'].fit(tmp[c].values.reshape(-1,1))
641
+ for c in self.cat_var:
642
+ if c!=self.group:
643
+ self.scaler_cat[f'{c}_{group}'] = OrdinalEncoder(dtype=np.int32,handle_unknown= 'use_encoded_value',unknown_value=train[c].nunique())
644
+ self.scaler_cat[f'{c}_{group}'].fit(tmp[c].values.reshape(-1,1))
645
+
646
+ dl_train = self.create_data_loader(train,past_steps,future_steps,shift,keep_entire_seq_while_shifting,starting_point,skip_step)
647
+ dl_validation = self.create_data_loader(validation,past_steps,future_steps,shift,keep_entire_seq_while_shifting,starting_point,skip_step)
648
+ if test.shape[0]>0:
649
+ dl_test = self.create_data_loader(test,past_steps,future_steps,shift,keep_entire_seq_while_shifting,starting_point,skip_step)
650
+ else:
651
+ dl_test = None
652
+ return dl_train,dl_validation,dl_test
653
+
654
+ def set_model(self,model:Base,config:dict=None,custom_init:bool=False):
655
+ """Set the model to train
656
+
657
+ Args:
658
+ model (Base): see `models`
659
+ config (dict, optional): usually the configuration used by the model. Defaults to None.
660
+ custom_init (bool, optional): if true a custom initialization paradigm will be used (see weight_init in models/utils.py ) .
661
+ """
662
+ self.model = model
663
+ if custom_init:
664
+ self.model.apply(weight_init)
665
+ #self.model.apply(weight_init_zeros)
666
+
667
+ self.config = config
668
+
669
+ beauty_string('Setting the model','block',self.verbose)
670
+ beauty_string(model,'',self.verbose)
671
+
672
+ def train_model(self,dirpath:str,
673
+ split_params:dict,
674
+ batch_size:int=100,
675
+ num_workers:int=4,
676
+ max_epochs:int=500,
677
+ auto_lr_find:bool=True,
678
+ gradient_clip_val:Union[float,None]=None,
679
+ gradient_clip_algorithm:str="value",
680
+ devices:Union[str,List[int]]='auto',
681
+ precision:Union[str,int]=32,
682
+ modifier:Union[None,str]=None,
683
+ modifier_params:Union[None,dict]=None,
684
+ seed:int=42
685
+ )-> float:
686
+ """Train the model
687
+
688
+ Args:
689
+ dirpath (str): path where to put all the useful things
690
+ split_params (dict): see `split_for_train`
691
+ batch_size (int, optional): batch size. Defaults to 100.
692
+ num_workers (int, optional): num_workers for the dataloader. Defaults to 4.
693
+ max_epochs (int, optional): maximum epochs to perform. Defaults to 500.
694
+ auto_lr_find (bool, optional): find initial learning rate, see `pytorch-lightening`. Defaults to True.
695
+ gradient_clip_val (Union[float,None], optional): gradient_clip_val. Defaults to None. See https://lightning.ai/docs/pytorch/stable/advanced/training_tricks.html
696
+ gradient_clip_algorithm (str, optional): gradient_clip_algorithm. Defaults to 'norm '. See https://lightning.ai/docs/pytorch/stable/advanced/training_tricks.html
697
+ devices (Union[str,List[int]], optional): devices to use. Use auto if cpu or the list of gpu to use otherwise. Defaults to 'auto'.
698
+ precision (Union[str,int], optional): precision to use. Usually 32 bit is fine but for larger model you should try 'bf16'. If 'auto' it will use bf16 for GPU and 32 for cpu
699
+ modifier (Union[str,int], optional): if not None a modifier is applyed to the dataloader. Sometimes lightening has very restrictive rules on the dataloader, or we want to use a ML model before or after the DL model (See readme for more information)
700
+ modifier_params (Union[dict,int], optional): parameters of the modifier
701
+ seed (int, optional): seed for reproducibility
702
+ """
703
+
704
+ beauty_string('Training the model','block',self.verbose)
705
+
706
+ self.split_params = split_params
707
+ self.check_custom = False
708
+ train,validation,test = self.split_for_train(**self.split_params)
709
+ accelerator = 'gpu' if torch.cuda.is_available() else "cpu"
710
+ strategy = "auto"
711
+ if accelerator == 'gpu':
712
+ strategy = "auto" ##TODO in future investigate on this
713
+ if precision=='auto':
714
+ precision = 'bf16'
715
+ #"bf16" ##in futuro magari inserirlo nei config, potrebbe essere che per alcuni modelli possa non andare bfloat32
716
+ torch.set_float32_matmul_precision('medium')
717
+ beauty_string('Setting multiplication precision to medium','info',self.verbose)
718
+ else:
719
+ devices = 'auto'
720
+ if precision=='auto':
721
+ precision = 32
722
+ beauty_string(f'train:{len(train)}, validation:{len(validation)}, test:{len(test) if test is not None else 0}','section',self.verbose)
723
+ if (accelerator=='gpu') and (num_workers>0):
724
+ persistent_workers = True
725
+ else:
726
+ persistent_workers = False
727
+
728
+
729
+ if modifier is not None:
730
+ modifier = eval(modifier)
731
+ modifier = modifier(**modifier_params)
732
+ train, validation = modifier.fit_transform(train=train,val=validation)
733
+ self.modifier = modifier
734
+ else:
735
+ self.modifier = None
736
+
737
+
738
+ train_dl = DataLoader(train, batch_size = batch_size , shuffle=True,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
739
+ valid_dl = DataLoader(validation, batch_size = batch_size , shuffle=False,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
740
+
741
+ checkpoint_callback = ModelCheckpoint(dirpath=dirpath,
742
+ monitor='val_loss',
743
+ save_last = True,
744
+ every_n_epochs =1,
745
+ verbose = self.verbose,
746
+ save_top_k = 1,
747
+ filename='checkpoint')
748
+
749
+
750
+ #logger = CSVLogger("logs", name=dirpath)
751
+ aim_logger = AimLogger(
752
+ experiment=self.name,
753
+ train_metric_prefix='train_',
754
+ val_metric_prefix='val_',
755
+ )
756
+
757
+ #https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
758
+ n_params = sum(dict((p.data_ptr(), p.numel()) for p in self.model.parameters()).values())
759
+
760
+ #https://discuss.pytorch.org/t/finding-model-size/130275
761
+ param_size = 0
762
+ for param in self.model.parameters():
763
+ param_size += param.nelement() * param.element_size()
764
+ buffer_size = 0
765
+ for buffer in self.model.buffers():
766
+ buffer_size += buffer.nelement() * buffer.element_size()
767
+
768
+ size_all_mb = (param_size + buffer_size) / 1024**2
769
+ #aim_logger.experiment.track(self.model.name,name='model_name')
770
+
771
+ aim_logger.experiment.track(n_params,name='N-parameters')
772
+ aim_logger.experiment.track(size_all_mb,name='dim-model-MB')
773
+ aim_logger.experiment.track(len(train_dl.dataset),name='len-train')
774
+ aim_logger.experiment.track(len(valid_dl.dataset),name='len-valid')
775
+ #aim_logger.experiment.track(self.config,name=None)
776
+ tmp = self.config.copy()
777
+ tmp['model_name'] = self.model.name
778
+ aim_logger._run['hyperparameters'] = tmp
779
+
780
+ mc = MetricsCallback(dirpath)
781
+ ## TODO se ci sono 2 o piu gpu MetricsCallback non funziona (secondo me fa una istanza per ogni dataparallel che lancia e poi non riesce a recuperare info)
782
+ pl.seed_everything(seed, workers=True)
783
+ self.model.max_epochs = max_epochs
784
+
785
+ if os.path.isfile(os.path.join(dirpath,'last.ckpt')):
786
+ weight_exists = True
787
+ beauty_string('I loaded a previous checkpoint','section',self.verbose)
788
+
789
+ else:
790
+ weight_exists = False
791
+ beauty_string('I can not load a previous model','section',self.verbose)
792
+
793
+
794
+
795
+
796
+ if OLD_PL:
797
+ trainer = pl.Trainer(default_root_dir=dirpath,
798
+ logger = aim_logger,
799
+ max_epochs=max_epochs,
800
+ callbacks=[checkpoint_callback,mc],
801
+ auto_lr_find=auto_lr_find,
802
+ accelerator=accelerator,
803
+ devices=devices,
804
+ strategy=strategy,
805
+ enable_progress_bar=False,
806
+ precision=precision,
807
+ gradient_clip_val=gradient_clip_val,
808
+ gradient_clip_algorithm=gradient_clip_algorithm)#,devices=1)
809
+ else:
810
+ trainer = pl.Trainer(default_root_dir=dirpath,
811
+ logger = aim_logger,
812
+ max_epochs=max_epochs,
813
+ callbacks=[checkpoint_callback,mc],
814
+ strategy='auto',
815
+ devices=devices,
816
+ enable_progress_bar=False,
817
+ precision=precision,
818
+ gradient_clip_val=gradient_clip_val,
819
+ gradient_clip_algorithm=gradient_clip_algorithm)#,devices=1)
820
+ tot_seconds = time.time()
821
+
822
+
823
+
824
+
825
+ if auto_lr_find and (weight_exists is False):
826
+ if OLD_PL:
827
+ lr_tuner = trainer.tune(self.model,train_dataloaders=train_dl,val_dataloaders = valid_dl)
828
+ files = os.listdir(dirpath)
829
+ for f in files:
830
+ if '.lr_find' in f:
831
+ os.remove(os.path.join(dirpath,f))
832
+ self.model.optim_config['lr'] = lr_tuner['lr_find'].suggestion()
833
+ else:
834
+ from lightning.pytorch.tuner import Tuner
835
+ tuner = Tuner(trainer)
836
+ lr_finder = tuner.lr_find(self.model,train_dataloaders=train_dl,val_dataloaders = valid_dl)
837
+ self.model.optim_config['lr'] = lr_finder.suggestion() ## we are using it as optim key
838
+
839
+
840
+
841
+ if OLD_PL:
842
+ if weight_exists:
843
+ trainer.fit(self.model, train_dl,valid_dl,ckpt_path=os.path.join(dirpath,'last.ckpt'))
844
+ else:
845
+ trainer.fit(self.model, train_dl,valid_dl)
846
+ else:
847
+ if weight_exists:
848
+ trainer.fit(self.model, train_dataloaders = train_dl,val_dataloaders = valid_dl,ckpt_path=os.path.join(dirpath,'last.ckpt'))
849
+ else:
850
+ trainer.fit(self.model, train_dataloaders = train_dl,val_dataloaders = valid_dl)
851
+ self.checkpoint_file_best = checkpoint_callback.best_model_path
852
+ self.checkpoint_file_last = checkpoint_callback.last_model_path
853
+ if self.checkpoint_file_last=='':
854
+ beauty_string('There is a bug on saving last model I will try to fix it','info',self.verbose)
855
+ self.checkpoint_file_last = checkpoint_callback.best_model_path.replace('checkpoint','last')
856
+
857
+ self.dirpath = dirpath
858
+
859
+ self.losses = mc.metrics
860
+
861
+ files = os.listdir(dirpath)
862
+ ##accrocchio per multi gpu
863
+ for f in files:
864
+ if '__losses__.csv' in f:
865
+ if len(self.losses['val_loss'])>0:
866
+ self.losses = pd.DataFrame(self.losses)
867
+ else:
868
+ self.losses = pd.read_csv(os.path.join(os.path.join(dirpath,f)))
869
+ os.remove(os.path.join(os.path.join(dirpath,f)))
870
+ if isinstance(self.losses,dict):
871
+ self.losses = pd.DataFrame()
872
+
873
+ try:
874
+ if OLD_PL:
875
+ self.model = self.model.load_from_checkpoint(self.checkpoint_file_last)
876
+ else:
877
+ self.model = self.model.__class__.load_from_checkpoint(self.checkpoint_file_last)
878
+
879
+ except Exception as _:
880
+ beauty_string(f'There is a problem loading the weights on file MAYBE CHANGED HOW WEIGHTS ARE LOADED {self.checkpoint_file_last}','section',self.verbose)
881
+
882
+ try:
883
+ val_loss = self.losses.val_loss.values[-1]
884
+ except Exception as _:
885
+ beauty_string('Can not extract the validation loss, maybe it is a persistent model','info',self.verbose)
886
+ val_loss = 100
887
+ self.is_trained = True
888
+
889
+ beauty_string('END of the training process','block',self.verbose)
890
+
891
+ aim_logger.experiment.track((time.time()-tot_seconds),name='seconds-training')
892
+ aim_logger.experiment.track(val_loss,name='val-loss-end-train')
893
+
894
+
895
+
896
+ return val_loss
897
+
898
+ def inference_on_set(self,batch_size:int=100,
899
+ num_workers:int=4,
900
+ split_params:Union[None,dict]=None,set:str='test',
901
+ rescaling:bool=True,
902
+ data:Union[None,torch.utils.data.Dataset]=None)->pd.DataFrame:
903
+ """This function allows to get the prediction on a particular set (train, test or validation).
904
+
905
+ Args:
906
+ batch_size (int, optional): barch sise. Defaults to 100.
907
+ num_workers (int, optional): num workers. Defaults to 4.
908
+ split_params (Union[None,dict], optional): if not None the spliting procedure will use the given data otherwise it will use the same configuration used in train. Defaults to None.
909
+ set (str, optional): trai, validation or test. Defaults to 'test'.
910
+ rescaling (bool, optional): If rescaling is true the output will be rescaled to the initial values. . Defaults to True.
911
+ data (None or pd.DataFrame, optional). If not None the inference is performed on the given data. In the case of custom data please call inference because it will normalize the data for you!
912
+ Returns:
913
+ pd.DataFrame: the predicted values in a pandas format
914
+ """
915
+
916
+ beauty_string('Inference on a set (train, validation o test)','block',self.verbose)
917
+
918
+ if data is None:
919
+ if split_params is None:
920
+ beauty_string(f'splitting using train parameters {self.split_params}','section',self.verbose)
921
+ train,validation,test = self.split_for_train(**self.split_params)
922
+ else:
923
+ train,validation,test = self.split_for_train(**split_params)
924
+
925
+ if set=='test':
926
+ if self.modifier is not None:
927
+ test = self.modifier.transform(test)
928
+ dl = DataLoader(test, batch_size = batch_size , shuffle=False,drop_last=False,num_workers=num_workers)
929
+ elif set=='validation':
930
+ if self.modifier is not None:
931
+ validation = self.modifier.transform(validation)
932
+ dl = DataLoader(validation, batch_size = batch_size , shuffle=False,drop_last=False,num_workers=num_workers)
933
+ elif set=='train':
934
+ if self.modifier is not None:
935
+ train = self.modifier.transform(train)
936
+ dl = DataLoader(train, batch_size = batch_size , shuffle=False,drop_last=False,num_workers=num_workers)
937
+ elif set=='custom':
938
+ if self.check_custom:
939
+ pass
940
+ else:
941
+ beauty_string('If you are here something went wrong, please report it','section',self.verbose)
942
+ if self.modifier is not None:
943
+ data = self.modifier.transform(data)
944
+ dl = DataLoader(data, batch_size = batch_size , shuffle=False,drop_last=False,num_workers=num_workers)
945
+
946
+ else:
947
+ beauty_string('Select one of train, test, or validation set','section',self.verbose)
948
+ self.model.eval()
949
+
950
+ res = []
951
+ real = []
952
+ self.model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
953
+ beauty_string(f'Device used: {self.model.device}','info',self.verbose)
954
+
955
+ for batch in dl:
956
+ res.append(self.model.inference(batch).cpu().detach().numpy())
957
+ real.append(batch['y'].cpu().detach().numpy())
958
+
959
+ res = np.vstack(res)
960
+
961
+ real = np.vstack(real)
962
+ time = dl.dataset.t
963
+ groups = dl.dataset.groups
964
+ #import pdb
965
+ #pdb.set_trace()
966
+ if self.modifier is not None:
967
+ res,real = self.modifier.inverse_transform(res,real)
968
+
969
+ ## BxLxCx3
970
+ if rescaling:
971
+ beauty_string('Scaling back','info',self.verbose)
972
+ if self.normalize_per_group is False:
973
+ for i, c in enumerate(self.target_variables):
974
+ real[:,:,i] = self.scaler_num[c].inverse_transform(real[:,:,i].reshape(-1,1)).reshape(-1,real.shape[1])
975
+ for j in range(res.shape[3]):
976
+ res[:,:,i,j] = self.scaler_num[c].inverse_transform(res[:,:,i,j].reshape(-1,1)).reshape(-1,res.shape[1])
977
+ else:
978
+ for group in np.unique(groups):
979
+ idx = np.where(groups==group)[0]
980
+ for i, c in enumerate(self.target_variables):
981
+ real[idx,:,i] = self.scaler_num[f'{c}_{group}'].inverse_transform(real[idx,:,i].reshape(-1,1)).reshape(-1,real.shape[1])
982
+ for j in range(res.shape[3]):
983
+ res[idx,:,i,j] = self.scaler_num[f'{c}_{group}'].inverse_transform(res[idx,:,i,j].reshape(-1,1)).reshape(-1,res.shape[1])
984
+
985
+ if self.model.use_quantiles:
986
+ time = pd.DataFrame(time,columns=[i+1 for i in range(res.shape[1])])
987
+
988
+ if self.group is not None:
989
+ time[self.group] = groups
990
+ time = time.melt(id_vars=['region'])
991
+ else:
992
+ time = time.melt()
993
+ time.rename(columns={'value':'time','variable':'lag'},inplace=True)
994
+
995
+
996
+ tot = [time]
997
+ for i, c in enumerate(self.target_variables):
998
+ tot.append(pd.DataFrame(real[:,:,i],columns=[i+1 for i in range(res.shape[1])]).melt().rename(columns={'value':c}).drop(columns=['variable']))
999
+ tot.append(pd.DataFrame(res[:,:,i,0],columns=[i+1 for i in range(res.shape[1])]).melt().rename(columns={'value':c+'_low'}).drop(columns=['variable']))
1000
+ tot.append(pd.DataFrame(res[:,:,i,1],columns=[i+1 for i in range(res.shape[1])]).melt().rename(columns={'value':c+'_median'}).drop(columns=['variable']))
1001
+ tot.append(pd.DataFrame(res[:,:,i,2],columns=[i+1 for i in range(res.shape[1])]).melt().rename(columns={'value':c+'_high'}).drop(columns=['variable']))
1002
+
1003
+ res = pd.concat(tot,axis=1)
1004
+
1005
+
1006
+ ## BxLxCx1
1007
+ else:
1008
+ time = pd.DataFrame(time,columns=[i+1 for i in range(res.shape[1])])#.melt()
1009
+
1010
+ if self.group is not None:
1011
+ time[self.group] = groups
1012
+ time = time.melt(id_vars=['region'])
1013
+ else:
1014
+ time = time.melt()
1015
+ time.rename(columns={'value':'time','variable':'lag'},inplace=True)
1016
+
1017
+
1018
+ tot = [time]
1019
+ for i, c in enumerate(self.target_variables):
1020
+ tot.append(pd.DataFrame(real[:,:,i],columns=[i+1 for i in range(res.shape[1])]).melt().rename(columns={'value':c}).drop(columns=['variable']))
1021
+ tot.append(pd.DataFrame(res[:,:,i,0],columns=[i+1 for i in range(res.shape[1])]).melt().rename(columns={'value':c+'_pred'}).drop(columns=['variable']))
1022
+ res = pd.concat(tot,axis=1)
1023
+
1024
+ res['prediction_time'] = res.apply(lambda x: x.time-self.freq*x.lag, axis=1)
1025
+ return res
1026
+ def inference(self,batch_size:int=100,
1027
+ num_workers:int=4,
1028
+ split_params:Union[None,dict]=None,
1029
+ rescaling:bool=True,
1030
+ data:pd.DataFrame=None,
1031
+ steps_in_future:int=0,
1032
+ check_holes_and_duplicates:bool=True,
1033
+ is_inference:bool=False)->pd.DataFrame: ##TODO PUSH THIS ON PTF!
1034
+
1035
+ """similar to `inference_on_set`
1036
+ only change is split_params that must contain this keys but using the default can be sufficient:
1037
+ 'past_steps','future_steps','shift','keep_entire_seq_while_shifting','starting_point'
1038
+
1039
+ skip_step is set to 1 for convenience (generally you want all the predictions)
1040
+ You can set split_params to None and use the standard parameters (at your own risck)
1041
+
1042
+
1043
+ Args:
1044
+ batch_size (int, optional): see inference_on_set. Defaults to 100.
1045
+ num_workers (int, optional): inference_on_set. Defaults to 4.
1046
+ split_params (Union[None,dict], optional): inference_on_set. Defaults to None.
1047
+ rescaling (bool, optional): inference_on_set. Defaults to True.
1048
+ data (pd.DataFrame, optional): startin dataset. Defaults to None.
1049
+ steps_in_future (int, optional): if>0 the dataset is extendend in order to make predictions in the future. Defaults to 0.
1050
+ check_holes_and_duplicates (bool, optional): if False the routine does not check for holes or for duplicates, set to False for stacked model. Defaults to True.
1051
+
1052
+ Returns:
1053
+ pd.DataFrame: predicted values
1054
+ """
1055
+ beauty_string('Inference on a custom dataset','block',self.verbose)
1056
+ self.check_custom = True ##this is a check for the dataset loading
1057
+ ## enlarge the dataset in order to have all the rows needed
1058
+ if check_holes_and_duplicates:
1059
+ if self.group is None:
1060
+ ##freq = pd.to_timedelta(np.diff(data.time).min())
1061
+ freq = self.freq #TODO port it into PTF
1062
+ beauty_string(f'Detected minumum frequency: {freq}','section',self.verbose)
1063
+ ## TODO work on this for consistency
1064
+ empty = pd.DataFrame({'time':pd.date_range(data.time.min(),data.time.max()+freq*(steps_in_future+self.split_params['past_steps']+self.split_params['future_steps']),freq=freq)})
1065
+
1066
+ else:
1067
+ freq = pd.to_timedelta(np.diff(data[data[self.group==data[self.group].unique()[0]]].time).min())
1068
+ beauty_string(f'Detected minumum frequency: {freq} supposing constant frequence inside the groups','section',self.verbose)
1069
+ _min = data.groupby(self.group).time.min()
1070
+ _max = data.groupby(self.group).time.max()
1071
+ empty = []
1072
+ for c in data[self.group].unique():
1073
+ empty.append(pd.DataFrame({self.group:c,'time':pd.date_range(_min.time[_min[self.group]==c].values[0],_max.time[_max[self.group]==c].values[0]+freq*(steps_in_future+self.split_params['past_steps']+self.split_params['future_steps']),freq=freq)}))
1074
+ empty = pd.concat(empty,ignore_index=True)
1075
+ dataset = empty.merge(data,how='left')
1076
+ #TODO port it into PTF
1077
+ for c in self.cat_var:
1078
+ self.enrich(dataset, c)
1079
+ else:
1080
+ dataset = data.copy()
1081
+
1082
+
1083
+ if split_params is None:
1084
+ split_params = {}
1085
+ for c in self.split_params.keys():
1086
+ if c in ['past_steps','future_steps','shift','keep_entire_seq_while_shifting','starting_point']:
1087
+ split_params[c] = self.split_params[c]
1088
+ split_params['skip_step']=1
1089
+ data = self.create_data_loader(dataset,**split_params,is_inference=is_inference)
1090
+ else:
1091
+ data = self.create_data_loader(data,**split_params,is_inference=is_inference)
1092
+
1093
+ res = self.inference_on_set(batch_size=batch_size,num_workers=num_workers,split_params=None,set='custom',rescaling=rescaling,data=data)
1094
+ self.check_custom = False
1095
+ return res
1096
+
1097
+ def save(self, filename:str)->None:
1098
+ """save the timeseries object
1099
+
1100
+ Args:
1101
+ filename (str): name of the file
1102
+ """
1103
+ beauty_string('Saving','block',self.verbose)
1104
+ with open(f'{filename}.pkl','wb') as f:
1105
+ params = self.__dict__.copy()
1106
+ for k in ['model']:
1107
+ if k in params.keys():
1108
+ _ = params.pop(k)
1109
+ pickle.dump(params,f)
1110
+
1111
+
1112
+ def load(self,model:Base, filename:str,load_last:bool=True,dirpath:Union[str,None]=None,weight_path:Union[str, None]=None)->None:
1113
+ """ Load a saved model
1114
+
1115
+ Args:
1116
+ model (Base): class of the model to load (it will be initiated by pytorch-lightening)
1117
+ filename (str): filename of the saved model
1118
+ load_last (bool, optional): if true the last checkpoint will be loaded otherwise the best (in the validation set). Defaults to True.
1119
+ dirpath (Union[str,None], optional): if None we asssume that the model is loaded from the same pc where it has been trained, otherwise we can pass the dirpath where all the stuff has been saved . Defaults to None.
1120
+ weight_path (Union[str, None], optional): if None the standard path will be used. Defaults to None.
1121
+ """
1122
+
1123
+
1124
+
1125
+ beauty_string('Loading','block',self.verbose)
1126
+ self.modifier = None
1127
+ self.check_custom = False
1128
+ self.is_trained = True
1129
+ with open(filename+'.pkl','rb') as f:
1130
+ params = pickle.load(f)
1131
+ for p in params:
1132
+ setattr(self,p, params[p])
1133
+ if 'verbose' in self.config['model_configs'].keys():
1134
+ self.config['model_configs'].pop('verbose')
1135
+ self.model = model(**self.config['model_configs'],optim_config = self.config['optim_config'],scheduler_config =self.config['scheduler_config'],verbose=self.verbose )
1136
+
1137
+
1138
+ if weight_path is not None:
1139
+ tmp_path = weight_path
1140
+ else:
1141
+ if self.dirpath is not None:
1142
+ directory = self.dirpath
1143
+ else:
1144
+ directory = dirpath
1145
+
1146
+ if load_last:
1147
+
1148
+ try:
1149
+ tmp_path = os.path.join(directory,self.checkpoint_file_last.split('/')[-1])
1150
+ except Exception as _:
1151
+ beauty_string('checkpoint_file_last not defined try to load best','section',self.verbose)
1152
+ tmp_path = os.path.join(directory,self.checkpoint_file_best.split('/')[-1])
1153
+ else:
1154
+ try:
1155
+ tmp_path = os.path.join(directory,self.checkpoint_file_best.split('/')[-1])
1156
+ except Exception as _:
1157
+ beauty_string('checkpoint_file_best not defined try to load best','section',self.verbose)
1158
+ tmp_path = os.path.join(directory,self.checkpoint_file_last.split('/')[-1])
1159
+ try:
1160
+ #with torch.serialization.add_safe_globals([ListConfig]):
1161
+ if OLD_PL:
1162
+ self.model = self.model.load_from_checkpoint(tmp_path,verbose=self.verbose,)
1163
+ else:
1164
+ self.model = self.model.__class__.load_from_checkpoint(tmp_path,verbose=self.verbose,)
1165
+
1166
+ except Exception as e:
1167
+ beauty_string(f'There is a problem loading the weights on file {tmp_path} {e}','section',self.verbose)