dsipts 1.1.10__py3-none-any.whl → 1.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dsipts/data_structure/data_structure.py +57 -20
- dsipts/models/Autoformer.py +2 -1
- dsipts/models/CrossFormer.py +2 -1
- dsipts/models/D3VAE.py +2 -1
- dsipts/models/Diffusion.py +3 -0
- dsipts/models/DilatedConv.py +2 -1
- dsipts/models/DilatedConvED.py +2 -1
- dsipts/models/Duet.py +2 -1
- dsipts/models/ITransformer.py +5 -8
- dsipts/models/Informer.py +2 -1
- dsipts/models/LinearTS.py +2 -1
- dsipts/models/PatchTST.py +3 -0
- dsipts/models/RNN.py +2 -1
- dsipts/models/Samformer.py +3 -1
- dsipts/models/Simple.py +3 -1
- dsipts/models/TFT.py +4 -0
- dsipts/models/TIDE.py +4 -1
- dsipts/models/TTM.py +72 -166
- dsipts/models/TimeXER.py +3 -1
- dsipts/models/base.py +47 -35
- dsipts/models/base_v2.py +53 -38
- dsipts/models/duet/layers.py +6 -2
- {dsipts-1.1.10.dist-info → dsipts-1.1.12.dist-info}/METADATA +1 -1
- {dsipts-1.1.10.dist-info → dsipts-1.1.12.dist-info}/RECORD +26 -26
- {dsipts-1.1.10.dist-info → dsipts-1.1.12.dist-info}/WHEEL +0 -0
- {dsipts-1.1.10.dist-info → dsipts-1.1.12.dist-info}/top_level.txt +0 -0
|
@@ -35,7 +35,18 @@ from .modifiers import *
|
|
|
35
35
|
from aim.pytorch_lightning import AimLogger
|
|
36
36
|
import time
|
|
37
37
|
|
|
38
|
-
|
|
38
|
+
class DummyScaler():
|
|
39
|
+
def __init__(self):
|
|
40
|
+
pass
|
|
41
|
+
def fit(self,x):
|
|
42
|
+
pass
|
|
43
|
+
def transform(self,x):
|
|
44
|
+
return x
|
|
45
|
+
def inverse_transform(self,x):
|
|
46
|
+
return x
|
|
47
|
+
def fit_transform(self,x):
|
|
48
|
+
return x
|
|
49
|
+
|
|
39
50
|
|
|
40
51
|
pd.options.mode.chained_assignment = None
|
|
41
52
|
log = logging.getLogger(__name__)
|
|
@@ -210,20 +221,23 @@ class TimeSeries():
|
|
|
210
221
|
self.future_variables = []
|
|
211
222
|
self.target_variables = ['signal']
|
|
212
223
|
self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
|
|
213
|
-
|
|
224
|
+
self.num_var = list(np.sort(self.num_var))
|
|
214
225
|
|
|
215
226
|
def enrich(self,dataset,columns):
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
+
try:
|
|
228
|
+
if columns =='hour':
|
|
229
|
+
dataset[columns] = dataset.time.dt.hour
|
|
230
|
+
elif columns=='dow':
|
|
231
|
+
dataset[columns] = dataset.time.dt.weekday
|
|
232
|
+
elif columns=='month':
|
|
233
|
+
dataset[columns] = dataset.time.dt.month
|
|
234
|
+
elif columns=='minute':
|
|
235
|
+
dataset[columns] = dataset.time.dt.minute
|
|
236
|
+
else:
|
|
237
|
+
if columns not in dataset.columns:
|
|
238
|
+
beauty_string(f'I can not automatically enrich column {columns}. Please contact the developers or add it manually to your dataset.','section',True)
|
|
239
|
+
except:
|
|
240
|
+
beauty_string(f'I can not automatically enrich column {columns}. Probably not a temporal index.','section',True)
|
|
227
241
|
|
|
228
242
|
def load_signal(self,data:pd.DataFrame,
|
|
229
243
|
enrich_cat:List[str] = [],
|
|
@@ -300,7 +314,7 @@ class TimeSeries():
|
|
|
300
314
|
if check_past:
|
|
301
315
|
beauty_string('I will update past column adding all target columns, if you want to avoid this beahviour please use check_pass as false','info',self.verbose)
|
|
302
316
|
past_variables = list(set(past_variables).union(set(target_variables)))
|
|
303
|
-
|
|
317
|
+
past_variables = list(np.sort(past_variables))
|
|
304
318
|
self.cat_past_var = cat_past_var
|
|
305
319
|
self.cat_fut_var = cat_fut_var
|
|
306
320
|
|
|
@@ -321,14 +335,18 @@ class TimeSeries():
|
|
|
321
335
|
beauty_string('Categorical {c} already present, it will be added to categorical variable but not call the enriching function','info',self.verbose)
|
|
322
336
|
else:
|
|
323
337
|
self.enrich(dataset,c)
|
|
338
|
+
self.cat_past_var = list(np.sort(self.cat_past_var))
|
|
339
|
+
self.cat_fut_var = list(np.sort(self.cat_fut_var))
|
|
340
|
+
|
|
324
341
|
self.cat_var = list(set(self.cat_past_var+self.cat_fut_var)) ## all categorical data
|
|
325
|
-
|
|
342
|
+
self.cat_var = list(np.sort(self.cat_var))
|
|
326
343
|
self.dataset = dataset
|
|
327
344
|
self.past_variables = past_variables
|
|
328
345
|
self.future_variables = future_variables
|
|
329
346
|
self.target_variables = target_variables
|
|
330
347
|
self.out_vars = len(target_variables)
|
|
331
348
|
self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
|
|
349
|
+
self.num_var = list(np.sort(self.num_var))
|
|
332
350
|
if silly_model:
|
|
333
351
|
beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
|
|
334
352
|
self.future_variables+=self.target_variables
|
|
@@ -665,7 +683,8 @@ class TimeSeries():
|
|
|
665
683
|
#self.model.apply(weight_init_zeros)
|
|
666
684
|
|
|
667
685
|
self.config = config
|
|
668
|
-
|
|
686
|
+
|
|
687
|
+
|
|
669
688
|
beauty_string('Setting the model','block',self.verbose)
|
|
670
689
|
beauty_string(model,'',self.verbose)
|
|
671
690
|
|
|
@@ -790,8 +809,17 @@ class TimeSeries():
|
|
|
790
809
|
weight_exists = False
|
|
791
810
|
beauty_string('I can not load a previous model','section',self.verbose)
|
|
792
811
|
|
|
812
|
+
self.model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
|
|
813
|
+
if self.model.can_be_compiled():
|
|
814
|
+
try:
|
|
815
|
+
self.model = torch.compile(self.model)
|
|
816
|
+
beauty_string('Model COMPILED','block',self.verbose)
|
|
817
|
+
|
|
818
|
+
except:
|
|
819
|
+
beauty_string('Can not compile the model','block',self.verbose)
|
|
820
|
+
else:
|
|
821
|
+
beauty_string('Model can not still be compiled, be patient','block',self.verbose)
|
|
793
822
|
|
|
794
|
-
|
|
795
823
|
|
|
796
824
|
if OLD_PL:
|
|
797
825
|
trainer = pl.Trainer(default_root_dir=dirpath,
|
|
@@ -873,10 +901,19 @@ class TimeSeries():
|
|
|
873
901
|
self.losses = pd.DataFrame()
|
|
874
902
|
|
|
875
903
|
try:
|
|
904
|
+
|
|
876
905
|
if OLD_PL:
|
|
877
|
-
self.model
|
|
906
|
+
if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
|
|
907
|
+
self.model = self.model._orig_mod
|
|
908
|
+
self.model.load_from_checkpoint(self.checkpoint_file_last)
|
|
909
|
+
else:
|
|
910
|
+
self.model = self.model.load_from_checkpoint(self.checkpoint_file_last)
|
|
878
911
|
else:
|
|
879
|
-
self.model
|
|
912
|
+
if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
|
|
913
|
+
mm = self.model._orig_mod
|
|
914
|
+
self.model = mm.__class__.load_from_checkpoint(self.checkpoint_file_last)
|
|
915
|
+
else:
|
|
916
|
+
self.model = self.model.__class__.load_from_checkpoint(self.checkpoint_file_last)
|
|
880
917
|
|
|
881
918
|
except Exception as _:
|
|
882
919
|
beauty_string(f'There is a problem loading the weights on file MAYBE CHANGED HOW WEIGHTS ARE LOADED {self.checkpoint_file_last}','section',self.verbose)
|
|
@@ -1164,6 +1201,6 @@ class TimeSeries():
|
|
|
1164
1201
|
self.model = self.model.load_from_checkpoint(tmp_path,verbose=self.verbose,)
|
|
1165
1202
|
else:
|
|
1166
1203
|
self.model = self.model.__class__.load_from_checkpoint(tmp_path,verbose=self.verbose,)
|
|
1167
|
-
|
|
1204
|
+
self.model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
|
|
1168
1205
|
except Exception as e:
|
|
1169
1206
|
beauty_string(f'There is a problem loading the weights on file {tmp_path} {e}','section',self.verbose)
|
dsipts/models/Autoformer.py
CHANGED
|
@@ -148,7 +148,8 @@ class Autoformer(Base):
|
|
|
148
148
|
projection=nn.Linear(d_model, self.out_channels*self.mul, bias=True)
|
|
149
149
|
)
|
|
150
150
|
self.projection = nn.Linear(self.past_channels,self.out_channels*self.mul )
|
|
151
|
-
|
|
151
|
+
def can_be_compiled(self):
|
|
152
|
+
return True
|
|
152
153
|
def forward(self, batch):
|
|
153
154
|
|
|
154
155
|
|
dsipts/models/CrossFormer.py
CHANGED
dsipts/models/D3VAE.py
CHANGED
dsipts/models/Diffusion.py
CHANGED
|
@@ -425,6 +425,9 @@ class Diffusion(Base):
|
|
|
425
425
|
loss = self.compute_loss(batch,out)
|
|
426
426
|
return loss
|
|
427
427
|
|
|
428
|
+
def can_be_compiled(self):
|
|
429
|
+
return False
|
|
430
|
+
|
|
428
431
|
# function to concat embedded categorical variables
|
|
429
432
|
def cat_categorical_vars(self, batch:dict):
|
|
430
433
|
"""Extracting categorical context about past and future
|
dsipts/models/DilatedConv.py
CHANGED
dsipts/models/DilatedConvED.py
CHANGED
|
@@ -228,7 +228,8 @@ class DilatedConvED(Base):
|
|
|
228
228
|
nn.BatchNorm1d(hidden_RNN) if use_bn else nn.Dropout(dropout_rate) ,
|
|
229
229
|
Permute() if use_bn else nn.Identity() ,
|
|
230
230
|
nn.Linear(hidden_RNN ,self.mul))
|
|
231
|
-
|
|
231
|
+
def can_be_compiled(self):
|
|
232
|
+
return True
|
|
232
233
|
|
|
233
234
|
|
|
234
235
|
def forward(self, batch):
|
dsipts/models/Duet.py
CHANGED
|
@@ -136,7 +136,8 @@ class Duet(Base):
|
|
|
136
136
|
activation(),
|
|
137
137
|
nn.Linear(dim*2,self.out_channels*self.mul ))
|
|
138
138
|
|
|
139
|
-
|
|
139
|
+
def can_be_compiled(self):
|
|
140
|
+
return False
|
|
140
141
|
def forward(self, batch:dict)-> float:
|
|
141
142
|
# x: [Batch, Input length, Channel]
|
|
142
143
|
x_enc = batch['x_num_past'].to(self.device)
|
dsipts/models/ITransformer.py
CHANGED
|
@@ -8,6 +8,8 @@ import numpy as np
|
|
|
8
8
|
from .itransformer.Transformer_EncDec import Encoder, EncoderLayer
|
|
9
9
|
from .itransformer.SelfAttention_Family import FullAttention, AttentionLayer
|
|
10
10
|
from .itransformer.Embed import DataEmbedding_inverted
|
|
11
|
+
from ..data_structure.utils import beauty_string
|
|
12
|
+
from .utils import get_scope,get_activation,Embedding_cat_variables
|
|
11
13
|
|
|
12
14
|
try:
|
|
13
15
|
import lightning.pytorch as pl
|
|
@@ -17,12 +19,6 @@ except:
|
|
|
17
19
|
import pytorch_lightning as pl
|
|
18
20
|
OLD_PL = True
|
|
19
21
|
from .base import Base
|
|
20
|
-
from .utils import QuantileLossMO,Permute, get_activation
|
|
21
|
-
|
|
22
|
-
from typing import List, Union
|
|
23
|
-
from ..data_structure.utils import beauty_string
|
|
24
|
-
from .utils import get_scope
|
|
25
|
-
from .utils import Embedding_cat_variables
|
|
26
22
|
|
|
27
23
|
|
|
28
24
|
|
|
@@ -34,8 +30,6 @@ class ITransformer(Base):
|
|
|
34
30
|
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
35
31
|
|
|
36
32
|
def __init__(self,
|
|
37
|
-
|
|
38
|
-
|
|
39
33
|
# specific params
|
|
40
34
|
hidden_size:int,
|
|
41
35
|
d_model: int,
|
|
@@ -107,6 +101,9 @@ class ITransformer(Base):
|
|
|
107
101
|
)
|
|
108
102
|
self.projector = nn.Linear(d_model, self.future_steps*self.mul, bias=True)
|
|
109
103
|
|
|
104
|
+
def can_be_compiled(self):
|
|
105
|
+
return True
|
|
106
|
+
|
|
110
107
|
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
|
111
108
|
if self.use_norm:
|
|
112
109
|
# Normalization from Non-stationary Transformer
|
dsipts/models/Informer.py
CHANGED
dsipts/models/LinearTS.py
CHANGED
|
@@ -143,7 +143,8 @@ class LinearTS(Base):
|
|
|
143
143
|
activation(),
|
|
144
144
|
nn.BatchNorm1d(hidden_size//8) if use_bn else nn.Dropout(dropout_rate) ,
|
|
145
145
|
nn.Linear(hidden_size//8,self.future_steps*self.mul)))
|
|
146
|
-
|
|
146
|
+
def can_be_compiled(self):
|
|
147
|
+
return True
|
|
147
148
|
def forward(self, batch):
|
|
148
149
|
|
|
149
150
|
x = batch['x_num_past'].to(self.device)
|
dsipts/models/PatchTST.py
CHANGED
|
@@ -133,6 +133,9 @@ class PatchTST(Base):
|
|
|
133
133
|
|
|
134
134
|
#self.final_linear = nn.Sequential(nn.Linear(past_channels,past_channels//2),activation(),nn.Dropout(dropout_rate), nn.Linear(past_channels//2,out_channels) )
|
|
135
135
|
|
|
136
|
+
def can_be_compiled(self):
|
|
137
|
+
return True
|
|
138
|
+
|
|
136
139
|
def forward(self, batch): # x: [Batch, Input length, Channel]
|
|
137
140
|
|
|
138
141
|
|
dsipts/models/RNN.py
CHANGED
dsipts/models/Samformer.py
CHANGED
dsipts/models/Simple.py
CHANGED
|
@@ -67,7 +67,9 @@ class Simple(Base):
|
|
|
67
67
|
self.linear = (nn.Sequential(nn.Linear(emb_past_out_channel*self.past_steps+emb_fut_out_channel*self.future_steps+self.past_steps*self.past_channels+self.future_channels*self.future_steps,hidden_size),
|
|
68
68
|
activation(),nn.Dropout(dropout_rate),
|
|
69
69
|
nn.Linear(hidden_size,self.out_channels*self.future_steps*self.mul)))
|
|
70
|
-
|
|
70
|
+
def can_be_compiled(self):
|
|
71
|
+
return True
|
|
72
|
+
|
|
71
73
|
def forward(self, batch):
|
|
72
74
|
|
|
73
75
|
x = batch['x_num_past'].to(self.device)
|
dsipts/models/TFT.py
CHANGED
dsipts/models/TIDE.py
CHANGED
|
@@ -106,7 +106,10 @@ class TIDE(Base):
|
|
|
106
106
|
|
|
107
107
|
# linear for Y lookback
|
|
108
108
|
self.linear_target = nn.Linear(self.past_steps*self.out_channels, self.future_steps*self.out_channels*self.mul)
|
|
109
|
-
|
|
109
|
+
|
|
110
|
+
def can_be_compiled(self):
|
|
111
|
+
return False
|
|
112
|
+
|
|
110
113
|
|
|
111
114
|
def forward(self, batch:dict)-> float:
|
|
112
115
|
"""training process of the diffusion network
|
dsipts/models/TTM.py
CHANGED
|
@@ -12,240 +12,146 @@ except:
|
|
|
12
12
|
from .base import Base
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
from typing import List,Union
|
|
17
|
-
|
|
18
|
-
from .utils import QuantileLossMO
|
|
15
|
+
from .ttm.utils import get_model, get_frequency_token, count_parameters, DEFAULT_FREQUENCY_MAPPING
|
|
19
16
|
from ..data_structure.utils import beauty_string
|
|
20
|
-
from .
|
|
21
|
-
|
|
17
|
+
from .utils import get_scope
|
|
22
18
|
|
|
23
19
|
class TTM(Base):
|
|
20
|
+
handle_multivariate = True
|
|
21
|
+
handle_future_covariates = True
|
|
22
|
+
handle_categorical_variables = True
|
|
23
|
+
handle_quantile_loss = True
|
|
24
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
25
|
+
|
|
24
26
|
def __init__(self,
|
|
25
27
|
model_path:str,
|
|
26
|
-
past_steps:int,
|
|
27
|
-
future_steps:int,
|
|
28
|
-
freq_prefix_tuning:bool,
|
|
29
|
-
freq:str,
|
|
30
28
|
prefer_l1_loss:bool, # exog: set true to use l1 loss
|
|
31
29
|
prefer_longer_context:bool,
|
|
32
|
-
loss_type:str,
|
|
33
|
-
num_input_channels,
|
|
34
30
|
prediction_channel_indices,
|
|
35
|
-
|
|
31
|
+
exogenous_channel_indices_cont,
|
|
32
|
+
exogenous_channel_indices_cat,
|
|
36
33
|
decoder_mode,
|
|
34
|
+
freq,
|
|
35
|
+
freq_prefix_tuning,
|
|
37
36
|
fcm_context_length,
|
|
38
37
|
fcm_use_mixer,
|
|
39
38
|
fcm_mix_layers,
|
|
40
39
|
fcm_prepend_past,
|
|
41
40
|
enable_forecast_channel_mixing,
|
|
42
|
-
out_channels:int,
|
|
43
|
-
embs:List[int],
|
|
44
|
-
remove_last = False,
|
|
45
|
-
optim:Union[str,None]=None,
|
|
46
|
-
optim_config:dict=None,
|
|
47
|
-
scheduler_config:dict=None,
|
|
48
|
-
verbose = False,
|
|
49
|
-
use_quantiles=False,
|
|
50
|
-
persistence_weight:float=0.0,
|
|
51
|
-
quantiles:List[int]=[],
|
|
52
41
|
**kwargs)->None:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
model_path (str): _description_
|
|
57
|
-
past_steps (int): _description_
|
|
58
|
-
future_steps (int): _description_
|
|
59
|
-
freq_prefix_tuning (bool): _description_
|
|
60
|
-
freq (str): _description_
|
|
61
|
-
prefer_l1_loss (bool): _description_
|
|
62
|
-
loss_type (str): _description_
|
|
63
|
-
num_input_channels (_type_): _description_
|
|
64
|
-
prediction_channel_indices (_type_): _description_
|
|
65
|
-
exogenous_channel_indices (_type_): _description_
|
|
66
|
-
decoder_mode (_type_): _description_
|
|
67
|
-
fcm_context_length (_type_): _description_
|
|
68
|
-
fcm_use_mixer (_type_): _description_
|
|
69
|
-
fcm_mix_layers (_type_): _description_
|
|
70
|
-
fcm_prepend_past (_type_): _description_
|
|
71
|
-
enable_forecast_channel_mixing (_type_): _description_
|
|
72
|
-
out_channels (int): _description_
|
|
73
|
-
embs (List[int]): _description_
|
|
74
|
-
remove_last (bool, optional): _description_. Defaults to False.
|
|
75
|
-
optim (Union[str,None], optional): _description_. Defaults to None.
|
|
76
|
-
optim_config (dict, optional): _description_. Defaults to None.
|
|
77
|
-
scheduler_config (dict, optional): _description_. Defaults to None.
|
|
78
|
-
verbose (bool, optional): _description_. Defaults to False.
|
|
79
|
-
use_quantiles (bool, optional): _description_. Defaults to False.
|
|
80
|
-
persistence_weight (float, optional): _description_. Defaults to 0.0.
|
|
81
|
-
quantiles (List[int], optional): _description_. Defaults to [].
|
|
82
|
-
"""
|
|
83
|
-
super(TTM, self).__init__(verbose)
|
|
42
|
+
|
|
43
|
+
super().__init__(**kwargs)
|
|
84
44
|
self.save_hyperparameters(logger=False)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
self.
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
self.remove_last = remove_last
|
|
93
|
-
self.embs = embs
|
|
94
|
-
self.freq = freq
|
|
95
|
-
self.extend_variables = False
|
|
96
|
-
|
|
97
|
-
# NOTE: For Hydra
|
|
98
|
-
prediction_channel_indices = list(prediction_channel_indices)
|
|
99
|
-
exogenous_channel_indices = list(exogenous_channel_indices)
|
|
100
|
-
|
|
101
|
-
if len(quantiles)>0:
|
|
102
|
-
assert len(quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
103
|
-
self.use_quantiles = True
|
|
104
|
-
self.mul = len(quantiles)
|
|
105
|
-
self.loss = QuantileLossMO(quantiles)
|
|
106
|
-
self.extend_variables = True
|
|
107
|
-
if out_channels * 3 != len(prediction_channel_indices):
|
|
108
|
-
prediction_channel_indices, exogenous_channel_indices, num_input_channels = self.__add_quantile_features(prediction_channel_indices,
|
|
109
|
-
exogenous_channel_indices,
|
|
110
|
-
out_channels)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
self.index_fut = list(exogenous_channel_indices_cont)
|
|
49
|
+
|
|
50
|
+
if len(exogenous_channel_indices_cat)>0:
|
|
51
|
+
self.index_fut_cat = (self.past_channels+len(self.embs_past))+list(exogenous_channel_indices_cat)
|
|
111
52
|
else:
|
|
112
|
-
self.
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
53
|
+
self.index_fut_cat = []
|
|
54
|
+
self.freq = freq
|
|
55
|
+
|
|
56
|
+
base_freq_token = get_frequency_token(self.freq) # e.g., shape [n_token] or scalar
|
|
57
|
+
# ensure it's a tensor of integer type
|
|
58
|
+
if not torch.is_tensor(base_freq_token):
|
|
59
|
+
base_freq_token = torch.tensor(base_freq_token)
|
|
60
|
+
base_freq_token = base_freq_token.long()
|
|
61
|
+
self.register_buffer("token", base_freq_token, persistent=True)
|
|
62
|
+
|
|
119
63
|
|
|
120
64
|
self.model = get_model(
|
|
121
65
|
model_path=model_path,
|
|
122
|
-
context_length=past_steps,
|
|
123
|
-
prediction_length=future_steps,
|
|
124
|
-
freq_prefix_tuning=freq_prefix_tuning,
|
|
125
|
-
freq=freq,
|
|
66
|
+
context_length=self.past_steps,
|
|
67
|
+
prediction_length=self.future_steps,
|
|
126
68
|
prefer_l1_loss=prefer_l1_loss,
|
|
127
69
|
prefer_longer_context=prefer_longer_context,
|
|
128
|
-
num_input_channels=
|
|
70
|
+
num_input_channels=self.past_channels+len(self.embs_past), #giusto
|
|
129
71
|
decoder_mode=decoder_mode,
|
|
130
72
|
prediction_channel_indices=list(prediction_channel_indices),
|
|
131
|
-
exogenous_channel_indices=
|
|
73
|
+
exogenous_channel_indices=self.index_fut + self.index_fut_cat,
|
|
132
74
|
fcm_context_length=fcm_context_length,
|
|
133
75
|
fcm_use_mixer=fcm_use_mixer,
|
|
134
76
|
fcm_mix_layers=fcm_mix_layers,
|
|
77
|
+
freq=freq,
|
|
78
|
+
freq_prefix_tuning=freq_prefix_tuning,
|
|
135
79
|
fcm_prepend_past=fcm_prepend_past,
|
|
136
|
-
#loss='mse',
|
|
137
80
|
enable_forecast_channel_mixing=enable_forecast_channel_mixing,
|
|
81
|
+
|
|
138
82
|
)
|
|
139
|
-
self.
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
prediction_channel_indices = list(range(out_channels * 3))
|
|
143
|
-
exogenous_channel_indices = [prediction_channel_indices[-1] + i for i in range(1, len(exogenous_channel_indices)+1)]
|
|
144
|
-
num_input_channels = len(prediction_channel_indices) + len(exogenous_channel_indices)
|
|
145
|
-
return prediction_channel_indices, exogenous_channel_indices, num_input_channels
|
|
83
|
+
hidden_size = self.model.config.hidden_size
|
|
84
|
+
self.model.prediction_head = torch.nn.Linear(hidden_size, self.out_channels*self.mul)
|
|
85
|
+
self._freeze_backbone()
|
|
146
86
|
|
|
147
|
-
def
|
|
87
|
+
def _freeze_backbone(self):
|
|
148
88
|
"""
|
|
149
89
|
Freeze the backbone of the model.
|
|
150
90
|
This is useful when you want to fine-tune only the head of the model.
|
|
151
91
|
"""
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
count_parameters(self.model),
|
|
155
|
-
)
|
|
92
|
+
beauty_string(f"Number of params before freezing backbone:{count_parameters(self.model)}",'info',self.verbose)
|
|
93
|
+
|
|
156
94
|
# Freeze the backbone of the model
|
|
157
95
|
for param in self.model.backbone.parameters():
|
|
158
96
|
param.requires_grad = False
|
|
159
97
|
# Count params
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
count_parameters(self.model),
|
|
163
|
-
)
|
|
98
|
+
beauty_string(f"Number of params after freezing the backbone: {count_parameters(self.model)}",'info',self.verbose)
|
|
99
|
+
|
|
164
100
|
|
|
165
|
-
def
|
|
166
|
-
|
|
167
|
-
|
|
101
|
+
def _scaler_past(self, input):
|
|
102
|
+
for i, e in enumerate(self.embs_past):
|
|
103
|
+
input[:,:,i] = input[:, :, i] / (e-1)
|
|
104
|
+
return input
|
|
105
|
+
def _scaler_fut(self, input):
|
|
106
|
+
for i, e in enumerate(self.embs_fut):
|
|
168
107
|
input[:,:,i] = input[:, :, i] / (e-1)
|
|
169
108
|
return input
|
|
170
|
-
|
|
171
|
-
def __build_tupla_indexes(self, size, target_idx, current_idx):
|
|
172
|
-
permute = list(range(size))
|
|
173
|
-
history = dict()
|
|
174
|
-
for j, i in enumerate(target_idx):
|
|
175
|
-
c = history.get(current_idx[j], current_idx[j])
|
|
176
|
-
permute[i], permute[c] = current_idx[j], i
|
|
177
|
-
history[i] = current_idx[j]
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def __permute_indexes(self, values, target_idx, current_idx):
|
|
181
|
-
if current_idx is None or target_idx is None:
|
|
182
|
-
raise ValueError("Indexes cannot be None")
|
|
183
|
-
if sorted(current_idx) != sorted(target_idx):
|
|
184
|
-
return values[..., self.__build_tupla_indexes(values.shape[-1], target_idx, current_idx)]
|
|
185
|
-
return values
|
|
186
|
-
|
|
187
|
-
def __extend_with_quantile_variables(self, x, original_indexes):
|
|
188
|
-
covariate_indexes = [i for i in range(x.shape[-1]) if i not in original_indexes]
|
|
189
|
-
covariate_tensors = x[..., covariate_indexes]
|
|
190
|
-
|
|
191
|
-
new_tensors = [x[..., target_index] for target_index in original_indexes for _ in range(3)]
|
|
192
109
|
|
|
193
|
-
|
|
194
|
-
return
|
|
195
|
-
|
|
110
|
+
def can_be_compiled(self):
|
|
111
|
+
return True
|
|
112
|
+
|
|
196
113
|
def forward(self, batch):
|
|
197
|
-
x_enc = batch['x_num_past']
|
|
114
|
+
x_enc = batch['x_num_past'].to(self.device)
|
|
198
115
|
original_indexes = batch['idx_target'][0].tolist()
|
|
199
|
-
original_indexes_future = batch['idx_target_future'][0].tolist()
|
|
200
116
|
|
|
201
117
|
|
|
202
|
-
if self.extend_variables:
|
|
203
|
-
x_enc, original_indexes = self.__extend_with_quantile_variables(x_enc, original_indexes)
|
|
204
|
-
|
|
205
118
|
if 'x_cat_past' in batch.keys():
|
|
206
119
|
x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
|
|
207
|
-
x_mark_enc = self.
|
|
120
|
+
x_mark_enc = self._scaler_past(x_mark_enc)
|
|
208
121
|
past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
|
|
209
122
|
else:
|
|
210
123
|
past_values = x_enc
|
|
211
124
|
|
|
212
|
-
|
|
125
|
+
future_values = torch.zeros_like(past_values).to(self.device)
|
|
126
|
+
future_values = future_values[:,:self.future_steps,:]
|
|
127
|
+
|
|
213
128
|
if 'x_num_future' in batch.keys():
|
|
214
|
-
|
|
215
|
-
if self.extend_variables:
|
|
216
|
-
x_dec, original_indexes_future = self.__extend_with_quantile_variables(x_dec, original_indexes_future)
|
|
129
|
+
future_values[:,:,self.index_fut] = batch['x_num_future'].to(self.device)
|
|
217
130
|
if 'x_cat_future' in batch.keys():
|
|
218
131
|
x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
|
|
219
|
-
x_mark_dec = self.
|
|
220
|
-
future_values =
|
|
221
|
-
|
|
222
|
-
future_values = x_dec
|
|
223
|
-
|
|
224
|
-
if self.remove_last:
|
|
225
|
-
idx_target = batch['idx_target'][0]
|
|
226
|
-
x_start = x_enc[:,-1,idx_target].unsqueeze(1)
|
|
227
|
-
x_enc[:,:,idx_target]-=x_start
|
|
228
|
-
|
|
132
|
+
x_mark_dec = self._scaler_fut(x_mark_dec)
|
|
133
|
+
future_values[:,:,self.index_cat_fut] = x_mark_dec
|
|
134
|
+
|
|
229
135
|
|
|
230
|
-
|
|
136
|
+
#investigating!! problem with dynamo!
|
|
137
|
+
#freq_token = get_frequency_token(self.freq).repeat(past_values.shape[0])
|
|
231
138
|
|
|
232
|
-
|
|
233
|
-
|
|
139
|
+
batch_size = past_values.shape[0]
|
|
140
|
+
freq_token = self.token.repeat(batch_size).long().to(self.device)
|
|
234
141
|
|
|
235
|
-
freq_token = get_frequency_token(self.freq).repeat(x_enc.shape[0])
|
|
236
142
|
|
|
237
143
|
res = self.model(
|
|
238
144
|
past_values= past_values,
|
|
239
|
-
future_values= future_values,
|
|
145
|
+
future_values= future_values,# future_values if future_values.shape[0]>0 else None,
|
|
240
146
|
past_observed_mask = None,
|
|
241
147
|
future_observed_mask = None,
|
|
242
148
|
output_hidden_states = False,
|
|
243
149
|
return_dict = False,
|
|
244
|
-
freq_token= freq_token,
|
|
150
|
+
freq_token= freq_token,#[0:past_values.shape[0]], ##investigating
|
|
245
151
|
static_categorical_values = None
|
|
246
152
|
)
|
|
247
|
-
|
|
248
|
-
|
|
153
|
+
|
|
154
|
+
|
|
249
155
|
BS = res.shape[0]
|
|
250
156
|
return res.reshape(BS,self.future_steps,-1,self.mul)
|
|
251
157
|
|
dsipts/models/TimeXER.py
CHANGED
dsipts/models/base.py
CHANGED
|
@@ -111,8 +111,11 @@ class Base(pl.LightningModule):
|
|
|
111
111
|
self.train_loss_epoch = -100.0
|
|
112
112
|
self.verbose = verbose
|
|
113
113
|
self.name = self.__class__.__name__
|
|
114
|
-
self.train_epoch_metrics
|
|
115
|
-
self.validation_epoch_metrics
|
|
114
|
+
self.register_buffer("train_epoch_metrics", torch.tensor(0.0))
|
|
115
|
+
self.register_buffer("validation_epoch_metrics", torch.tensor(0.0))
|
|
116
|
+
self.register_buffer("train_epoch_count", torch.tensor(0))
|
|
117
|
+
self.register_buffer("validation_epoch_count", torch.tensor(0))
|
|
118
|
+
|
|
116
119
|
|
|
117
120
|
self.use_quantiles = True if len(quantiles)>0 else False
|
|
118
121
|
self.quantiles = quantiles
|
|
@@ -295,7 +298,8 @@ class Base(pl.LightningModule):
|
|
|
295
298
|
y_hat = self(batch)
|
|
296
299
|
loss = self.compute_loss(batch,y_hat)
|
|
297
300
|
|
|
298
|
-
self.train_epoch_metrics
|
|
301
|
+
self.train_epoch_metrics+=loss.detach()
|
|
302
|
+
self.train_epoch_count +=1
|
|
299
303
|
return loss
|
|
300
304
|
|
|
301
305
|
|
|
@@ -311,27 +315,20 @@ class Base(pl.LightningModule):
|
|
|
311
315
|
y_hat = self(batch)
|
|
312
316
|
score = 0
|
|
313
317
|
if batch_idx==0:
|
|
314
|
-
|
|
315
|
-
idx = 1
|
|
316
|
-
else:
|
|
317
|
-
idx = 0
|
|
318
|
-
#track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
|
|
319
|
-
|
|
318
|
+
|
|
320
319
|
if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
return self.compute_loss(batch,y_hat)+score
|
|
334
|
-
|
|
320
|
+
self._val_outputs.append({
|
|
321
|
+
"y": batch['y'].detach().cpu(),
|
|
322
|
+
"y_hat": y_hat.detach().cpu()
|
|
323
|
+
})
|
|
324
|
+
self.validation_epoch_metrics = (self.compute_loss(batch,y_hat)+score).detach()
|
|
325
|
+
self.validation_epoch_count+=1
|
|
326
|
+
|
|
327
|
+
return None #self.compute_loss(batch,y_hat)+score
|
|
328
|
+
|
|
329
|
+
def on_validation_start(self):
|
|
330
|
+
# reset buffer each epoch
|
|
331
|
+
self._val_outputs = []
|
|
335
332
|
|
|
336
333
|
def validation_epoch_end(self, outs):
|
|
337
334
|
"""
|
|
@@ -339,14 +336,30 @@ class Base(pl.LightningModule):
|
|
|
339
336
|
|
|
340
337
|
:meta private:
|
|
341
338
|
"""
|
|
342
|
-
if len(
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
339
|
+
if len(self._val_outputs)>0:
|
|
340
|
+
ys = torch.cat([o["y"] for o in self._val_outputs])
|
|
341
|
+
y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
|
|
342
|
+
if self.use_quantiles:
|
|
343
|
+
idx = 1
|
|
344
|
+
else:
|
|
345
|
+
idx = 0
|
|
346
|
+
for i in range(ys.shape[2]):
|
|
347
|
+
real = ys[0,:,i].cpu().detach().numpy()
|
|
348
|
+
pred = y_hats[0,:,i,idx].cpu().detach().numpy()
|
|
349
|
+
fig, ax = plt.subplots(figsize=(7,5))
|
|
350
|
+
ax.plot(real,'o-',label='real')
|
|
351
|
+
ax.plot(pred,'o-',label='pred')
|
|
352
|
+
ax.legend()
|
|
353
|
+
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
354
|
+
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
355
|
+
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
356
|
+
plt.close(fig)
|
|
357
|
+
avg = self.validation_epoch_metrics/self.validation_epoch_count
|
|
358
|
+
|
|
359
|
+
self.validation_epoch_metrics.zero_()
|
|
360
|
+
self.validation_epoch_count.zero_()
|
|
361
|
+
self.log("val_loss", avg,sync_dist=True)
|
|
362
|
+
beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
|
|
350
363
|
|
|
351
364
|
def training_epoch_end(self, outs):
|
|
352
365
|
"""
|
|
@@ -355,12 +368,11 @@ class Base(pl.LightningModule):
|
|
|
355
368
|
:meta private:
|
|
356
369
|
"""
|
|
357
370
|
|
|
358
|
-
loss =
|
|
359
|
-
self.log("train_loss", loss
|
|
371
|
+
loss = self.train_epoch_metrics/self.global_step
|
|
372
|
+
self.log("train_loss", loss,sync_dist=True)
|
|
360
373
|
self.count_epoch+=1
|
|
361
374
|
|
|
362
|
-
self.train_loss_epoch = loss
|
|
363
|
-
|
|
375
|
+
self.train_loss_epoch = loss
|
|
364
376
|
def compute_loss(self,batch,y_hat):
|
|
365
377
|
"""
|
|
366
378
|
custom loss calculation
|
dsipts/models/base_v2.py
CHANGED
|
@@ -15,7 +15,6 @@ from typing import List, Union
|
|
|
15
15
|
from .utils import QuantileLossMO, CPRS
|
|
16
16
|
import torch.nn as nn
|
|
17
17
|
|
|
18
|
-
|
|
19
18
|
def standardize_momentum(x,order):
|
|
20
19
|
mean = torch.mean(x,1).unsqueeze(1).repeat(1,x.shape[1],1)
|
|
21
20
|
num = torch.pow(x-mean,order).mean(axis=1)
|
|
@@ -113,8 +112,13 @@ class Base(pl.LightningModule):
|
|
|
113
112
|
self.train_loss_epoch = -100.0
|
|
114
113
|
self.verbose = verbose
|
|
115
114
|
self.name = self.__class__.__name__
|
|
116
|
-
self.train_epoch_metrics =
|
|
117
|
-
self.validation_epoch_metrics =
|
|
115
|
+
#self.train_epoch_metrics = 0
|
|
116
|
+
#self.validation_epoch_metrics = 0
|
|
117
|
+
|
|
118
|
+
self.register_buffer("train_epoch_metrics", torch.tensor(0.0))
|
|
119
|
+
self.register_buffer("validation_epoch_metrics", torch.tensor(0.0))
|
|
120
|
+
self.register_buffer("train_epoch_count", torch.tensor(0))
|
|
121
|
+
self.register_buffer("validation_epoch_count", torch.tensor(0))
|
|
118
122
|
|
|
119
123
|
self.use_quantiles = True if len(quantiles)>0 else False
|
|
120
124
|
self.quantiles = quantiles
|
|
@@ -299,10 +303,11 @@ class Base(pl.LightningModule):
|
|
|
299
303
|
y_hat = self(batch)
|
|
300
304
|
loss = self.compute_loss(batch,y_hat)
|
|
301
305
|
|
|
302
|
-
self.train_epoch_metrics
|
|
306
|
+
self.train_epoch_metrics+=loss.detach()
|
|
307
|
+
self.train_epoch_count +=1
|
|
303
308
|
return loss
|
|
304
309
|
|
|
305
|
-
|
|
310
|
+
|
|
306
311
|
def validation_step(self, batch, batch_idx):
|
|
307
312
|
"""
|
|
308
313
|
pythotrch lightening stuff
|
|
@@ -315,42 +320,54 @@ class Base(pl.LightningModule):
|
|
|
315
320
|
else:
|
|
316
321
|
y_hat = self(batch)
|
|
317
322
|
score = 0
|
|
318
|
-
|
|
319
|
-
if self.use_quantiles:
|
|
320
|
-
idx = 1
|
|
321
|
-
else:
|
|
322
|
-
idx = 0
|
|
323
|
-
#track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
|
|
323
|
+
#log_this_batch = (batch_idx == 0) and (self.count_epoch % int(max(self.trainer.max_epochs / 100,1)) == 1)
|
|
324
324
|
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
335
|
-
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
336
|
-
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
337
|
-
self.validation_epoch_metrics.append(self.compute_loss(batch,y_hat)+score)
|
|
338
|
-
return
|
|
325
|
+
#if log_this_batch:
|
|
326
|
+
#track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
|
|
327
|
+
self._val_outputs=[{
|
|
328
|
+
"y": batch['y'].detach().cpu(),
|
|
329
|
+
"y_hat": y_hat.detach().cpu()
|
|
330
|
+
}]
|
|
331
|
+
self.validation_epoch_metrics+= (self.compute_loss(batch,y_hat)+score).detach()
|
|
332
|
+
self.validation_epoch_count+=1
|
|
333
|
+
return None
|
|
339
334
|
|
|
335
|
+
def on_validation_start(self):
|
|
336
|
+
# reset buffer each epoch
|
|
337
|
+
self._val_outputs = []
|
|
338
|
+
|
|
340
339
|
|
|
341
340
|
def on_validation_epoch_end(self):
|
|
342
341
|
"""
|
|
343
342
|
pythotrch lightening stuff
|
|
344
343
|
|
|
345
344
|
:meta private:
|
|
346
|
-
"""
|
|
345
|
+
"""
|
|
347
346
|
|
|
348
|
-
if len(self.
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
347
|
+
if len(self._val_outputs)>0:
|
|
348
|
+
ys = torch.cat([o["y"] for o in self._val_outputs])
|
|
349
|
+
y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
|
|
350
|
+
if self.use_quantiles:
|
|
351
|
+
idx = 1
|
|
352
|
+
else:
|
|
353
|
+
idx = 0
|
|
354
|
+
for i in range(ys.shape[2]):
|
|
355
|
+
real = ys[0,:,i].cpu().detach().numpy()
|
|
356
|
+
pred = y_hats[0,:,i,idx].cpu().detach().numpy()
|
|
357
|
+
fig, ax = plt.subplots(figsize=(7,5))
|
|
358
|
+
ax.plot(real,'o-',label='real')
|
|
359
|
+
ax.plot(pred,'o-',label='pred')
|
|
360
|
+
ax.legend()
|
|
361
|
+
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
362
|
+
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
363
|
+
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
364
|
+
plt.close(fig)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
avg = self.validation_epoch_metrics/self.validation_epoch_count
|
|
368
|
+
|
|
369
|
+
self.validation_epoch_metrics.zero_()
|
|
370
|
+
self.validation_epoch_count.zero_()
|
|
354
371
|
self.log("val_loss", avg,sync_dist=True)
|
|
355
372
|
beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
|
|
356
373
|
|
|
@@ -361,14 +378,12 @@ class Base(pl.LightningModule):
|
|
|
361
378
|
|
|
362
379
|
:meta private:
|
|
363
380
|
"""
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
beauty_string(f'THIS IS A BUG, It should be polulated','info',self.verbose)
|
|
367
|
-
else:
|
|
368
|
-
avg = np.stack(self.train_epoch_metrics).mean()
|
|
381
|
+
|
|
382
|
+
avg = self.train_epoch_metrics/self.train_epoch_count
|
|
369
383
|
self.log("train_loss", avg,sync_dist=True)
|
|
370
384
|
self.count_epoch+=1
|
|
371
|
-
self.train_epoch_metrics
|
|
385
|
+
self.train_epoch_metrics.zero_()
|
|
386
|
+
self.train_epoch_count.zero_()
|
|
372
387
|
self.train_loss_epoch = avg
|
|
373
388
|
|
|
374
389
|
def compute_loss(self,batch,y_hat):
|
dsipts/models/duet/layers.py
CHANGED
|
@@ -219,7 +219,7 @@ class SparseDispatcher(object):
|
|
|
219
219
|
# expand according to batch index so we can just split by _part_sizes
|
|
220
220
|
inp_exp = inp[self._batch_index].squeeze(1)
|
|
221
221
|
return torch.split(inp_exp, self._part_sizes, dim=0)
|
|
222
|
-
|
|
222
|
+
|
|
223
223
|
def combine(self, expert_out, multiply_by_gates=True):
|
|
224
224
|
"""Sum together the expert output, weighted by the gates.
|
|
225
225
|
The slice corresponding to a particular batch element `b` is computed
|
|
@@ -234,7 +234,9 @@ class SparseDispatcher(object):
|
|
|
234
234
|
a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
|
|
235
235
|
"""
|
|
236
236
|
# apply exp to expert outputs, so we are not longer in log space
|
|
237
|
+
|
|
237
238
|
stitched = torch.cat(expert_out, 0)
|
|
239
|
+
|
|
238
240
|
if multiply_by_gates:
|
|
239
241
|
# stitched = stitched.mul(self._nonzero_gates)
|
|
240
242
|
stitched = torch.einsum("i...,ij->i...", stitched, self._nonzero_gates)
|
|
@@ -430,9 +432,11 @@ class Linear_extractor_cluster(nn.Module):
|
|
|
430
432
|
expert_inputs = dispatcher.dispatch(x_norm)
|
|
431
433
|
|
|
432
434
|
gates = dispatcher.expert_to_gates()
|
|
435
|
+
|
|
433
436
|
expert_outputs = [
|
|
434
437
|
self.experts[i](expert_inputs[i]) for i in range(self.num_experts)
|
|
435
438
|
]
|
|
439
|
+
#y = dispatcher.combine([e for e in expert_outputs if len(e)>0])
|
|
440
|
+
#with torch._dynamo.disable():
|
|
436
441
|
y = dispatcher.combine(expert_outputs)
|
|
437
|
-
|
|
438
442
|
return y, loss
|
|
@@ -3,33 +3,33 @@ dsipts/data_management/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
|
|
|
3
3
|
dsipts/data_management/monash.py,sha256=aZxq9FbIH6IsU8Lwou1hAokXjgOAK-wdl2VAeFg2k4M,13075
|
|
4
4
|
dsipts/data_management/public_datasets.py,sha256=yXFzOZZ-X0ZG1DoqVU-zFmEGVMc2033YDQhRgYxY8ws,6793
|
|
5
5
|
dsipts/data_structure/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
-
dsipts/data_structure/data_structure.py,sha256=
|
|
6
|
+
dsipts/data_structure/data_structure.py,sha256=KVkjTVjc7NznJIou4LYGzMbzE7ye-K3ll65GEgn2qKg,60814
|
|
7
7
|
dsipts/data_structure/modifiers.py,sha256=qlry9dfw8pEE0GrvgwROZJkJ6oPpUnjEHPIG5qIetss,7948
|
|
8
8
|
dsipts/data_structure/utils.py,sha256=QwfKPZgSy6DIw5n6ztOdPJIAnzo4EnlMTgRbpiWnyko,6593
|
|
9
|
-
dsipts/models/Autoformer.py,sha256=
|
|
10
|
-
dsipts/models/CrossFormer.py,sha256=
|
|
11
|
-
dsipts/models/D3VAE.py,sha256=
|
|
12
|
-
dsipts/models/Diffusion.py,sha256=
|
|
13
|
-
dsipts/models/DilatedConv.py,sha256=
|
|
14
|
-
dsipts/models/DilatedConvED.py,sha256=
|
|
15
|
-
dsipts/models/Duet.py,sha256=
|
|
16
|
-
dsipts/models/ITransformer.py,sha256=
|
|
17
|
-
dsipts/models/Informer.py,sha256=
|
|
18
|
-
dsipts/models/LinearTS.py,sha256=
|
|
19
|
-
dsipts/models/PatchTST.py,sha256=
|
|
9
|
+
dsipts/models/Autoformer.py,sha256=nUQvPC_qtajLT1AHdNJmF_P3ZL01j3spkZ4ubxdGF3g,8497
|
|
10
|
+
dsipts/models/CrossFormer.py,sha256=ClW6H_hrtLJH0iqTC7q_ya_Bwc_Xu-0lpAN5w2DSUYk,6526
|
|
11
|
+
dsipts/models/D3VAE.py,sha256=d1aY6kGjBSxZncN-KPWpdUGunu182ng2QFInGFrKYQM,6903
|
|
12
|
+
dsipts/models/Diffusion.py,sha256=owst4IxA3hkEEIrn5K-zwAYWUzEhouiRPwM4nTLcyoE,40786
|
|
13
|
+
dsipts/models/DilatedConv.py,sha256=TMDzd_cNgCZa6YusVVVGbTGGH3YlMz0IZZ9ZxRrJ3i4,14334
|
|
14
|
+
dsipts/models/DilatedConvED.py,sha256=KwG83yHqoEx_Vmea69zTPsSP1-0GdOUrtXwvhNDuWj8,14048
|
|
15
|
+
dsipts/models/Duet.py,sha256=m67PStuYE6vkFUFUofBrrLryx1ZUZropyVGcu_ygOx8,7681
|
|
16
|
+
dsipts/models/ITransformer.py,sha256=2WXqqEvnWH2DqRQyXfGm4Eg4_q32GFy2XnNeoTl-KmY,7310
|
|
17
|
+
dsipts/models/Informer.py,sha256=gxCdU2KkNhadyMujBA5A0eP6SPN4Q0IkEIogLYwvz5k,6970
|
|
18
|
+
dsipts/models/LinearTS.py,sha256=vXaGpbbkfdpzpTEWZ1hs6QI6j3vDvevD3SyKQXo6Sdg,9151
|
|
19
|
+
dsipts/models/PatchTST.py,sha256=1O09cPMg8USdkt5q6szTiz5dIY45kizsf6gt6vLKnQo,9119
|
|
20
20
|
dsipts/models/Persistent.py,sha256=URwyaBb0M7zbPXSGMImtHlwC9XCy-OquFCwfWvn3P70,1249
|
|
21
|
-
dsipts/models/RNN.py,sha256=
|
|
22
|
-
dsipts/models/Samformer.py,sha256=
|
|
23
|
-
dsipts/models/Simple.py,sha256=
|
|
24
|
-
dsipts/models/TFT.py,sha256=
|
|
25
|
-
dsipts/models/TIDE.py,sha256=
|
|
26
|
-
dsipts/models/TTM.py,sha256=
|
|
27
|
-
dsipts/models/TimeXER.py,sha256=
|
|
21
|
+
dsipts/models/RNN.py,sha256=RnsRDAQ2z5-XNaJVZd6Q7z23WvPR2uLVdi7BNQyF7QE,9685
|
|
22
|
+
dsipts/models/Samformer.py,sha256=Kt7B9ID3INtFDAVKIM1LTly5-UfKCaVZ9uxAJmYv6B4,5606
|
|
23
|
+
dsipts/models/Simple.py,sha256=8wRSO-gh_Z6Sl8fYMV-RIXIL0RrO5u5dDtsaq-OsKg0,3960
|
|
24
|
+
dsipts/models/TFT.py,sha256=JiI90ikfP8aaR_rtczu8CyGMNLTgml13aYQifgIC_yo,13888
|
|
25
|
+
dsipts/models/TIDE.py,sha256=S1KlKqFOR3jJ9DDiTqeaKvya9hYBsNHBVqwJsYX3FLU,13094
|
|
26
|
+
dsipts/models/TTM.py,sha256=lOOo5dR5nOmf37cND6C8ft8TVl0kzNeraIuABw7eI5g,5897
|
|
27
|
+
dsipts/models/TimeXER.py,sha256=EkmlHfT2RegY6Ce6q8EUEV1a_WZ6SkYibnOZXqsyd_8,7111
|
|
28
28
|
dsipts/models/VQVAEA.py,sha256=sNJi8UZh-10qEIKcZK3SzhlOFUUjvqjoglzeZBFaeZM,13789
|
|
29
29
|
dsipts/models/VVA.py,sha256=BnPkJ0Nzue0oShSHZVRNlf5RvT0Iwtf9bx19vLB9Nn0,11939
|
|
30
30
|
dsipts/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
31
|
-
dsipts/models/base.py,sha256=
|
|
32
|
-
dsipts/models/base_v2.py,sha256=
|
|
31
|
+
dsipts/models/base.py,sha256=Gqsycy8ZXGaIVx9vvmYRpBCqdUxGE4tvC5ltgxlpEYY,19640
|
|
32
|
+
dsipts/models/base_v2.py,sha256=03cueZExRhkJyBVIHuUPB8sjsCd5Go1HJAR81CADg-c,19896
|
|
33
33
|
dsipts/models/utils.py,sha256=kjTwyktNCFMpPUy6zoleBCSKlvMvK_Jkgyh2T1OXg3E,24497
|
|
34
34
|
dsipts/models/autoformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
35
35
|
dsipts/models/autoformer/layers.py,sha256=xHt8V1lKdD1cIvgxXdDbI_EqOz4zgOQ6LP8l7M1pAxM,13276
|
|
@@ -47,7 +47,7 @@ dsipts/models/d3vae/neural_operations.py,sha256=C70kUtQ0ox9MeXBdu4rPDqt022_hVtcN
|
|
|
47
47
|
dsipts/models/d3vae/resnet.py,sha256=3bnlrEBM2DGiAJV8TeSv2tm27Gm-_P6hee41t8QQFL8,5520
|
|
48
48
|
dsipts/models/d3vae/utils.py,sha256=fmUsE_67uwizjeR1_pDdsndyQddbqt27Lv31XBEn-gw,23798
|
|
49
49
|
dsipts/models/duet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
50
|
-
dsipts/models/duet/layers.py,sha256=
|
|
50
|
+
dsipts/models/duet/layers.py,sha256=TTrhlfSwIXE_7gO9rsdKJD9Bdy3B_JJPCo8vYZJ8Fvg,18258
|
|
51
51
|
dsipts/models/duet/masked.py,sha256=lkdAB5kwAgV7QfBSVP_QeDr_mB09Rz4302p-KwZpUV4,7111
|
|
52
52
|
dsipts/models/informer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
53
53
|
dsipts/models/informer/attn.py,sha256=ghrQGfAqt-Z_7qU5D_aixobmwk6pBKMLAdaNfg-QZbo,6839
|
|
@@ -76,7 +76,7 @@ dsipts/models/vva/minigpt.py,sha256=bg0JddqSD322uxSGexen3nPXL_hGTsk3vNLR62d7-w8,
|
|
|
76
76
|
dsipts/models/vva/vqvae.py,sha256=RzCQ_M9xBprp7_x20dSV3EQqlO0FjPUGWV-qdyKrQsM,19680
|
|
77
77
|
dsipts/models/xlstm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
78
78
|
dsipts/models/xlstm/xLSTM.py,sha256=ZKZZmffmIq1Vb71CR4GSyM8viqVx-u0FChxhcNgHub8,10081
|
|
79
|
-
dsipts-1.1.
|
|
80
|
-
dsipts-1.1.
|
|
81
|
-
dsipts-1.1.
|
|
82
|
-
dsipts-1.1.
|
|
79
|
+
dsipts-1.1.12.dist-info/METADATA,sha256=nxE2kAg9RvG5Py27sMNbQ-mUIu9mtZrDo2WocLpJdQ4,24795
|
|
80
|
+
dsipts-1.1.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
81
|
+
dsipts-1.1.12.dist-info/top_level.txt,sha256=i6o0rf5ScFwZK21E89dSKjVNjUBkrEQpn0-Vij43748,7
|
|
82
|
+
dsipts-1.1.12.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|