dsipts 1.1.12__tar.gz → 1.1.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- {dsipts-1.1.12 → dsipts-1.1.13}/PKG-INFO +1 -1
- {dsipts-1.1.12 → dsipts-1.1.13}/pyproject.toml +1 -1
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/data_structure/data_structure.py +27 -10
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/data_structure/utils.py +4 -2
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/TTM.py +27 -8
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts.egg-info/PKG-INFO +1 -1
- {dsipts-1.1.12 → dsipts-1.1.13}/README.md +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/setup.cfg +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/data_management/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/data_management/monash.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/data_management/public_datasets.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/data_structure/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/data_structure/modifiers.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/Autoformer.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/CrossFormer.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/D3VAE.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/Diffusion.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/DilatedConv.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/DilatedConvED.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/Duet.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/ITransformer.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/Informer.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/LinearTS.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/PatchTST.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/Persistent.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/RNN.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/Samformer.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/Simple.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/TFT.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/TIDE.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/TimeXER.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/VQVAEA.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/VVA.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/autoformer/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/autoformer/layers.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/base.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/base_v2.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/crossformer/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/crossformer/attn.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/crossformer/cross_decoder.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/crossformer/cross_embed.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/crossformer/cross_encoder.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/d3vae/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/d3vae/diffusion_process.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/d3vae/embedding.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/d3vae/encoder.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/d3vae/model.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/d3vae/neural_operations.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/d3vae/resnet.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/d3vae/utils.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/duet/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/duet/layers.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/duet/masked.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/informer/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/informer/attn.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/informer/decoder.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/informer/embed.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/informer/encoder.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/itransformer/Embed.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/itransformer/SelfAttention_Family.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/itransformer/Transformer_EncDec.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/itransformer/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/patchtst/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/patchtst/layers.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/samformer/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/samformer/utils.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/tft/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/tft/sub_nn.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/timexer/Layers.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/timexer/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/ttm/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/ttm/configuration_tinytimemixer.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/ttm/consts.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/ttm/modeling_tinytimemixer.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/ttm/utils.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/utils.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/vva/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/vva/minigpt.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/vva/vqvae.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/xlstm/__init__.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts/models/xlstm/xLSTM.py +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts.egg-info/SOURCES.txt +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts.egg-info/dependency_links.txt +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts.egg-info/requires.txt +0 -0
- {dsipts-1.1.12 → dsipts-1.1.13}/src/dsipts.egg-info/top_level.txt +0 -0
|
@@ -6,7 +6,7 @@ from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
|
|
|
6
6
|
from sklearn.preprocessing import *
|
|
7
7
|
from torch.utils.data import DataLoader
|
|
8
8
|
from .utils import extend_time_df,MetricsCallback, MyDataset, ActionEnum,beauty_string
|
|
9
|
-
|
|
9
|
+
from torch.utils.data.sampler import WeightedRandomSampler
|
|
10
10
|
try:
|
|
11
11
|
|
|
12
12
|
#new version of lightning
|
|
@@ -249,7 +249,8 @@ class TimeSeries():
|
|
|
249
249
|
check_past:bool=True,
|
|
250
250
|
group:Union[None,str]=None,
|
|
251
251
|
check_holes_and_duplicates:bool=True,
|
|
252
|
-
silly_model:bool=False
|
|
252
|
+
silly_model:bool=False,
|
|
253
|
+
sampler_weights:Union[None,str]=None)->None:
|
|
253
254
|
""" This is a crucial point in the data structure. We expect here to have a dataset with time as timestamp.
|
|
254
255
|
There are some checks:
|
|
255
256
|
1- the duplicates will tbe removed taking the first instance
|
|
@@ -270,6 +271,7 @@ class TimeSeries():
|
|
|
270
271
|
group (str or None, optional): if not None the time serie dataset is considered composed by omogeneus timeseries coming from different realization (for example point of sales, cities, locations) and the relative series are not splitted during the sample generation. Defaults to None
|
|
271
272
|
check_holes_and_duplicates (bool, optional): if False duplicates or holes will not checked, the dataloader can not correctly work, disable at your own risk. Defaults True
|
|
272
273
|
silly_model (bool, optional): if True, target variables will be added to the pool of the future variables. This can be useful to see if information passes throught the decoder part of your model (if any)
|
|
274
|
+
sampler_weights group (str or None, optional): if it is a column name it will be used as weight for the sampler. Careful that the weight of the sample is the weight value of the fist target value (index)
|
|
273
275
|
"""
|
|
274
276
|
|
|
275
277
|
|
|
@@ -322,7 +324,7 @@ class TimeSeries():
|
|
|
322
324
|
if group is not None:
|
|
323
325
|
if group not in cat_past_var:
|
|
324
326
|
beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
|
|
325
|
-
self.
|
|
327
|
+
self.cat_past_var.append(group)
|
|
326
328
|
if group not in cat_fut_var:
|
|
327
329
|
beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
|
|
328
330
|
self.cat_fut_var.append(group)
|
|
@@ -350,7 +352,7 @@ class TimeSeries():
|
|
|
350
352
|
if silly_model:
|
|
351
353
|
beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
|
|
352
354
|
self.future_variables+=self.target_variables
|
|
353
|
-
|
|
355
|
+
self.sampler_weights = sampler_weights
|
|
354
356
|
def plot(self):
|
|
355
357
|
"""
|
|
356
358
|
Easy way to control the loaded data
|
|
@@ -409,6 +411,7 @@ class TimeSeries():
|
|
|
409
411
|
y_samples = []
|
|
410
412
|
t_samples = []
|
|
411
413
|
g_samples = []
|
|
414
|
+
sampler_weights_samples = []
|
|
412
415
|
|
|
413
416
|
if starting_point is not None:
|
|
414
417
|
kk = list(starting_point.keys())[0]
|
|
@@ -475,7 +478,8 @@ class TimeSeries():
|
|
|
475
478
|
if len(self.cat_fut_var)>0:
|
|
476
479
|
x_fut_cat = tmp[self.cat_fut_var].values
|
|
477
480
|
y_target = tmp[self.target_variables].values
|
|
478
|
-
|
|
481
|
+
if self.sampler_weights is not None:
|
|
482
|
+
sampler_weights = tmp[self.sampler_weights].values.flatten()
|
|
479
483
|
|
|
480
484
|
if starting_point is not None:
|
|
481
485
|
check = tmp[list(starting_point.keys())[0]].values == starting_point[list(starting_point.keys())[0]]
|
|
@@ -512,6 +516,8 @@ class TimeSeries():
|
|
|
512
516
|
x_cat_future_samples.append(x_fut_cat[i-shift+skip_stacked:i+future_steps-shift+skip_stacked])
|
|
513
517
|
|
|
514
518
|
y_samples.append(y_target[i+skip_stacked:i+future_steps+skip_stacked])
|
|
519
|
+
if self.sampler_weights is not None:
|
|
520
|
+
sampler_weights_samples.append(sampler_weights[i+skip_stacked])
|
|
515
521
|
t_samples.append(t[i+skip_stacked:i+future_steps+skip_stacked])
|
|
516
522
|
g_samples.append(groups[i])
|
|
517
523
|
|
|
@@ -524,6 +530,8 @@ class TimeSeries():
|
|
|
524
530
|
beauty_string('WARNING x_num_future_samples is empty and it should not','info',True)
|
|
525
531
|
|
|
526
532
|
y_samples = np.stack(y_samples)
|
|
533
|
+
if self.sampler_weights is not None:
|
|
534
|
+
sampler_weights_samples = np.stack(sampler_weights_samples)
|
|
527
535
|
t_samples = np.stack(t_samples)
|
|
528
536
|
g_samples = np.stack(g_samples)
|
|
529
537
|
|
|
@@ -537,7 +545,6 @@ class TimeSeries():
|
|
|
537
545
|
else:
|
|
538
546
|
mod = 1.0
|
|
539
547
|
dd = {'y':y_samples.astype(np.float32),
|
|
540
|
-
|
|
541
548
|
'x_num_past':(x_num_past_samples*mod).astype(np.float32)}
|
|
542
549
|
if len(self.cat_past_var)>0:
|
|
543
550
|
dd['x_cat_past'] = x_cat_past_samples
|
|
@@ -545,7 +552,10 @@ class TimeSeries():
|
|
|
545
552
|
dd['x_cat_future'] = x_cat_future_samples
|
|
546
553
|
if len(self.future_variables)>0:
|
|
547
554
|
dd['x_num_future'] = x_num_future_samples.astype(np.float32)
|
|
548
|
-
|
|
555
|
+
if self.sampler_weights is not None:
|
|
556
|
+
dd['sampler_weights'] = sampler_weights_samples.astype(np.float32)
|
|
557
|
+
else:
|
|
558
|
+
dd['sampler_weights'] = np.ones(len(y_samples)).astype(np.float32)
|
|
549
559
|
return MyDataset(dd,t_samples,g_samples,idx_target,idx_target_future)
|
|
550
560
|
|
|
551
561
|
|
|
@@ -753,8 +763,14 @@ class TimeSeries():
|
|
|
753
763
|
else:
|
|
754
764
|
self.modifier = None
|
|
755
765
|
|
|
766
|
+
if self.sampler_weights is not None:
|
|
767
|
+
beauty_string(f'USING SAMPLER IN TRAIN {min(train.sampler_weights)}-{max(train.sampler_weights)}','section',self.verbose)
|
|
768
|
+
|
|
769
|
+
sampler = WeightedRandomSampler(train.sampler_weights, num_samples= len(train))
|
|
770
|
+
train_dl = DataLoader(train, batch_size = batch_size , shuffle=False,sampler=sampler,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
|
|
756
771
|
|
|
757
|
-
|
|
772
|
+
else:
|
|
773
|
+
train_dl = DataLoader(train, batch_size = batch_size , shuffle=True,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
|
|
758
774
|
valid_dl = DataLoader(validation, batch_size = batch_size , shuffle=False,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
|
|
759
775
|
|
|
760
776
|
checkpoint_callback = ModelCheckpoint(dirpath=dirpath,
|
|
@@ -1026,7 +1042,7 @@ class TimeSeries():
|
|
|
1026
1042
|
|
|
1027
1043
|
if self.group is not None:
|
|
1028
1044
|
time[self.group] = groups
|
|
1029
|
-
time = time.melt(id_vars=[
|
|
1045
|
+
time = time.melt(id_vars=[self.group])
|
|
1030
1046
|
else:
|
|
1031
1047
|
time = time.melt()
|
|
1032
1048
|
time.rename(columns={'value':'time','variable':'lag'},inplace=True)
|
|
@@ -1048,7 +1064,8 @@ class TimeSeries():
|
|
|
1048
1064
|
|
|
1049
1065
|
if self.group is not None:
|
|
1050
1066
|
time[self.group] = groups
|
|
1051
|
-
|
|
1067
|
+
|
|
1068
|
+
time = time.melt(id_vars=[self.group])
|
|
1052
1069
|
else:
|
|
1053
1070
|
time = time.melt()
|
|
1054
1071
|
time.rename(columns={'value':'time','variable':'lag'},inplace=True)
|
|
@@ -142,12 +142,13 @@ class MyDataset(Dataset):
|
|
|
142
142
|
Returns:
|
|
143
143
|
torch.utils.data.Dataset: a torch Dataset to be used in a Dataloader
|
|
144
144
|
"""
|
|
145
|
+
|
|
145
146
|
self.data = data
|
|
146
147
|
self.t = t
|
|
147
148
|
self.groups = groups
|
|
148
149
|
self.idx_target = np.array(idx_target) if idx_target is not None else None
|
|
149
150
|
self.idx_target_future = np.array(idx_target_future) if idx_target_future is not None else None
|
|
150
|
-
|
|
151
|
+
self.sampler_weights = data['sampler_weights']
|
|
151
152
|
|
|
152
153
|
|
|
153
154
|
def __len__(self):
|
|
@@ -157,7 +158,8 @@ class MyDataset(Dataset):
|
|
|
157
158
|
def __getitem__(self, idxs):
|
|
158
159
|
sample = {}
|
|
159
160
|
for k in self.data:
|
|
160
|
-
|
|
161
|
+
if k!='sampler_weights':
|
|
162
|
+
sample[k] = self.data[k][idxs]
|
|
161
163
|
if self.idx_target is not None:
|
|
162
164
|
sample['idx_target'] = self.idx_target
|
|
163
165
|
if self.idx_target_future is not None:
|
|
@@ -38,6 +38,7 @@ class TTM(Base):
|
|
|
38
38
|
fcm_mix_layers,
|
|
39
39
|
fcm_prepend_past,
|
|
40
40
|
enable_forecast_channel_mixing,
|
|
41
|
+
force_return,
|
|
41
42
|
**kwargs)->None:
|
|
42
43
|
|
|
43
44
|
super().__init__(**kwargs)
|
|
@@ -48,7 +49,9 @@ class TTM(Base):
|
|
|
48
49
|
self.index_fut = list(exogenous_channel_indices_cont)
|
|
49
50
|
|
|
50
51
|
if len(exogenous_channel_indices_cat)>0:
|
|
51
|
-
|
|
52
|
+
|
|
53
|
+
self.index_fut_cat = [self.past_channels+c for c in list(exogenous_channel_indices_cat)]
|
|
54
|
+
|
|
52
55
|
else:
|
|
53
56
|
self.index_fut_cat = []
|
|
54
57
|
self.freq = freq
|
|
@@ -75,6 +78,7 @@ class TTM(Base):
|
|
|
75
78
|
fcm_use_mixer=fcm_use_mixer,
|
|
76
79
|
fcm_mix_layers=fcm_mix_layers,
|
|
77
80
|
freq=freq,
|
|
81
|
+
force_return=force_return,
|
|
78
82
|
freq_prefix_tuning=freq_prefix_tuning,
|
|
79
83
|
fcm_prepend_past=fcm_prepend_past,
|
|
80
84
|
enable_forecast_channel_mixing=enable_forecast_channel_mixing,
|
|
@@ -83,7 +87,7 @@ class TTM(Base):
|
|
|
83
87
|
hidden_size = self.model.config.hidden_size
|
|
84
88
|
self.model.prediction_head = torch.nn.Linear(hidden_size, self.out_channels*self.mul)
|
|
85
89
|
self._freeze_backbone()
|
|
86
|
-
|
|
90
|
+
self.zero_pad = (force_return=='zeropad')
|
|
87
91
|
def _freeze_backbone(self):
|
|
88
92
|
"""
|
|
89
93
|
Freeze the backbone of the model.
|
|
@@ -108,29 +112,44 @@ class TTM(Base):
|
|
|
108
112
|
return input
|
|
109
113
|
|
|
110
114
|
def can_be_compiled(self):
|
|
111
|
-
|
|
115
|
+
|
|
116
|
+
return True#not self.zero_pad
|
|
112
117
|
|
|
113
118
|
def forward(self, batch):
|
|
114
119
|
x_enc = batch['x_num_past'].to(self.device)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
if self.zero_pad:
|
|
123
|
+
B,L,C = batch['x_num_past'].shape
|
|
124
|
+
x_enc = torch.zeros((B,512,C)).to(self.device)
|
|
125
|
+
x_enc[:,-L:,:] = batch['x_num_past'].to(self.device)
|
|
126
|
+
else:
|
|
127
|
+
x_enc = batch['x_num_past'].to(self.device)
|
|
115
128
|
original_indexes = batch['idx_target'][0].tolist()
|
|
116
129
|
|
|
117
130
|
|
|
118
131
|
if 'x_cat_past' in batch.keys():
|
|
119
|
-
|
|
120
|
-
|
|
132
|
+
if self.zero_pad:
|
|
133
|
+
B,L,C = batch['x_cat_past'].shape
|
|
134
|
+
x_mark_enc = torch.zeros((B,512,C)).to(self.device)
|
|
135
|
+
x_mark_enc[:,-L:,:] = batch['x_cat_past'].to(torch.float32).to(self.device)
|
|
136
|
+
else:
|
|
137
|
+
x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
|
|
138
|
+
x_mark_enc = self._scaler_past(x_mark_enc)
|
|
121
139
|
past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
|
|
122
140
|
else:
|
|
123
141
|
past_values = x_enc
|
|
142
|
+
B,L,C = past_values.shape
|
|
143
|
+
future_values = torch.zeros((B,self.future_steps,C)).to(self.device)
|
|
124
144
|
|
|
125
|
-
future_values = torch.zeros_like(past_values).to(self.device)
|
|
126
|
-
future_values = future_values[:,:self.future_steps,:]
|
|
127
145
|
|
|
146
|
+
|
|
128
147
|
if 'x_num_future' in batch.keys():
|
|
129
148
|
future_values[:,:,self.index_fut] = batch['x_num_future'].to(self.device)
|
|
130
149
|
if 'x_cat_future' in batch.keys():
|
|
131
150
|
x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
|
|
132
151
|
x_mark_dec = self._scaler_fut(x_mark_dec)
|
|
133
|
-
future_values[:,:,self.
|
|
152
|
+
future_values[:,:,self.index_fut_cat] = x_mark_dec
|
|
134
153
|
|
|
135
154
|
|
|
136
155
|
#investigating!! problem with dynamo!
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|