dsipts 1.1.12__py3-none-any.whl → 1.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

@@ -6,7 +6,7 @@ from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
6
6
  from sklearn.preprocessing import *
7
7
  from torch.utils.data import DataLoader
8
8
  from .utils import extend_time_df,MetricsCallback, MyDataset, ActionEnum,beauty_string
9
-
9
+ from torch.utils.data.sampler import WeightedRandomSampler
10
10
  try:
11
11
 
12
12
  #new version of lightning
@@ -249,7 +249,8 @@ class TimeSeries():
249
249
  check_past:bool=True,
250
250
  group:Union[None,str]=None,
251
251
  check_holes_and_duplicates:bool=True,
252
- silly_model:bool=False)->None:
252
+ silly_model:bool=False,
253
+ sampler_weights:Union[None,str]=None)->None:
253
254
  """ This is a crucial point in the data structure. We expect here to have a dataset with time as timestamp.
254
255
  There are some checks:
255
256
  1- the duplicates will tbe removed taking the first instance
@@ -270,6 +271,7 @@ class TimeSeries():
270
271
  group (str or None, optional): if not None the time serie dataset is considered composed by omogeneus timeseries coming from different realization (for example point of sales, cities, locations) and the relative series are not splitted during the sample generation. Defaults to None
271
272
  check_holes_and_duplicates (bool, optional): if False duplicates or holes will not checked, the dataloader can not correctly work, disable at your own risk. Defaults True
272
273
  silly_model (bool, optional): if True, target variables will be added to the pool of the future variables. This can be useful to see if information passes throught the decoder part of your model (if any)
274
+ sampler_weights group (str or None, optional): if it is a column name it will be used as weight for the sampler. Careful that the weight of the sample is the weight value of the fist target value (index)
273
275
  """
274
276
 
275
277
 
@@ -322,7 +324,7 @@ class TimeSeries():
322
324
  if group is not None:
323
325
  if group not in cat_past_var:
324
326
  beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
325
- self.cat_var.append(group)
327
+ self.cat_past_var.append(group)
326
328
  if group not in cat_fut_var:
327
329
  beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
328
330
  self.cat_fut_var.append(group)
@@ -350,7 +352,7 @@ class TimeSeries():
350
352
  if silly_model:
351
353
  beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
352
354
  self.future_variables+=self.target_variables
353
-
355
+ self.sampler_weights = sampler_weights
354
356
  def plot(self):
