dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1167 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import plotly.express as px
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from typing import List
|
|
5
|
+
from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
|
|
6
|
+
from sklearn.preprocessing import *
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
from .utils import extend_time_df,MetricsCallback, MyDataset, ActionEnum,beauty_string
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
|
|
12
|
+
#new version of lightning
|
|
13
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
14
|
+
import lightning.pytorch as pl
|
|
15
|
+
from ..models.base_v2 import Base
|
|
16
|
+
beauty_string('V2','block',True)
|
|
17
|
+
OLD_PL = False
|
|
18
|
+
except:
|
|
19
|
+
## older version of lightning
|
|
20
|
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
21
|
+
import pytorch_lightning as pl
|
|
22
|
+
from ..models.base import Base
|
|
23
|
+
beauty_string('V1','block',True)
|
|
24
|
+
|
|
25
|
+
OLD_PL = True
|
|
26
|
+
|
|
27
|
+
from typing import Union
|
|
28
|
+
import os
|
|
29
|
+
import torch
|
|
30
|
+
import pickle
|
|
31
|
+
from datetime import datetime
|
|
32
|
+
from ..models.utils import weight_init_zeros,weight_init
|
|
33
|
+
import logging
|
|
34
|
+
from .modifiers import *
|
|
35
|
+
from aim.pytorch_lightning import AimLogger
|
|
36
|
+
import time
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
pd.options.mode.chained_assignment = None
|
|
41
|
+
log = logging.getLogger(__name__)
|
|
42
|
+
log.addHandler(logging.NullHandler())
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Categorical():
|
|
46
|
+
|
|
47
|
+
def __init__(self,name:str, frequency: int,duration: List[int], classes: int, action: ActionEnum, level: List[float]):
|
|
48
|
+
"""Class for generating toy categorical data
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
name (str): name of the categorical signal
|
|
52
|
+
frequency (int): frequency of the signal
|
|
53
|
+
duration (List[int]): duration of each class
|
|
54
|
+
classes (int): number of classes
|
|
55
|
+
action (str): one between additive or multiplicative
|
|
56
|
+
level (List[float]): intensity of each class
|
|
57
|
+
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
self.name = name
|
|
61
|
+
self.frequency = frequency
|
|
62
|
+
self.duration = duration
|
|
63
|
+
self.classes = classes
|
|
64
|
+
self.action = action
|
|
65
|
+
self.level = level
|
|
66
|
+
self.validate()
|
|
67
|
+
|
|
68
|
+
def validate(self):
|
|
69
|
+
"""Validate, maybe there will be other checks in the future
|
|
70
|
+
|
|
71
|
+
:meta private:
|
|
72
|
+
"""
|
|
73
|
+
if len(self.level) == self.classes:
|
|
74
|
+
pass
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError("Length must match")
|
|
77
|
+
|
|
78
|
+
def generate_signal(self,length:int)->None:
|
|
79
|
+
"""Generate the resposne signal
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
length (int): length of the signal
|
|
83
|
+
"""
|
|
84
|
+
if self.action == 'multiplicative':
|
|
85
|
+
signal = np.ones(length)
|
|
86
|
+
elif self.action == 'additive':
|
|
87
|
+
signal = np.zeros(length)
|
|
88
|
+
classes = []
|
|
89
|
+
_class = 0
|
|
90
|
+
_level = self.level[0]
|
|
91
|
+
_duration = self.duration[0]
|
|
92
|
+
|
|
93
|
+
count = 0
|
|
94
|
+
count_freq = 0
|
|
95
|
+
for i in range(length):
|
|
96
|
+
if count_freq%self.frequency == 0:
|
|
97
|
+
signal[i] = _level
|
|
98
|
+
classes.append(_class)
|
|
99
|
+
count+=1
|
|
100
|
+
if count == _duration:
|
|
101
|
+
#change class
|
|
102
|
+
count = 0
|
|
103
|
+
_class+=1
|
|
104
|
+
count_freq+= _duration
|
|
105
|
+
_class = _class%self.classes
|
|
106
|
+
_level = self.level[_class]
|
|
107
|
+
_duration = self.duration[_class]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
else:
|
|
111
|
+
classes.append(-1)
|
|
112
|
+
count_freq+=1
|
|
113
|
+
|
|
114
|
+
self.classes_array = classes
|
|
115
|
+
self.signal_array = signal
|
|
116
|
+
|
|
117
|
+
def plot(self)->None:
|
|
118
|
+
"""Plot the series
|
|
119
|
+
"""
|
|
120
|
+
tmp = pd.DataFrame({'time':range(len(self.classes_array)),'signal':self.signal,'class':self.classes_array})
|
|
121
|
+
fig = px.scatter(tmp,x='time',y='signal',color='class',title=self.name)
|
|
122
|
+
fig.show()
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class TimeSeries():
|
|
126
|
+
|
|
127
|
+
def __init__(self,name:str,stacked:bool=False):
|
|
128
|
+
"""Class for generating time series object. If you don't have any time series you can build one fake timeseries using some helping classes (Categorical for instance).
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
name (str): name of the series
|
|
133
|
+
stacked (bool): if true it is a stacked model
|
|
134
|
+
|
|
135
|
+
Usage:
|
|
136
|
+
For example we can generate a toy timeseries:\n
|
|
137
|
+
- add a multiplicative categorical feature (weekly)\n
|
|
138
|
+
>>> settimana = Categorical('settimanale',1,[1,1,1,1,1,1,1],7,'multiplicative',[0.9,0.8,0.7,0.6,0.5,0.99,0.99])\n
|
|
139
|
+
- an additive montly feature (here a year is composed by 5 months)\n
|
|
140
|
+
>>> mese = Categorical('mensile',1,[31,28,20,10,33],5,'additive',[10,20,-10,20,0])\n
|
|
141
|
+
- a spotted categorical variable that happens every 100 days and lasts 1 day\n
|
|
142
|
+
>>> spot = Categorical('spot',100,[7],1,'additive',[10])\n
|
|
143
|
+
>>> ts = TimeSeries('prova')\n
|
|
144
|
+
>>> ts.generate_signal(length = 5000,categorical_variables = [settimana,mese,spot],noise_mean=1,type=0) ##we can add also noise\n
|
|
145
|
+
>>> ts.plot()\n
|
|
146
|
+
"""
|
|
147
|
+
self.is_trained = False
|
|
148
|
+
self.name = name
|
|
149
|
+
self.stacked = stacked
|
|
150
|
+
self.verbose = True
|
|
151
|
+
self.group = None
|
|
152
|
+
def __str__(self) -> str:
|
|
153
|
+
return f"Timeseries named {self.name} of length {self.dataset.shape[0]}.\n Categorical variable: {self.cat_var},\n Future variables: {self.future_variables},\n Past variables: {self.past_variables},\n Target variables: {self.target_variables} \n With {'no group' if self.group is None else self.group+' as group' }"
|
|
154
|
+
def __repr__(self) -> str:
|
|
155
|
+
return f"Timeseries named {self.name} of length {self.dataset.shape[0]}.\n Categorical variable: {self.cat_var},\n Future variables: {self.future_variables},\n Past variables: {self.past_variables},\n Target variables: {self.target_variables}\n With {'no group' if self.group is None else self.group+' as group' }"
|
|
156
|
+
|
|
157
|
+
def set_verbose(self,verbose:bool):
|
|
158
|
+
self.verbose = verbose
|
|
159
|
+
def _generate_base(self,length:int,type:int=0)-> None:
|
|
160
|
+
"""Generate a basic timeseries
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
length (int): length
|
|
164
|
+
type (int, optional): Type of the generated timeseries. Defaults to 0.
|
|
165
|
+
"""
|
|
166
|
+
if type==0:
|
|
167
|
+
self.base_signal = 10*np.cos(np.arange(length)/(2*np.pi*length/100))
|
|
168
|
+
self.out_vars = 1
|
|
169
|
+
else:
|
|
170
|
+
beauty_string('Please implement your own method','block',True)
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
"""
|
|
174
|
+
def generate_signal(self,length:int=5000,categorical_variables:List[Categorical]=[],noise_mean:int=1,type:int=0)->None:
|
|
175
|
+
"""This will generate a syntetic signal with a selected length, a noise level and some categorical variables. The additive series are added at the end while the multiplicative series acts on the original signal
|
|
176
|
+
The TS structure will be populated
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
length (int, optional): length of the signal. Defaults to 5000.
|
|
180
|
+
categorical_variables (List[Categorical], optional): list of Categorical variables. Defaults to [].
|
|
181
|
+
noise_mean (int, optional): variance of the noise to add at the end. Defaults to 1.
|
|
182
|
+
type (int, optional): type of the timeseries (only type=0 available right now). Defaults to 0.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
dataset = pd.DataFrame({'time':range(length)})
|
|
187
|
+
self._generate_base(length,type)
|
|
188
|
+
signal = self.base_signal.copy()
|
|
189
|
+
tot = None
|
|
190
|
+
self.cat_var = []
|
|
191
|
+
for c in categorical_variables:
|
|
192
|
+
c.generate_signal(length)
|
|
193
|
+
_signal = c.signal_array
|
|
194
|
+
classes = c.classes_array
|
|
195
|
+
dataset[c.name] = classes
|
|
196
|
+
self.cat_var.append(c.name)
|
|
197
|
+
if c.action=='multiplicative':
|
|
198
|
+
signal*=_signal
|
|
199
|
+
else:
|
|
200
|
+
if tot is None:
|
|
201
|
+
additive = _signal
|
|
202
|
+
else:
|
|
203
|
+
additive+=_signal
|
|
204
|
+
signal+=additive
|
|
205
|
+
dataset['signal'] = signal + noise_mean*np.random.randn(len(signal))
|
|
206
|
+
self.dataset = dataset
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
self.past_variables = ['signal']
|
|
210
|
+
self.future_variables = []
|
|
211
|
+
self.target_variables = ['signal']
|
|
212
|
+
self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def enrich(self,dataset,columns):
|
|
216
|
+
if columns =='hour':
|
|
217
|
+
dataset[columns] = dataset.time.dt.hour
|
|
218
|
+
elif columns=='dow':
|
|
219
|
+
dataset[columns] = dataset.time.dt.weekday
|
|
220
|
+
elif columns=='month':
|
|
221
|
+
dataset[columns] = dataset.time.dt.month
|
|
222
|
+
elif columns=='minute':
|
|
223
|
+
dataset[columns] = dataset.time.dt.minute
|
|
224
|
+
else:
|
|
225
|
+
if columns not in dataset.columns:
|
|
226
|
+
beauty_string(f'I can not automatically enrich column {columns}. Please contact the developers or add it manually to your dataset.','section',True)
|
|
227
|
+
|
|
228
|
+
def load_signal(self,data:pd.DataFrame,
|
|
229
|
+
enrich_cat:List[str] = [],
|
|
230
|
+
past_variables:List[str]=[],
|
|
231
|
+
future_variables:List[str]=[],
|
|
232
|
+
target_variables:List[str]=[],
|
|
233
|
+
cat_past_var:List[str]=[],
|
|
234
|
+
cat_fut_var:List[str]=[],
|
|
235
|
+
check_past:bool=True,
|
|
236
|
+
group:Union[None,str]=None,
|
|
237
|
+
check_holes_and_duplicates:bool=True,
|
|
238
|
+
silly_model:bool=False)->None:
|
|
239
|
+
""" This is a crucial point in the data structure. We expect here to have a dataset with time as timestamp.
|
|
240
|
+
There are some checks:
|
|
241
|
+
1- the duplicates will tbe removed taking the first instance
|
|
242
|
+
|
|
243
|
+
2- the frequency will the inferred taking the minumum time distance between samples
|
|
244
|
+
|
|
245
|
+
3- the dataset will be filled completing the missing timestamps
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
data (pd.DataFrame): input dataset the column indicating the time must be called `time`
|
|
249
|
+
enrich_cat (List[str], optional): it is possible to let this function enrich the dataset for example adding the standard columns: hour, dow, month and minute. Defaults to [].
|
|
250
|
+
past_variables (List[str], optional): list of column names of past variables not available for future times . Defaults to [].
|
|
251
|
+
future_variables (List[str], optional): list of future variables available for tuture times. Defaults to [].
|
|
252
|
+
target_variables (List[str], optional): list of the target variables. They will added to past_variables by default unless `check_past` is false. Defaults to [].
|
|
253
|
+
cat_past_var (List[str], optional): list of the past categorical variables. Defaults to [].
|
|
254
|
+
cat_future_var (List[str], optional): list of the future categorical variables. Defaults to [].
|
|
255
|
+
check_past (bool, optional): see `target_variables`. Defaults to True.
|
|
256
|
+
group (str or None, optional): if not None the time serie dataset is considered composed by omogeneus timeseries coming from different realization (for example point of sales, cities, locations) and the relative series are not splitted during the sample generation. Defaults to None
|
|
257
|
+
check_holes_and_duplicates (bool, optional): if False duplicates or holes will not checked, the dataloader can not correctly work, disable at your own risk. Defaults True
|
|
258
|
+
silly_model (bool, optional): if True, target variables will be added to the pool of the future variables. This can be useful to see if information passes throught the decoder part of your model (if any)
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
dataset = data.copy()
|
|
264
|
+
dataset.sort_values(by='time',inplace=True)
|
|
265
|
+
|
|
266
|
+
if check_holes_and_duplicates:
|
|
267
|
+
beauty_string('I will drop duplicates, I dont like them','section',self.verbose)
|
|
268
|
+
dataset.drop_duplicates(subset=['time'] if group is None else [group,'time'], keep='first', inplace=True, ignore_index=True)
|
|
269
|
+
|
|
270
|
+
if group is None:
|
|
271
|
+
differences = dataset.time.diff()[1:]
|
|
272
|
+
else:
|
|
273
|
+
differences = dataset[dataset[group]==dataset[group].unique()[0]].time.diff()[1:]
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
if isinstance(dataset.time[0], datetime):
|
|
277
|
+
freq = pd.to_timedelta(differences.min())
|
|
278
|
+
else:
|
|
279
|
+
if int(dataset.time[0])==dataset.time[0]: ##ONLY THINK THAT WORKS IN GENERAL
|
|
280
|
+
freq = int(differences.min())
|
|
281
|
+
else:
|
|
282
|
+
raise TypeError("time must be integer or datetime")
|
|
283
|
+
self.freq = freq
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
if differences.nunique()>1:
|
|
287
|
+
beauty_string("There are holes in the dataset i will try to extend the dataframe inserting NAN",'info',self.verbose)
|
|
288
|
+
beauty_string(f'Detected minumum frequency: {freq}','section',self.verbose)
|
|
289
|
+
dataset = extend_time_df(dataset,freq,group).merge(dataset,how='left')
|
|
290
|
+
else:
|
|
291
|
+
beauty_string("I will compute the frequency as minimum of the time difference",'info',self.verbose)
|
|
292
|
+
self.freq = dataset.time.diff()[1:].min()
|
|
293
|
+
if isinstance(dataset.time.dtype, datetime):
|
|
294
|
+
self.freq = pd.to_timedelta(self.freq)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
assert len(target_variables)>0, 'Provide at least one column for target'
|
|
298
|
+
assert 'time' in dataset.columns, 'The temporal column must be called time'
|
|
299
|
+
if set(target_variables).intersection(set(past_variables))!= set(target_variables):
|
|
300
|
+
if check_past:
|
|
301
|
+
beauty_string('I will update past column adding all target columns, if you want to avoid this beahviour please use check_pass as false','info',self.verbose)
|
|
302
|
+
past_variables = list(set(past_variables).union(set(target_variables)))
|
|
303
|
+
|
|
304
|
+
self.cat_past_var = cat_past_var
|
|
305
|
+
self.cat_fut_var = cat_fut_var
|
|
306
|
+
|
|
307
|
+
self.group = group
|
|
308
|
+
if group is not None:
|
|
309
|
+
if group not in cat_past_var:
|
|
310
|
+
beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
|
|
311
|
+
self.cat_var.append(group)
|
|
312
|
+
if group not in cat_fut_var:
|
|
313
|
+
beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
|
|
314
|
+
self.cat_fut_var.append(group)
|
|
315
|
+
|
|
316
|
+
self.enrich_cat = enrich_cat
|
|
317
|
+
for c in enrich_cat:
|
|
318
|
+
self.cat_past_var = list(set(self.cat_past_var+[c]))
|
|
319
|
+
self.cat_fut_var = list(set(self.cat_fut_var+[c]))
|
|
320
|
+
if c in dataset.columns:
|
|
321
|
+
beauty_string('Categorical {c} already present, it will be added to categorical variable but not call the enriching function','info',self.verbose)
|
|
322
|
+
else:
|
|
323
|
+
self.enrich(dataset,c)
|
|
324
|
+
self.cat_var = list(set(self.cat_past_var+self.cat_fut_var)) ## all categorical data
|
|
325
|
+
|
|
326
|
+
self.dataset = dataset
|
|
327
|
+
self.past_variables = past_variables
|
|
328
|
+
self.future_variables = future_variables
|
|
329
|
+
self.target_variables = target_variables
|
|
330
|
+
self.out_vars = len(target_variables)
|
|
331
|
+
self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
|
|
332
|
+
if silly_model:
|
|
333
|
+
beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
|
|
334
|
+
self.future_variables+=self.target_variables
|
|
335
|
+
|
|
336
|
+
def plot(self):
|
|
337
|
+
"""
|
|
338
|
+
Easy way to control the loaded data
|
|
339
|
+
Returns:
|
|
340
|
+
plotly.graph_objects._figure.Figure: figure of the target variables
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
beauty_string('Plotting only target variables','block',self.verbose)
|
|
344
|
+
if self.group is None:
|
|
345
|
+
tmp = self.dataset[['time']+self.target_variables].melt(id_vars=['time'])
|
|
346
|
+
fig = px.line(tmp,x='time',y='value',color='variable',title=self.name)
|
|
347
|
+
fig.show()
|
|
348
|
+
else:
|
|
349
|
+
tmp = self.dataset[['time',self.group]+self.target_variables].melt(id_vars=['time',self.group])
|
|
350
|
+
fig = px.line(tmp,x='time',y='value',color='variable',title=self.name,facet_row=self.group)
|
|
351
|
+
fig.show()
|
|
352
|
+
return fig
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def create_data_loader(self,data:pd.DataFrame,
|
|
356
|
+
past_steps:int,
|
|
357
|
+
future_steps:int,
|
|
358
|
+
shift:int=0,
|
|
359
|
+
keep_entire_seq_while_shifting:bool=False,
|
|
360
|
+
starting_point:Union[None,dict]=None,
|
|
361
|
+
skip_step:int=1,
|
|
362
|
+
is_inference:bool=False
|
|
363
|
+
|
|
364
|
+
)->MyDataset:
|
|
365
|
+
""" Create the dataset for the training/inference step
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
data (pd.DataFrame): input dataset, usually a subset of self.data
|
|
369
|
+
past_steps (int): past context length
|
|
370
|
+
future_steps (int): future lags to predict
|
|
371
|
+
shift (int, optional): if >0 the future input variables will be shifted (categorical and numerical). For example for attention model it is better to start with a know value of y and use it during the process. Defaults to 0.
|
|
372
|
+
keep_entire_seq_while_shifting (bool, optional): if the dataset is shifted, you may want the future data be of length future_step+shift (like informer), default false
|
|
373
|
+
starting_point (Union[None,dict], optional): a dictionary indicating if a sample must be considered. It is checked for the first lag in the future (useful in the case your model has to predict only starting from hour 12). Defaults to None.
|
|
374
|
+
skip_step (int, optional): list of the categortial variables (same for past and future). Usual there is a skip of one between two saples but for debugging or training time purposes you can skip some samples. Defaults to 1.
|
|
375
|
+
Returns:
|
|
376
|
+
MyDataset: class that extends torch.utils.data.Dataset (see utils)
|
|
377
|
+
keys of a batch:
|
|
378
|
+
y : the target variable(s)
|
|
379
|
+
x_num_past: the numerical past variables
|
|
380
|
+
x_num_future: the numerical future variables
|
|
381
|
+
x_cat_past: the categorical past variables
|
|
382
|
+
x_cat_future: the categorical future variables
|
|
383
|
+
idx_target: index of target features in the past array
|
|
384
|
+
"""
|
|
385
|
+
beauty_string('Creating data loader','block',self.verbose)
|
|
386
|
+
|
|
387
|
+
x_num_past_samples = []
|
|
388
|
+
x_num_future_samples = []
|
|
389
|
+
x_cat_past_samples = []
|
|
390
|
+
x_cat_future_samples = []
|
|
391
|
+
y_samples = []
|
|
392
|
+
t_samples = []
|
|
393
|
+
g_samples = []
|
|
394
|
+
|
|
395
|
+
if starting_point is not None:
|
|
396
|
+
kk = list(starting_point.keys())[0]
|
|
397
|
+
assert kk not in self.cat_var, beauty_string('CAN NOT USE FEATURE {kk} as starting point it may have a different value due to the normalization step, please add a second column with a suitable name','info',True)
|
|
398
|
+
|
|
399
|
+
##overwrite categorical columns
|
|
400
|
+
for c in self.cat_var:
|
|
401
|
+
self.enrich(data,c)
|
|
402
|
+
|
|
403
|
+
if self.group is None:
|
|
404
|
+
data['_GROUP_'] = '1'
|
|
405
|
+
else:
|
|
406
|
+
data['_GROUP_'] = data[self.group].values
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
if self.normalize_per_group:
|
|
410
|
+
tot = []
|
|
411
|
+
groups = data[self.group].unique()
|
|
412
|
+
|
|
413
|
+
data[self.group] = self.scaler_cat[self.group].transform(data[self.group].values.reshape(-1,1)).flatten()
|
|
414
|
+
|
|
415
|
+
for group in groups:
|
|
416
|
+
tmp = data[data['_GROUP_']==group].copy()
|
|
417
|
+
|
|
418
|
+
for c in self.num_var:
|
|
419
|
+
tmp[c] = self.scaler_num[f'{c}_{group}'].transform(tmp[c].values.reshape(-1,1)).flatten()
|
|
420
|
+
for c in self.cat_var:
|
|
421
|
+
if c!=self.group:
|
|
422
|
+
tmp[c] = self.scaler_cat[f'{c}_{group}'].transform(tmp[c].values.reshape(-1,1)).flatten()
|
|
423
|
+
tot.append(tmp)
|
|
424
|
+
data = pd.concat(tot,ignore_index=True)
|
|
425
|
+
else:
|
|
426
|
+
for c in self.cat_var:
|
|
427
|
+
data[c] = self.scaler_cat[c].transform(data[c].values.reshape(-1,1)).flatten()
|
|
428
|
+
for c in self.num_var:
|
|
429
|
+
data[c] = self.scaler_num[c].transform(data[c].values.reshape(-1,1)).flatten()
|
|
430
|
+
|
|
431
|
+
idx_target = []
|
|
432
|
+
for c in self.target_variables:
|
|
433
|
+
idx_target.append(self.past_variables.index(c))
|
|
434
|
+
|
|
435
|
+
idx_target_future = []
|
|
436
|
+
|
|
437
|
+
for c in self.target_variables:
|
|
438
|
+
if c in self.future_variables:
|
|
439
|
+
idx_target_future.append(self.future_variables.index(c))
|
|
440
|
+
if len(idx_target_future)==0:
|
|
441
|
+
idx_target_future = None
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
if self.stacked:
|
|
445
|
+
skip_stacked = future_steps*future_steps-future_steps
|
|
446
|
+
else:
|
|
447
|
+
skip_stacked = 0
|
|
448
|
+
for group in data['_GROUP_'].unique():
|
|
449
|
+
tmp = data[data['_GROUP_']==group]
|
|
450
|
+
groups = tmp['_GROUP_'].values
|
|
451
|
+
t = tmp.time.values
|
|
452
|
+
x_num_past = tmp[self.past_variables].values
|
|
453
|
+
if len(self.future_variables)>0:
|
|
454
|
+
x_num_future = tmp[self.future_variables].values
|
|
455
|
+
if len(self.cat_past_var)>0:
|
|
456
|
+
x_past_cat = tmp[self.cat_past_var].values
|
|
457
|
+
if len(self.cat_fut_var)>0:
|
|
458
|
+
x_fut_cat = tmp[self.cat_fut_var].values
|
|
459
|
+
y_target = tmp[self.target_variables].values
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
if starting_point is not None:
|
|
463
|
+
check = tmp[list(starting_point.keys())[0]].values == starting_point[list(starting_point.keys())[0]]
|
|
464
|
+
else:
|
|
465
|
+
check = [True]*len(y_target)
|
|
466
|
+
|
|
467
|
+
for i in range(past_steps,tmp.shape[0]-future_steps-skip_stacked,skip_step):
|
|
468
|
+
if check[i]:
|
|
469
|
+
|
|
470
|
+
if len(self.future_variables)>0:
|
|
471
|
+
if keep_entire_seq_while_shifting:
|
|
472
|
+
xx = x_num_future[i-shift+skip_stacked:i+future_steps+skip_stacked].mean()
|
|
473
|
+
else:
|
|
474
|
+
xx = x_num_future[i-shift+skip_stacked:i+future_steps-shift+skip_stacked].mean()
|
|
475
|
+
else:
|
|
476
|
+
xx = 0.0
|
|
477
|
+
if is_inference is False:
|
|
478
|
+
xx+=y_target[i+skip_stacked:i+future_steps+skip_stacked].min()
|
|
479
|
+
|
|
480
|
+
if np.isfinite(x_num_past[i-past_steps:i].min() + xx):
|
|
481
|
+
|
|
482
|
+
x_num_past_samples.append(x_num_past[i-past_steps:i])
|
|
483
|
+
if len(self.future_variables)>0:
|
|
484
|
+
if keep_entire_seq_while_shifting:
|
|
485
|
+
x_num_future_samples.append(x_num_future[i-shift+skip_stacked:i+future_steps+skip_stacked])
|
|
486
|
+
else:
|
|
487
|
+
x_num_future_samples.append(x_num_future[i-shift+skip_stacked:i+future_steps-shift+skip_stacked])
|
|
488
|
+
if len(self.cat_past_var)>0:
|
|
489
|
+
x_cat_past_samples.append(x_past_cat[i-past_steps:i])
|
|
490
|
+
if len(self.cat_fut_var)>0:
|
|
491
|
+
if keep_entire_seq_while_shifting:
|
|
492
|
+
x_cat_future_samples.append(x_fut_cat[i-shift+skip_stacked:i+future_steps+skip_stacked])
|
|
493
|
+
else:
|
|
494
|
+
x_cat_future_samples.append(x_fut_cat[i-shift+skip_stacked:i+future_steps-shift+skip_stacked])
|
|
495
|
+
|
|
496
|
+
y_samples.append(y_target[i+skip_stacked:i+future_steps+skip_stacked])
|
|
497
|
+
t_samples.append(t[i+skip_stacked:i+future_steps+skip_stacked])
|
|
498
|
+
g_samples.append(groups[i])
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
if len(self.future_variables)>0:
|
|
503
|
+
try:
|
|
504
|
+
x_num_future_samples = np.stack(x_num_future_samples)
|
|
505
|
+
except Exception as e:
|
|
506
|
+
beauty_string('WARNING x_num_future_samples is empty and it should not','info',True)
|
|
507
|
+
|
|
508
|
+
y_samples = np.stack(y_samples)
|
|
509
|
+
t_samples = np.stack(t_samples)
|
|
510
|
+
g_samples = np.stack(g_samples)
|
|
511
|
+
|
|
512
|
+
if len(self.cat_past_var)>0:
|
|
513
|
+
x_cat_past_samples = np.stack(x_cat_past_samples).astype(np.int32)
|
|
514
|
+
if len(self.cat_fut_var)>0:
|
|
515
|
+
x_cat_future_samples = np.stack(x_cat_future_samples).astype(np.int32)
|
|
516
|
+
x_num_past_samples = np.stack(x_num_past_samples)
|
|
517
|
+
if self.stacked:
|
|
518
|
+
mod = 0
|
|
519
|
+
else:
|
|
520
|
+
mod = 1.0
|
|
521
|
+
dd = {'y':y_samples.astype(np.float32),
|
|
522
|
+
|
|
523
|
+
'x_num_past':(x_num_past_samples*mod).astype(np.float32)}
|
|
524
|
+
if len(self.cat_past_var)>0:
|
|
525
|
+
dd['x_cat_past'] = x_cat_past_samples
|
|
526
|
+
if len(self.cat_fut_var)>0:
|
|
527
|
+
dd['x_cat_future'] = x_cat_future_samples
|
|
528
|
+
if len(self.future_variables)>0:
|
|
529
|
+
dd['x_num_future'] = x_num_future_samples.astype(np.float32)
|
|
530
|
+
|
|
531
|
+
return MyDataset(dd,t_samples,g_samples,idx_target,idx_target_future)
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def split_for_train(self,
|
|
536
|
+
perc_train:Union[float,None]=0.6,
|
|
537
|
+
perc_valid:Union[float,None]=0.2,
|
|
538
|
+
range_train:Union[List[Union[datetime, str]],None]=None,
|
|
539
|
+
range_validation:Union[List[Union[datetime, str]],None]=None,
|
|
540
|
+
range_test:Union[List[Union[datetime, str]],None]=None,
|
|
541
|
+
past_steps:int = 100,
|
|
542
|
+
future_steps:int=20,
|
|
543
|
+
shift:int = 0,
|
|
544
|
+
keep_entire_seq_while_shifting:bool=False,
|
|
545
|
+
starting_point:Union[None, dict]=None,
|
|
546
|
+
skip_step:int=1,
|
|
547
|
+
normalize_per_group: bool=False,
|
|
548
|
+
check_consecutive: bool=True,
|
|
549
|
+
scaler: str='StandardScaler()'
|
|
550
|
+
)->List[DataLoader]:
|
|
551
|
+
"""Split the data and create the datasets.
|
|
552
|
+
|
|
553
|
+
Args:
|
|
554
|
+
perc_train (Union[float,None], optional): fraction of the training set. Defaults to 0.6.
|
|
555
|
+
perc_valid (Union[float,None], optional): fraction of the test set. Defaults to 0.2.
|
|
556
|
+
range_train (Union[List[Union[datetime, str]],None], optional): a list of two elements indicating the starting point and end point of the training set (string date style or datetime). Defaults to None.
|
|
557
|
+
range_validation (Union[List[Union[datetime, str]],None], optional):a list of two elements indicating the starting point and end point of the validation set (string date style or datetime). Defaults to None.
|
|
558
|
+
range_test (Union[List[Union[datetime, str]],None], optional): a list of two elements indicating the starting point and end point of the test set (string date style or datetime). Defaults to None.
|
|
559
|
+
past_steps (int, optional): past step to consider for making the prediction. Defaults to 100.
|
|
560
|
+
future_steps (int, optional): future step to predict. Defaults to 20.
|
|
561
|
+
shift (int, optional): see `create_data_loader`. Defaults to 0.
|
|
562
|
+
keep_entire_seq_while_shifting (bool, optional): if the dataset is shifted, you may want the future data be of length future_step+shift (like informer), default false
|
|
563
|
+
|
|
564
|
+
starting_point (Union[None, dict], optional): see `create_data_loader`. Defaults to None.
|
|
565
|
+
skip_step (int, optional): see `create_data_loader`. Defaults to 1.
|
|
566
|
+
normalize_per_group (boolean, optional): if true and self.group is not None, the variables are scaled respect to the groups. Default False
|
|
567
|
+
check_consecutive (boolean, optional): if false it skips the check on the consecutive ranges. Default True
|
|
568
|
+
scaler: instance of a sklearn.preprocessing scaler. Default 'StandardScaler()'
|
|
569
|
+
Returns:
|
|
570
|
+
List[DataLoader,DataLoader,DataLoadtrainer]: three dataloader used for training or inference
|
|
571
|
+
"""
|
|
572
|
+
|
|
573
|
+
beauty_string('Splitting for train','block',self.verbose)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
try:
|
|
577
|
+
ls = self.dataset.shape[0]
|
|
578
|
+
except Exception as _:
|
|
579
|
+
beauty_string('Empty dataset','info', True)
|
|
580
|
+
return None, None, None
|
|
581
|
+
|
|
582
|
+
if range_train is None:
|
|
583
|
+
if self.group is None:
|
|
584
|
+
beauty_string(f'Split temporally using perc_train: {perc_train} and perc_valid:{perc_valid}','section',self.verbose)
|
|
585
|
+
train = self.dataset.iloc[0:int(perc_train*ls)]
|
|
586
|
+
validation = self.dataset.iloc[int(perc_train*ls):int(perc_train*ls+perc_valid*ls)]
|
|
587
|
+
test = self.dataset.iloc[int(perc_train*ls+perc_valid*ls):]
|
|
588
|
+
else:
|
|
589
|
+
beauty_string(f'Split temporally using perc_train: {perc_train} and perc_valid:{perc_valid} for each group!','info',self.verbose)
|
|
590
|
+
train = []
|
|
591
|
+
validation =[]
|
|
592
|
+
test = []
|
|
593
|
+
ls = self.dataset.groupby(self.group).time.count().reset_index()
|
|
594
|
+
for group in self.dataset[self.group].unique():
|
|
595
|
+
tmp = self.dataset[self.dataset[self.group]==group]
|
|
596
|
+
lt = ls[ls[self.group]==group].time.values[0]
|
|
597
|
+
train.append(tmp[0:int(perc_train*lt)])
|
|
598
|
+
validation.append(tmp[int(perc_train*lt):int(perc_train*lt+perc_valid*lt)])
|
|
599
|
+
test.append(tmp[int(perc_train*lt+perc_valid*lt):])
|
|
600
|
+
|
|
601
|
+
train = pd.concat(train,ignore_index=True)
|
|
602
|
+
validation = pd.concat(validation,ignore_index=True)
|
|
603
|
+
test = pd.concat(test,ignore_index=True)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
else:
|
|
607
|
+
if check_consecutive:
|
|
608
|
+
assert range_train[0]<range_train[1]<=range_validation[0]<range_validation[1]<=range_test[0]<range_test[1], beauty_string(f'The range are not correct','info',True)
|
|
609
|
+
beauty_string('Split temporally using the time intervals provided','section',self.verbose)
|
|
610
|
+
train = self.dataset[self.dataset.time.between(range_train[0],range_train[1])]
|
|
611
|
+
validation = self.dataset[self.dataset.time.between(range_validation[0],range_validation[1])]
|
|
612
|
+
test = self.dataset[self.dataset.time.between(range_test[0],range_test[1])]
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
beauty_string('Train categorical and numerical scalers','block',self.verbose)
|
|
616
|
+
|
|
617
|
+
if self.is_trained:
|
|
618
|
+
pass
|
|
619
|
+
else:
|
|
620
|
+
self.scaler_cat = {}
|
|
621
|
+
self.scaler_num = {}
|
|
622
|
+
if self.group is None or normalize_per_group is False:
|
|
623
|
+
self.normalize_per_group = False
|
|
624
|
+
for c in self.num_var:
|
|
625
|
+
self.scaler_num[c] = eval(scaler)
|
|
626
|
+
self.scaler_num[c].fit(train[c].values.reshape(-1,1))
|
|
627
|
+
for c in self.cat_var:
|
|
628
|
+
self.scaler_cat[c] = OrdinalEncoder(dtype=np.int32,handle_unknown= 'use_encoded_value',unknown_value=train[c].nunique())
|
|
629
|
+
|
|
630
|
+
self.scaler_cat[c].fit(train[c].values.reshape(-1,1))
|
|
631
|
+
else:
|
|
632
|
+
self.normalize_per_group = True
|
|
633
|
+
self.scaler_cat[self.group] = OrdinalEncoder(dtype=np.int32,handle_unknown= 'use_encoded_value',unknown_value=train[c].nunique())
|
|
634
|
+
self.scaler_cat[self.group].fit(train[self.group].values.reshape(-1,1))
|
|
635
|
+
for group in train[self.group].unique():
|
|
636
|
+
tmp = train[train[self.group]==group]
|
|
637
|
+
|
|
638
|
+
for c in self.num_var:
|
|
639
|
+
self.scaler_num[f'{c}_{group}'] = eval(scaler)
|
|
640
|
+
self.scaler_num[f'{c}_{group}'].fit(tmp[c].values.reshape(-1,1))
|
|
641
|
+
for c in self.cat_var:
|
|
642
|
+
if c!=self.group:
|
|
643
|
+
self.scaler_cat[f'{c}_{group}'] = OrdinalEncoder(dtype=np.int32,handle_unknown= 'use_encoded_value',unknown_value=train[c].nunique())
|
|
644
|
+
self.scaler_cat[f'{c}_{group}'].fit(tmp[c].values.reshape(-1,1))
|
|
645
|
+
|
|
646
|
+
dl_train = self.create_data_loader(train,past_steps,future_steps,shift,keep_entire_seq_while_shifting,starting_point,skip_step)
|
|
647
|
+
dl_validation = self.create_data_loader(validation,past_steps,future_steps,shift,keep_entire_seq_while_shifting,starting_point,skip_step)
|
|
648
|
+
if test.shape[0]>0:
|
|
649
|
+
dl_test = self.create_data_loader(test,past_steps,future_steps,shift,keep_entire_seq_while_shifting,starting_point,skip_step)
|
|
650
|
+
else:
|
|
651
|
+
dl_test = None
|
|
652
|
+
return dl_train,dl_validation,dl_test
|
|
653
|
+
|
|
654
|
+
def set_model(self,model:Base,config:dict=None,custom_init:bool=False):
|
|
655
|
+
"""Set the model to train
|
|
656
|
+
|
|
657
|
+
Args:
|
|
658
|
+
model (Base): see `models`
|
|
659
|
+
config (dict, optional): usually the configuration used by the model. Defaults to None.
|
|
660
|
+
custom_init (bool, optional): if true a custom initialization paradigm will be used (see weight_init in models/utils.py ) .
|
|
661
|
+
"""
|
|
662
|
+
self.model = model
|
|
663
|
+
if custom_init:
|
|
664
|
+
self.model.apply(weight_init)
|
|
665
|
+
#self.model.apply(weight_init_zeros)
|
|
666
|
+
|
|
667
|
+
self.config = config
|
|
668
|
+
|
|
669
|
+
beauty_string('Setting the model','block',self.verbose)
|
|
670
|
+
beauty_string(model,'',self.verbose)
|
|
671
|
+
|
|
672
|
+
def train_model(self,dirpath:str,
|
|
673
|
+
split_params:dict,
|
|
674
|
+
batch_size:int=100,
|
|
675
|
+
num_workers:int=4,
|
|
676
|
+
max_epochs:int=500,
|
|
677
|
+
auto_lr_find:bool=True,
|
|
678
|
+
gradient_clip_val:Union[float,None]=None,
|
|
679
|
+
gradient_clip_algorithm:str="value",
|
|
680
|
+
devices:Union[str,List[int]]='auto',
|
|
681
|
+
precision:Union[str,int]=32,
|
|
682
|
+
modifier:Union[None,str]=None,
|
|
683
|
+
modifier_params:Union[None,dict]=None,
|
|
684
|
+
seed:int=42
|
|
685
|
+
)-> float:
|
|
686
|
+
"""Train the model
|
|
687
|
+
|
|
688
|
+
Args:
|
|
689
|
+
dirpath (str): path where to put all the useful things
|
|
690
|
+
split_params (dict): see `split_for_train`
|
|
691
|
+
batch_size (int, optional): batch size. Defaults to 100.
|
|
692
|
+
num_workers (int, optional): num_workers for the dataloader. Defaults to 4.
|
|
693
|
+
max_epochs (int, optional): maximum epochs to perform. Defaults to 500.
|
|
694
|
+
auto_lr_find (bool, optional): find initial learning rate, see `pytorch-lightening`. Defaults to True.
|
|
695
|
+
gradient_clip_val (Union[float,None], optional): gradient_clip_val. Defaults to None. See https://lightning.ai/docs/pytorch/stable/advanced/training_tricks.html
|
|
696
|
+
gradient_clip_algorithm (str, optional): gradient_clip_algorithm. Defaults to 'norm '. See https://lightning.ai/docs/pytorch/stable/advanced/training_tricks.html
|
|
697
|
+
devices (Union[str,List[int]], optional): devices to use. Use auto if cpu or the list of gpu to use otherwise. Defaults to 'auto'.
|
|
698
|
+
precision (Union[str,int], optional): precision to use. Usually 32 bit is fine but for larger model you should try 'bf16'. If 'auto' it will use bf16 for GPU and 32 for cpu
|
|
699
|
+
modifier (Union[str,int], optional): if not None a modifier is applyed to the dataloader. Sometimes lightening has very restrictive rules on the dataloader, or we want to use a ML model before or after the DL model (See readme for more information)
|
|
700
|
+
modifier_params (Union[dict,int], optional): parameters of the modifier
|
|
701
|
+
seed (int, optional): seed for reproducibility
|
|
702
|
+
"""
|
|
703
|
+
|
|
704
|
+
beauty_string('Training the model','block',self.verbose)
|
|
705
|
+
|
|
706
|
+
self.split_params = split_params
|
|
707
|
+
self.check_custom = False
|
|
708
|
+
train,validation,test = self.split_for_train(**self.split_params)
|
|
709
|
+
accelerator = 'gpu' if torch.cuda.is_available() else "cpu"
|
|
710
|
+
strategy = "auto"
|
|
711
|
+
if accelerator == 'gpu':
|
|
712
|
+
strategy = "auto" ##TODO in future investigate on this
|
|
713
|
+
if precision=='auto':
|
|
714
|
+
precision = 'bf16'
|
|
715
|
+
#"bf16" ##in futuro magari inserirlo nei config, potrebbe essere che per alcuni modelli possa non andare bfloat32
|
|
716
|
+
torch.set_float32_matmul_precision('medium')
|
|
717
|
+
beauty_string('Setting multiplication precision to medium','info',self.verbose)
|
|
718
|
+
else:
|
|
719
|
+
devices = 'auto'
|
|
720
|
+
if precision=='auto':
|
|
721
|
+
precision = 32
|
|
722
|
+
beauty_string(f'train:{len(train)}, validation:{len(validation)}, test:{len(test) if test is not None else 0}','section',self.verbose)
|
|
723
|
+
if (accelerator=='gpu') and (num_workers>0):
|
|
724
|
+
persistent_workers = True
|
|
725
|
+
else:
|
|
726
|
+
persistent_workers = False
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
if modifier is not None:
|
|
730
|
+
modifier = eval(modifier)
|
|
731
|
+
modifier = modifier(**modifier_params)
|
|
732
|
+
train, validation = modifier.fit_transform(train=train,val=validation)
|
|
733
|
+
self.modifier = modifier
|
|
734
|
+
else:
|
|
735
|
+
self.modifier = None
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
train_dl = DataLoader(train, batch_size = batch_size , shuffle=True,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
|
|
739
|
+
valid_dl = DataLoader(validation, batch_size = batch_size , shuffle=False,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
|
|
740
|
+
|
|
741
|
+
checkpoint_callback = ModelCheckpoint(dirpath=dirpath,
|
|
742
|
+
monitor='val_loss',
|
|
743
|
+
save_last = True,
|
|
744
|
+
every_n_epochs =1,
|
|
745
|
+
verbose = self.verbose,
|
|
746
|
+
save_top_k = 1,
|
|
747
|
+
filename='checkpoint')
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
#logger = CSVLogger("logs", name=dirpath)
|
|
751
|
+
aim_logger = AimLogger(
|
|
752
|
+
experiment=self.name,
|
|
753
|
+
train_metric_prefix='train_',
|
|
754
|
+
val_metric_prefix='val_',
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
#https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
|
|
758
|
+
n_params = sum(dict((p.data_ptr(), p.numel()) for p in self.model.parameters()).values())
|
|
759
|
+
|
|
760
|
+
#https://discuss.pytorch.org/t/finding-model-size/130275
|
|
761
|
+
param_size = 0
|
|
762
|
+
for param in self.model.parameters():
|
|
763
|
+
param_size += param.nelement() * param.element_size()
|
|
764
|
+
buffer_size = 0
|
|
765
|
+
for buffer in self.model.buffers():
|
|
766
|
+
buffer_size += buffer.nelement() * buffer.element_size()
|
|
767
|
+
|
|
768
|
+
size_all_mb = (param_size + buffer_size) / 1024**2
|
|
769
|
+
#aim_logger.experiment.track(self.model.name,name='model_name')
|
|
770
|
+
|
|
771
|
+
aim_logger.experiment.track(n_params,name='N-parameters')
|
|
772
|
+
aim_logger.experiment.track(size_all_mb,name='dim-model-MB')
|
|
773
|
+
aim_logger.experiment.track(len(train_dl.dataset),name='len-train')
|
|
774
|
+
aim_logger.experiment.track(len(valid_dl.dataset),name='len-valid')
|
|
775
|
+
#aim_logger.experiment.track(self.config,name=None)
|
|
776
|
+
tmp = self.config.copy()
|
|
777
|
+
tmp['model_name'] = self.model.name
|
|
778
|
+
aim_logger._run['hyperparameters'] = tmp
|
|
779
|
+
|
|
780
|
+
mc = MetricsCallback(dirpath)
|
|
781
|
+
## TODO se ci sono 2 o piu gpu MetricsCallback non funziona (secondo me fa una istanza per ogni dataparallel che lancia e poi non riesce a recuperare info)
|
|
782
|
+
pl.seed_everything(seed, workers=True)
|
|
783
|
+
self.model.max_epochs = max_epochs
|
|
784
|
+
|
|
785
|
+
if os.path.isfile(os.path.join(dirpath,'last.ckpt')):
|
|
786
|
+
weight_exists = True
|
|
787
|
+
beauty_string('I loaded a previous checkpoint','section',self.verbose)
|
|
788
|
+
|
|
789
|
+
else:
|
|
790
|
+
weight_exists = False
|
|
791
|
+
beauty_string('I can not load a previous model','section',self.verbose)
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
if OLD_PL:
|
|
797
|
+
trainer = pl.Trainer(default_root_dir=dirpath,
|
|
798
|
+
logger = aim_logger,
|
|
799
|
+
max_epochs=max_epochs,
|
|
800
|
+
callbacks=[checkpoint_callback,mc],
|
|
801
|
+
auto_lr_find=auto_lr_find,
|
|
802
|
+
accelerator=accelerator,
|
|
803
|
+
devices=devices,
|
|
804
|
+
strategy=strategy,
|
|
805
|
+
enable_progress_bar=False,
|
|
806
|
+
precision=precision,
|
|
807
|
+
gradient_clip_val=gradient_clip_val,
|
|
808
|
+
gradient_clip_algorithm=gradient_clip_algorithm)#,devices=1)
|
|
809
|
+
else:
|
|
810
|
+
trainer = pl.Trainer(default_root_dir=dirpath,
|
|
811
|
+
logger = aim_logger,
|
|
812
|
+
max_epochs=max_epochs,
|
|
813
|
+
callbacks=[checkpoint_callback,mc],
|
|
814
|
+
strategy='auto',
|
|
815
|
+
devices=devices,
|
|
816
|
+
enable_progress_bar=False,
|
|
817
|
+
precision=precision,
|
|
818
|
+
gradient_clip_val=gradient_clip_val,
|
|
819
|
+
gradient_clip_algorithm=gradient_clip_algorithm)#,devices=1)
|
|
820
|
+
tot_seconds = time.time()
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
if auto_lr_find and (weight_exists is False):
|
|
826
|
+
if OLD_PL:
|
|
827
|
+
lr_tuner = trainer.tune(self.model,train_dataloaders=train_dl,val_dataloaders = valid_dl)
|
|
828
|
+
files = os.listdir(dirpath)
|
|
829
|
+
for f in files:
|
|
830
|
+
if '.lr_find' in f:
|
|
831
|
+
os.remove(os.path.join(dirpath,f))
|
|
832
|
+
self.model.optim_config['lr'] = lr_tuner['lr_find'].suggestion()
|
|
833
|
+
else:
|
|
834
|
+
from lightning.pytorch.tuner import Tuner
|
|
835
|
+
tuner = Tuner(trainer)
|
|
836
|
+
lr_finder = tuner.lr_find(self.model,train_dataloaders=train_dl,val_dataloaders = valid_dl)
|
|
837
|
+
self.model.optim_config['lr'] = lr_finder.suggestion() ## we are using it as optim key
|
|
838
|
+
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
if OLD_PL:
|
|
842
|
+
if weight_exists:
|
|
843
|
+
trainer.fit(self.model, train_dl,valid_dl,ckpt_path=os.path.join(dirpath,'last.ckpt'))
|
|
844
|
+
else:
|
|
845
|
+
trainer.fit(self.model, train_dl,valid_dl)
|
|
846
|
+
else:
|
|
847
|
+
if weight_exists:
|
|
848
|
+
trainer.fit(self.model, train_dataloaders = train_dl,val_dataloaders = valid_dl,ckpt_path=os.path.join(dirpath,'last.ckpt'))
|
|
849
|
+
else:
|
|
850
|
+
trainer.fit(self.model, train_dataloaders = train_dl,val_dataloaders = valid_dl)
|
|
851
|
+
self.checkpoint_file_best = checkpoint_callback.best_model_path
|
|
852
|
+
self.checkpoint_file_last = checkpoint_callback.last_model_path
|
|
853
|
+
if self.checkpoint_file_last=='':
|
|
854
|
+
beauty_string('There is a bug on saving last model I will try to fix it','info',self.verbose)
|
|
855
|
+
self.checkpoint_file_last = checkpoint_callback.best_model_path.replace('checkpoint','last')
|
|
856
|
+
|
|
857
|
+
self.dirpath = dirpath
|
|
858
|
+
|
|
859
|
+
self.losses = mc.metrics
|
|
860
|
+
|
|
861
|
+
files = os.listdir(dirpath)
|
|
862
|
+
##accrocchio per multi gpu
|
|
863
|
+
for f in files:
|
|
864
|
+
if '__losses__.csv' in f:
|
|
865
|
+
if len(self.losses['val_loss'])>0:
|
|
866
|
+
self.losses = pd.DataFrame(self.losses)
|
|
867
|
+
else:
|
|
868
|
+
self.losses = pd.read_csv(os.path.join(os.path.join(dirpath,f)))
|
|
869
|
+
os.remove(os.path.join(os.path.join(dirpath,f)))
|
|
870
|
+
if isinstance(self.losses,dict):
|
|
871
|
+
self.losses = pd.DataFrame()
|
|
872
|
+
|
|
873
|
+
try:
|
|
874
|
+
if OLD_PL:
|
|
875
|
+
self.model = self.model.load_from_checkpoint(self.checkpoint_file_last)
|
|
876
|
+
else:
|
|
877
|
+
self.model = self.model.__class__.load_from_checkpoint(self.checkpoint_file_last)
|
|
878
|
+
|
|
879
|
+
except Exception as _:
|
|
880
|
+
beauty_string(f'There is a problem loading the weights on file MAYBE CHANGED HOW WEIGHTS ARE LOADED {self.checkpoint_file_last}','section',self.verbose)
|
|
881
|
+
|
|
882
|
+
try:
|
|
883
|
+
val_loss = self.losses.val_loss.values[-1]
|
|
884
|
+
except Exception as _:
|
|
885
|
+
beauty_string('Can not extract the validation loss, maybe it is a persistent model','info',self.verbose)
|
|
886
|
+
val_loss = 100
|
|
887
|
+
self.is_trained = True
|
|
888
|
+
|
|
889
|
+
beauty_string('END of the training process','block',self.verbose)
|
|
890
|
+
|
|
891
|
+
aim_logger.experiment.track((time.time()-tot_seconds),name='seconds-training')
|
|
892
|
+
aim_logger.experiment.track(val_loss,name='val-loss-end-train')
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
return val_loss
|
|
897
|
+
|
|
898
|
+
def inference_on_set(self,batch_size:int=100,
|
|
899
|
+
num_workers:int=4,
|
|
900
|
+
split_params:Union[None,dict]=None,set:str='test',
|
|
901
|
+
rescaling:bool=True,
|
|
902
|
+
data:Union[None,torch.utils.data.Dataset]=None)->pd.DataFrame:
|
|
903
|
+
"""This function allows to get the prediction on a particular set (train, test or validation).
|
|
904
|
+
|
|
905
|
+
Args:
|
|
906
|
+
batch_size (int, optional): barch sise. Defaults to 100.
|
|
907
|
+
num_workers (int, optional): num workers. Defaults to 4.
|
|
908
|
+
split_params (Union[None,dict], optional): if not None the spliting procedure will use the given data otherwise it will use the same configuration used in train. Defaults to None.
|
|
909
|
+
set (str, optional): trai, validation or test. Defaults to 'test'.
|
|
910
|
+
rescaling (bool, optional): If rescaling is true the output will be rescaled to the initial values. . Defaults to True.
|
|
911
|
+
data (None or pd.DataFrame, optional). If not None the inference is performed on the given data. In the case of custom data please call inference because it will normalize the data for you!
|
|
912
|
+
Returns:
|
|
913
|
+
pd.DataFrame: the predicted values in a pandas format
|
|
914
|
+
"""
|
|
915
|
+
|
|
916
|
+
beauty_string('Inference on a set (train, validation o test)','block',self.verbose)
|
|
917
|
+
|
|
918
|
+
if data is None:
|
|
919
|
+
if split_params is None:
|
|
920
|
+
beauty_string(f'splitting using train parameters {self.split_params}','section',self.verbose)
|
|
921
|
+
train,validation,test = self.split_for_train(**self.split_params)
|
|
922
|
+
else:
|
|
923
|
+
train,validation,test = self.split_for_train(**split_params)
|
|
924
|
+
|
|
925
|
+
if set=='test':
|
|
926
|
+
if self.modifier is not None:
|
|
927
|
+
test = self.modifier.transform(test)
|
|
928
|
+
dl = DataLoader(test, batch_size = batch_size , shuffle=False,drop_last=False,num_workers=num_workers)
|
|
929
|
+
elif set=='validation':
|
|
930
|
+
if self.modifier is not None:
|
|
931
|
+
validation = self.modifier.transform(validation)
|
|
932
|
+
dl = DataLoader(validation, batch_size = batch_size , shuffle=False,drop_last=False,num_workers=num_workers)
|
|
933
|
+
elif set=='train':
|
|
934
|
+
if self.modifier is not None:
|
|
935
|
+
train = self.modifier.transform(train)
|
|
936
|
+
dl = DataLoader(train, batch_size = batch_size , shuffle=False,drop_last=False,num_workers=num_workers)
|
|
937
|
+
elif set=='custom':
|
|
938
|
+
if self.check_custom:
|
|
939
|
+
pass
|
|
940
|
+
else:
|
|
941
|
+
beauty_string('If you are here something went wrong, please report it','section',self.verbose)
|
|
942
|
+
if self.modifier is not None:
|
|
943
|
+
data = self.modifier.transform(data)
|
|
944
|
+
dl = DataLoader(data, batch_size = batch_size , shuffle=False,drop_last=False,num_workers=num_workers)
|
|
945
|
+
|
|
946
|
+
else:
|
|
947
|
+
beauty_string('Select one of train, test, or validation set','section',self.verbose)
|
|
948
|
+
self.model.eval()
|
|
949
|
+
|
|
950
|
+
res = []
|
|
951
|
+
real = []
|
|
952
|
+
self.model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
|
|
953
|
+
beauty_string(f'Device used: {self.model.device}','info',self.verbose)
|
|
954
|
+
|
|
955
|
+
for batch in dl:
|
|
956
|
+
res.append(self.model.inference(batch).cpu().detach().numpy())
|
|
957
|
+
real.append(batch['y'].cpu().detach().numpy())
|
|
958
|
+
|
|
959
|
+
res = np.vstack(res)
|
|
960
|
+
|
|
961
|
+
real = np.vstack(real)
|
|
962
|
+
time = dl.dataset.t
|
|
963
|
+
groups = dl.dataset.groups
|
|
964
|
+
#import pdb
|
|
965
|
+
#pdb.set_trace()
|
|
966
|
+
if self.modifier is not None:
|
|
967
|
+
res,real = self.modifier.inverse_transform(res,real)
|
|
968
|
+
|
|
969
|
+
## BxLxCx3
|
|
970
|
+
if rescaling:
|
|
971
|
+
beauty_string('Scaling back','info',self.verbose)
|
|
972
|
+
if self.normalize_per_group is False:
|
|
973
|
+
for i, c in enumerate(self.target_variables):
|
|
974
|
+
real[:,:,i] = self.scaler_num[c].inverse_transform(real[:,:,i].reshape(-1,1)).reshape(-1,real.shape[1])
|
|
975
|
+
for j in range(res.shape[3]):
|
|
976
|
+
res[:,:,i,j] = self.scaler_num[c].inverse_transform(res[:,:,i,j].reshape(-1,1)).reshape(-1,res.shape[1])
|
|
977
|
+
else:
|
|
978
|
+
for group in np.unique(groups):
|
|
979
|
+
idx = np.where(groups==group)[0]
|
|
980
|
+
for i, c in enumerate(self.target_variables):
|
|
981
|
+
real[idx,:,i] = self.scaler_num[f'{c}_{group}'].inverse_transform(real[idx,:,i].reshape(-1,1)).reshape(-1,real.shape[1])
|
|
982
|
+
for j in range(res.shape[3]):
|
|
983
|
+
res[idx,:,i,j] = self.scaler_num[f'{c}_{group}'].inverse_transform(res[idx,:,i,j].reshape(-1,1)).reshape(-1,res.shape[1])
|
|
984
|
+
|
|
985
|
+
if self.model.use_quantiles:
|
|
986
|
+
time = pd.DataFrame(time,columns=[i+1 for i in range(res.shape[1])])
|
|
987
|
+
|
|
988
|
+
if self.group is not None:
|
|
989
|
+
time[self.group] = groups
|
|
990
|
+
time = time.melt(id_vars=['region'])
|
|
991
|
+
else:
|
|
992
|
+
time = time.melt()
|
|
993
|
+
time.rename(columns={'value':'time','variable':'lag'},inplace=True)
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
tot = [time]
|
|
997
|
+
for i, c in enumerate(self.target_variables):
|
|
998
|
+
tot.append(pd.DataFrame(real[:,:,i],columns=[i+1 for i in range(res.shape[1])]).melt().rename(columns={'value':c}).drop(columns=['variable']))
|
|
999
|
+
tot.append(pd.DataFrame(res[:,:,i,0],columns=[i+1 for i in range(res.shape[1])]).melt().rename(columns={'value':c+'_low'}).drop(columns=['variable']))
|
|
1000
|
+
tot.append(pd.DataFrame(res[:,:,i,1],columns=[i+1 for i in range(res.shape[1])]).melt().rename(columns={'value':c+'_median'}).drop(columns=['variable']))
|
|
1001
|
+
tot.append(pd.DataFrame(res[:,:,i,2],columns=[i+1 for i in range(res.shape[1])]).melt().rename(columns={'value':c+'_high'}).drop(columns=['variable']))
|
|
1002
|
+
|
|
1003
|
+
res = pd.concat(tot,axis=1)
|
|
1004
|
+
|
|
1005
|
+
|
|
1006
|
+
## BxLxCx1
|
|
1007
|
+
else:
|
|
1008
|
+
time = pd.DataFrame(time,columns=[i+1 for i in range(res.shape[1])])#.melt()
|
|
1009
|
+
|
|
1010
|
+
if self.group is not None:
|
|
1011
|
+
time[self.group] = groups
|
|
1012
|
+
time = time.melt(id_vars=['region'])
|
|
1013
|
+
else:
|
|
1014
|
+
time = time.melt()
|
|
1015
|
+
time.rename(columns={'value':'time','variable':'lag'},inplace=True)
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
tot = [time]
|
|
1019
|
+
for i, c in enumerate(self.target_variables):
|
|
1020
|
+
tot.append(pd.DataFrame(real[:,:,i],columns=[i+1 for i in range(res.shape[1])]).melt().rename(columns={'value':c}).drop(columns=['variable']))
|
|
1021
|
+
tot.append(pd.DataFrame(res[:,:,i,0],columns=[i+1 for i in range(res.shape[1])]).melt().rename(columns={'value':c+'_pred'}).drop(columns=['variable']))
|
|
1022
|
+
res = pd.concat(tot,axis=1)
|
|
1023
|
+
|
|
1024
|
+
res['prediction_time'] = res.apply(lambda x: x.time-self.freq*x.lag, axis=1)
|
|
1025
|
+
return res
|
|
1026
|
+
def inference(self,batch_size:int=100,
|
|
1027
|
+
num_workers:int=4,
|
|
1028
|
+
split_params:Union[None,dict]=None,
|
|
1029
|
+
rescaling:bool=True,
|
|
1030
|
+
data:pd.DataFrame=None,
|
|
1031
|
+
steps_in_future:int=0,
|
|
1032
|
+
check_holes_and_duplicates:bool=True,
|
|
1033
|
+
is_inference:bool=False)->pd.DataFrame: ##TODO PUSH THIS ON PTF!
|
|
1034
|
+
|
|
1035
|
+
"""similar to `inference_on_set`
|
|
1036
|
+
only change is split_params that must contain this keys but using the default can be sufficient:
|
|
1037
|
+
'past_steps','future_steps','shift','keep_entire_seq_while_shifting','starting_point'
|
|
1038
|
+
|
|
1039
|
+
skip_step is set to 1 for convenience (generally you want all the predictions)
|
|
1040
|
+
You can set split_params to None and use the standard parameters (at your own risck)
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
Args:
|
|
1044
|
+
batch_size (int, optional): see inference_on_set. Defaults to 100.
|
|
1045
|
+
num_workers (int, optional): inference_on_set. Defaults to 4.
|
|
1046
|
+
split_params (Union[None,dict], optional): inference_on_set. Defaults to None.
|
|
1047
|
+
rescaling (bool, optional): inference_on_set. Defaults to True.
|
|
1048
|
+
data (pd.DataFrame, optional): startin dataset. Defaults to None.
|
|
1049
|
+
steps_in_future (int, optional): if>0 the dataset is extendend in order to make predictions in the future. Defaults to 0.
|
|
1050
|
+
check_holes_and_duplicates (bool, optional): if False the routine does not check for holes or for duplicates, set to False for stacked model. Defaults to True.
|
|
1051
|
+
|
|
1052
|
+
Returns:
|
|
1053
|
+
pd.DataFrame: predicted values
|
|
1054
|
+
"""
|
|
1055
|
+
beauty_string('Inference on a custom dataset','block',self.verbose)
|
|
1056
|
+
self.check_custom = True ##this is a check for the dataset loading
|
|
1057
|
+
## enlarge the dataset in order to have all the rows needed
|
|
1058
|
+
if check_holes_and_duplicates:
|
|
1059
|
+
if self.group is None:
|
|
1060
|
+
##freq = pd.to_timedelta(np.diff(data.time).min())
|
|
1061
|
+
freq = self.freq #TODO port it into PTF
|
|
1062
|
+
beauty_string(f'Detected minumum frequency: {freq}','section',self.verbose)
|
|
1063
|
+
## TODO work on this for consistency
|
|
1064
|
+
empty = pd.DataFrame({'time':pd.date_range(data.time.min(),data.time.max()+freq*(steps_in_future+self.split_params['past_steps']+self.split_params['future_steps']),freq=freq)})
|
|
1065
|
+
|
|
1066
|
+
else:
|
|
1067
|
+
freq = pd.to_timedelta(np.diff(data[data[self.group==data[self.group].unique()[0]]].time).min())
|
|
1068
|
+
beauty_string(f'Detected minumum frequency: {freq} supposing constant frequence inside the groups','section',self.verbose)
|
|
1069
|
+
_min = data.groupby(self.group).time.min()
|
|
1070
|
+
_max = data.groupby(self.group).time.max()
|
|
1071
|
+
empty = []
|
|
1072
|
+
for c in data[self.group].unique():
|
|
1073
|
+
empty.append(pd.DataFrame({self.group:c,'time':pd.date_range(_min.time[_min[self.group]==c].values[0],_max.time[_max[self.group]==c].values[0]+freq*(steps_in_future+self.split_params['past_steps']+self.split_params['future_steps']),freq=freq)}))
|
|
1074
|
+
empty = pd.concat(empty,ignore_index=True)
|
|
1075
|
+
dataset = empty.merge(data,how='left')
|
|
1076
|
+
#TODO port it into PTF
|
|
1077
|
+
for c in self.cat_var:
|
|
1078
|
+
self.enrich(dataset, c)
|
|
1079
|
+
else:
|
|
1080
|
+
dataset = data.copy()
|
|
1081
|
+
|
|
1082
|
+
|
|
1083
|
+
if split_params is None:
|
|
1084
|
+
split_params = {}
|
|
1085
|
+
for c in self.split_params.keys():
|
|
1086
|
+
if c in ['past_steps','future_steps','shift','keep_entire_seq_while_shifting','starting_point']:
|
|
1087
|
+
split_params[c] = self.split_params[c]
|
|
1088
|
+
split_params['skip_step']=1
|
|
1089
|
+
data = self.create_data_loader(dataset,**split_params,is_inference=is_inference)
|
|
1090
|
+
else:
|
|
1091
|
+
data = self.create_data_loader(data,**split_params,is_inference=is_inference)
|
|
1092
|
+
|
|
1093
|
+
res = self.inference_on_set(batch_size=batch_size,num_workers=num_workers,split_params=None,set='custom',rescaling=rescaling,data=data)
|
|
1094
|
+
self.check_custom = False
|
|
1095
|
+
return res
|
|
1096
|
+
|
|
1097
|
+
def save(self, filename:str)->None:
|
|
1098
|
+
"""save the timeseries object
|
|
1099
|
+
|
|
1100
|
+
Args:
|
|
1101
|
+
filename (str): name of the file
|
|
1102
|
+
"""
|
|
1103
|
+
beauty_string('Saving','block',self.verbose)
|
|
1104
|
+
with open(f'{filename}.pkl','wb') as f:
|
|
1105
|
+
params = self.__dict__.copy()
|
|
1106
|
+
for k in ['model']:
|
|
1107
|
+
if k in params.keys():
|
|
1108
|
+
_ = params.pop(k)
|
|
1109
|
+
pickle.dump(params,f)
|
|
1110
|
+
|
|
1111
|
+
|
|
1112
|
+
def load(self,model:Base, filename:str,load_last:bool=True,dirpath:Union[str,None]=None,weight_path:Union[str, None]=None)->None:
|
|
1113
|
+
""" Load a saved model
|
|
1114
|
+
|
|
1115
|
+
Args:
|
|
1116
|
+
model (Base): class of the model to load (it will be initiated by pytorch-lightening)
|
|
1117
|
+
filename (str): filename of the saved model
|
|
1118
|
+
load_last (bool, optional): if true the last checkpoint will be loaded otherwise the best (in the validation set). Defaults to True.
|
|
1119
|
+
dirpath (Union[str,None], optional): if None we asssume that the model is loaded from the same pc where it has been trained, otherwise we can pass the dirpath where all the stuff has been saved . Defaults to None.
|
|
1120
|
+
weight_path (Union[str, None], optional): if None the standard path will be used. Defaults to None.
|
|
1121
|
+
"""
|
|
1122
|
+
|
|
1123
|
+
|
|
1124
|
+
|
|
1125
|
+
beauty_string('Loading','block',self.verbose)
|
|
1126
|
+
self.modifier = None
|
|
1127
|
+
self.check_custom = False
|
|
1128
|
+
self.is_trained = True
|
|
1129
|
+
with open(filename+'.pkl','rb') as f:
|
|
1130
|
+
params = pickle.load(f)
|
|
1131
|
+
for p in params:
|
|
1132
|
+
setattr(self,p, params[p])
|
|
1133
|
+
if 'verbose' in self.config['model_configs'].keys():
|
|
1134
|
+
self.config['model_configs'].pop('verbose')
|
|
1135
|
+
self.model = model(**self.config['model_configs'],optim_config = self.config['optim_config'],scheduler_config =self.config['scheduler_config'],verbose=self.verbose )
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
if weight_path is not None:
|
|
1139
|
+
tmp_path = weight_path
|
|
1140
|
+
else:
|
|
1141
|
+
if self.dirpath is not None:
|
|
1142
|
+
directory = self.dirpath
|
|
1143
|
+
else:
|
|
1144
|
+
directory = dirpath
|
|
1145
|
+
|
|
1146
|
+
if load_last:
|
|
1147
|
+
|
|
1148
|
+
try:
|
|
1149
|
+
tmp_path = os.path.join(directory,self.checkpoint_file_last.split('/')[-1])
|
|
1150
|
+
except Exception as _:
|
|
1151
|
+
beauty_string('checkpoint_file_last not defined try to load best','section',self.verbose)
|
|
1152
|
+
tmp_path = os.path.join(directory,self.checkpoint_file_best.split('/')[-1])
|
|
1153
|
+
else:
|
|
1154
|
+
try:
|
|
1155
|
+
tmp_path = os.path.join(directory,self.checkpoint_file_best.split('/')[-1])
|
|
1156
|
+
except Exception as _:
|
|
1157
|
+
beauty_string('checkpoint_file_best not defined try to load best','section',self.verbose)
|
|
1158
|
+
tmp_path = os.path.join(directory,self.checkpoint_file_last.split('/')[-1])
|
|
1159
|
+
try:
|
|
1160
|
+
#with torch.serialization.add_safe_globals([ListConfig]):
|
|
1161
|
+
if OLD_PL:
|
|
1162
|
+
self.model = self.model.load_from_checkpoint(tmp_path,verbose=self.verbose,)
|
|
1163
|
+
else:
|
|
1164
|
+
self.model = self.model.__class__.load_from_checkpoint(tmp_path,verbose=self.verbose,)
|
|
1165
|
+
|
|
1166
|
+
except Exception as e:
|
|
1167
|
+
beauty_string(f'There is a problem loading the weights on file {tmp_path} {e}','section',self.verbose)
|