355
357
  """
356
358
  Easy way to control the loaded data
@@ -409,6 +411,7 @@ class TimeSeries():
409
411
  y_samples = []
410
412
  t_samples = []
411
413
  g_samples = []
414
+ sampler_weights_samples = []
412
415
 
413
416
  if starting_point is not None:
414
417
  kk = list(starting_point.keys())[0]
@@ -475,7 +478,8 @@ class TimeSeries():
475
478
  if len(self.cat_fut_var)>0:
476
479
  x_fut_cat = tmp[self.cat_fut_var].values
477
480
  y_target = tmp[self.target_variables].values
478
-
481
+ if self.sampler_weights is not None:
482
+ sampler_weights = tmp[self.sampler_weights].values.flatten()
479
483
 
480
484
  if starting_point is not None:
481
485
  check = tmp[list(starting_point.keys())[0]].values == starting_point[list(starting_point.keys())[0]]
@@ -512,6 +516,8 @@ class TimeSeries():
512
516
  x_cat_future_samples.append(x_fut_cat[i-shift+skip_stacked:i+future_steps-shift+skip_stacked])
513
517
 
514
518
  y_samples.append(y_target[i+skip_stacked:i+future_steps+skip_stacked])
519
+ if self.sampler_weights is not None:
520
+ sampler_weights_samples.append(sampler_weights[i+skip_stacked])
515
521
  t_samples.append(t[i+skip_stacked:i+future_steps+skip_stacked])
516
522
  g_samples.append(groups[i])
517
523
 
@@ -524,6 +530,8 @@ class TimeSeries():
524
530
  beauty_string('WARNING x_num_future_samples is empty and it should not','info',True)
525
531
 
526
532
  y_samples = np.stack(y_samples)
533
+ if self.sampler_weights is not None:
534
+ sampler_weights_samples = np.stack(sampler_weights_samples)
527
535
  t_samples = np.stack(t_samples)
528
536
  g_samples = np.stack(g_samples)
529
537
 
@@ -537,7 +545,6 @@ class TimeSeries():
537
545
  else:
538
546
  mod = 1.0
539
547
  dd = {'y':y_samples.astype(np.float32),
540
-
541
548
  'x_num_past':(x_num_past_samples*mod).astype(np.float32)}
542
549
  if len(self.cat_past_var)>0:
543
550
  dd['x_cat_past'] = x_cat_past_samples
@@ -545,7 +552,10 @@ class TimeSeries():
545
552
  dd['x_cat_future'] = x_cat_future_samples
546
553
  if len(self.future_variables)>0:
547
554
  dd['x_num_future'] = x_num_future_samples.astype(np.float32)
548
-
555
+ if self.sampler_weights is not None:
556
+ dd['sampler_weights'] = sampler_weights_samples.astype(np.float32)
557
+ else:
558
+ dd['sampler_weights'] = np.ones(len(y_samples)).astype(np.float32)
549
559
  return MyDataset(dd,t_samples,g_samples,idx_target,idx_target_future)
550
560
 
551
561
 
@@ -753,8 +763,14 @@ class TimeSeries():
753
763
  else:
754
764
  self.modifier = None
755
765
 
766
+ if self.sampler_weights is not None:
767
+ beauty_string(f'USING SAMPLER IN TRAIN {min(train.sampler_weights)}-{max(train.sampler_weights)}','section',self.verbose)
768
+
769
+ sampler = WeightedRandomSampler(train.sampler_weights, num_samples= len(train))
770
+ train_dl = DataLoader(train, batch_size = batch_size , shuffle=False,sampler=sampler,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
756
771
 
757
- train_dl = DataLoader(train, batch_size = batch_size , shuffle=True,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
772
+ else:
773
+ train_dl = DataLoader(train, batch_size = batch_size , shuffle=True,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
758
774
  valid_dl = DataLoader(validation, batch_size = batch_size , shuffle=False,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
759
775
 
760
776
  checkpoint_callback = ModelCheckpoint(dirpath=dirpath,
@@ -1026,7 +1042,7 @@ class TimeSeries():
1026
1042
 
1027
1043
  if self.group is not None:
1028
1044
  time[self.group] = groups
1029
- time = time.melt(id_vars=['region'])
1045
+ time = time.melt(id_vars=[self.group])
1030
1046
  else:
1031
1047
  time = time.melt()
1032
1048
  time.rename(columns={'value':'time','variable':'lag'},inplace=True)
@@ -1048,7 +1064,8 @@ class TimeSeries():
1048
1064
 
1049
1065
  if self.group is not None:
1050
1066
  time[self.group] = groups
1051
- time = time.melt(id_vars=['region'])
1067
+
1068
+ time = time.melt(id_vars=[self.group])
1052
1069
  else:
1053
1070
  time = time.melt()
1054
1071
  time.rename(columns={'value':'time','variable':'lag'},inplace=True)
@@ -142,12 +142,13 @@ class MyDataset(Dataset):
142
142
  Returns:
143
143
  torch.utils.data.Dataset: a torch Dataset to be used in a Dataloader
144
144
  """
145
+
145
146
  self.data = data
146
147
  self.t = t
147
148
  self.groups = groups
148
149
  self.idx_target = np.array(idx_target) if idx_target is not None else None
149
150
  self.idx_target_future = np.array(idx_target_future) if idx_target_future is not None else None
150
-
151
+ self.sampler_weights = data['sampler_weights']
151
152
 
152
153
 
153
154
  def __len__(self):
@@ -157,7 +158,8 @@ class MyDataset(Dataset):
157
158
  def __getitem__(self, idxs):
158
159
  sample = {}
159
160
  for k in self.data:
160
- sample[k] = self.data[k][idxs]
161
+ if k!='sampler_weights':
162
+ sample[k] = self.data[k][idxs]
161
163
  if self.idx_target is not None:
162
164
  sample['idx_target'] = self.idx_target
163
165
  if self.idx_target_future is not None:
dsipts/models/TTM.py CHANGED
@@ -38,6 +38,7 @@ class TTM(Base):
38
38
  fcm_mix_layers,
39
39
  fcm_prepend_past,
40
40
  enable_forecast_channel_mixing,
41
+ force_return,
41
42
  **kwargs)->None:
42
43
 
43
44
  super().__init__(**kwargs)
@@ -48,7 +49,9 @@ class TTM(Base):
48
49
  self.index_fut = list(exogenous_channel_indices_cont)
49
50
 
50
51
  if len(exogenous_channel_indices_cat)>0:
51
- self.index_fut_cat = (self.past_channels+len(self.embs_past))+list(exogenous_channel_indices_cat)
52
+
53
+ self.index_fut_cat = [self.past_channels+c for c in list(exogenous_channel_indices_cat)]
54
+
52
55
  else:
53
56
  self.index_fut_cat = []
54
57
  self.freq = freq
@@ -75,6 +78,7 @@ class TTM(Base):
75
78
  fcm_use_mixer=fcm_use_mixer,
76
79
  fcm_mix_layers=fcm_mix_layers,
77
80
  freq=freq,
81
+ force_return=force_return,
78
82
  freq_prefix_tuning=freq_prefix_tuning,
79
83
  fcm_prepend_past=fcm_prepend_past,
80
84
  enable_forecast_channel_mixing=enable_forecast_channel_mixing,
@@ -83,7 +87,7 @@ class TTM(Base):
83
87
  hidden_size = self.model.config.hidden_size
84
88
  self.model.prediction_head = torch.nn.Linear(hidden_size, self.out_channels*self.mul)
85
89
  self._freeze_backbone()
86
-
90
+ self.zero_pad = (force_return=='zeropad')
87
91
  def _freeze_backbone(self):
88
92
  """
89
93
  Freeze the backbone of the model.
@@ -108,29 +112,44 @@ class TTM(Base):
108
112
  return input
109
113
 
110
114
  def can_be_compiled(self):
111
- return True
115
+
116
+ return True#not self.zero_pad
112
117
 
113
118
  def forward(self, batch):
114
119
  x_enc = batch['x_num_past'].to(self.device)
120
+
121
+
122
+ if self.zero_pad:
123
+ B,L,C = batch['x_num_past'].shape
124
+ x_enc = torch.zeros((B,512,C)).to(self.device)
125
+ x_enc[:,-L:,:] = batch['x_num_past'].to(self.device)
126
+ else:
127
+ x_enc = batch['x_num_past'].to(self.device)
115
128
  original_indexes = batch['idx_target'][0].tolist()
116
129
 
117
130
 
118
131
  if 'x_cat_past' in batch.keys():
119
- x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
120
- x_mark_enc = self._scaler_past(x_mark_enc)
132
+ if self.zero_pad:
133
+ B,L,C = batch['x_cat_past'].shape
134
+ x_mark_enc = torch.zeros((B,512,C)).to(self.device)
135
+ x_mark_enc[:,-L:,:] = batch['x_cat_past'].to(torch.float32).to(self.device)
136
+ else:
137
+ x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
138
+ x_mark_enc = self._scaler_past(x_mark_enc)
121
139
  past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
122
140
  else:
123
141
  past_values = x_enc
142
+ B,L,C = past_values.shape
143
+ future_values = torch.zeros((B,self.future_steps,C)).to(self.device)
124
144
 
125
- future_values = torch.zeros_like(past_values).to(self.device)
126
- future_values = future_values[:,:self.future_steps,:]
127
145
 
146
+
128
147
  if 'x_num_future' in batch.keys():
129
148
  future_values[:,:,self.index_fut] = batch['x_num_future'].to(self.device)
130
149
  if 'x_cat_future' in batch.keys():
131
150
  x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
132
151
  x_mark_dec = self._scaler_fut(x_mark_dec)
133
- future_values[:,:,self.index_cat_fut] = x_mark_dec
152
+ future_values[:,:,self.index_fut_cat] = x_mark_dec
134
153
 
135
154
 
136
155
  #investigating!! problem with dynamo!
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dsipts
3
- Version: 1.1.12
3
+ Version: 1.1.13
4
4
  Summary: Unified library for timeseries modelling
5
5
  Author-email: Andrea Gobbi <agobbi@fbk.eu>
6
6
  Project-URL: Homepage, https://github.com/DSIP-FBK/DSIPTS
@@ -3,9 +3,9 @@ dsipts/data_management/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
3
3
  dsipts/data_management/monash.py,sha256=aZxq9FbIH6IsU8Lwou1hAokXjgOAK-wdl2VAeFg2k4M,13075
4
4
  dsipts/data_management/public_datasets.py,sha256=yXFzOZZ-X0ZG1DoqVU-zFmEGVMc2033YDQhRgYxY8ws,6793
5
5
  dsipts/data_structure/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- dsipts/data_structure/data_structure.py,sha256=KVkjTVjc7NznJIou4LYGzMbzE7ye-K3ll65GEgn2qKg,60814
6
+ dsipts/data_structure/data_structure.py,sha256=vOiVuTbEprXwpf8l5hk-7HP3v_r5d3YiXuQdGwo4nV0,62295
7
7
  dsipts/data_structure/modifiers.py,sha256=qlry9dfw8pEE0GrvgwROZJkJ6oPpUnjEHPIG5qIetss,7948
8
- dsipts/data_structure/utils.py,sha256=QwfKPZgSy6DIw5n6ztOdPJIAnzo4EnlMTgRbpiWnyko,6593
8
+ dsipts/data_structure/utils.py,sha256=ZL-z_InmFUkge5kQoHSrev1t6nyve9sTYTVeA75Or-I,6689
9
9
  dsipts/models/Autoformer.py,sha256=nUQvPC_qtajLT1AHdNJmF_P3ZL01j3spkZ4ubxdGF3g,8497
10
10
  dsipts/models/CrossFormer.py,sha256=ClW6H_hrtLJH0iqTC7q_ya_Bwc_Xu-0lpAN5w2DSUYk,6526
11
11
  dsipts/models/D3VAE.py,sha256=d1aY6kGjBSxZncN-KPWpdUGunu182ng2QFInGFrKYQM,6903
@@ -23,7 +23,7 @@ dsipts/models/Samformer.py,sha256=Kt7B9ID3INtFDAVKIM1LTly5-UfKCaVZ9uxAJmYv6B4,56
23
23
  dsipts/models/Simple.py,sha256=8wRSO-gh_Z6Sl8fYMV-RIXIL0RrO5u5dDtsaq-OsKg0,3960
24
24
  dsipts/models/TFT.py,sha256=JiI90ikfP8aaR_rtczu8CyGMNLTgml13aYQifgIC_yo,13888
25
25
  dsipts/models/TIDE.py,sha256=S1KlKqFOR3jJ9DDiTqeaKvya9hYBsNHBVqwJsYX3FLU,13094
26
- dsipts/models/TTM.py,sha256=lOOo5dR5nOmf37cND6C8ft8TVl0kzNeraIuABw7eI5g,5897
26
+ dsipts/models/TTM.py,sha256=PoRDT-KYoMqv6yIOU-73E7Y2pRyd4lga0u6KrJRd5DU,6561
27
27
  dsipts/models/TimeXER.py,sha256=EkmlHfT2RegY6Ce6q8EUEV1a_WZ6SkYibnOZXqsyd_8,7111
28
28
  dsipts/models/VQVAEA.py,sha256=sNJi8UZh-10qEIKcZK3SzhlOFUUjvqjoglzeZBFaeZM,13789
29
29
  dsipts/models/VVA.py,sha256=BnPkJ0Nzue0oShSHZVRNlf5RvT0Iwtf9bx19vLB9Nn0,11939
@@ -76,7 +76,7 @@ dsipts/models/vva/minigpt.py,sha256=bg0JddqSD322uxSGexen3nPXL_hGTsk3vNLR62d7-w8,
76
76
  dsipts/models/vva/vqvae.py,sha256=RzCQ_M9xBprp7_x20dSV3EQqlO0FjPUGWV-qdyKrQsM,19680
77
77
  dsipts/models/xlstm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
78
  dsipts/models/xlstm/xLSTM.py,sha256=ZKZZmffmIq1Vb71CR4GSyM8viqVx-u0FChxhcNgHub8,10081
79
- dsipts-1.1.12.dist-info/METADATA,sha256=nxE2kAg9RvG5Py27sMNbQ-mUIu9mtZrDo2WocLpJdQ4,24795
80
- dsipts-1.1.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
81
- dsipts-1.1.12.dist-info/top_level.txt,sha256=i6o0rf5ScFwZK21E89dSKjVNjUBkrEQpn0-Vij43748,7
82
- dsipts-1.1.12.dist-info/RECORD,,
79
+ dsipts-1.1.13.dist-info/METADATA,sha256=6UZ0nHk0RoGXxxkPYCyB0w41m8LlOE5BfoiswplloXQ,24795
80
+ dsipts-1.1.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
81
+ dsipts-1.1.13.dist-info/top_level.txt,sha256=i6o0rf5ScFwZK21E89dSKjVNjUBkrEQpn0-Vij43748,7
82
+ dsipts-1.1.13.dist-info/RECORD,,