dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,184 @@
1
+ ## Copyright https://github.com/thuml/Time-Series-Library/blob/main/models/TimeMixer.py
2
+ ## Modified for notation alignmenet and batch structure
3
+ ## extended to what inside timexer folder
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+
10
+
11
+ try:
12
+ import lightning.pytorch as pl
13
+ from .base_v2 import Base
14
+ OLD_PL = False
15
+ except:
16
+ import pytorch_lightning as pl
17
+ OLD_PL = True
18
+ from .base import Base
19
+ from .utils import QuantileLossMO,Permute, get_activation
20
+ from .itransformer.SelfAttention_Family import FullAttention, AttentionLayer
21
+ from .itransformer.Embed import DataEmbedding_inverted
22
+ from .timexer.Layers import FlattenHead,EnEmbedding, EncoderLayer, Encoder
23
+
24
+ from typing import List, Union
25
+ from ..data_structure.utils import beauty_string
26
+ from .utils import get_scope
27
+ from .utils import Embedding_cat_variables
28
+
29
+
30
+
31
+
32
+ class TimeXER(Base):
33
+ handle_multivariate = True
34
+ handle_future_covariates = True # or at least it seems...
35
+ handle_categorical_variables = True #solo nel encoder
36
+ handle_quantile_loss = True # NOT EFFICIENTLY ADDED, TODO fix this
37
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
38
+
39
+ def __init__(self,
40
+ patch_len:int,
41
+ d_model: int,
42
+ n_head: int,
43
+ d_ff:int=512,
44
+ dropout_rate: float=0.1,
45
+ n_layer_decoder: int=1,
46
+ activation: str='',
47
+ **kwargs)->None:
48
+ """Initialize the model with specified parameters. https://github.com/thuml/Time-Series-Library/blob/main/models/TimeMixer.py
49
+
50
+ Args:
51
+ patch_len (int): Length of the patches.
52
+ d_model (int): Dimension of the model.
53
+ n_head (int): Number of attention heads.
54
+ d_ff (int, optional): Dimension of the feedforward network. Defaults to 512.
55
+ dropout_rate (float, optional): Dropout rate for regularization. Defaults to 0.1.
56
+ n_layer_decoder (int, optional): Number of layers in the decoder. Defaults to 1.
57
+ activation (str, optional): Activation function to use. Defaults to ''.
58
+ **kwargs: Additional keyword arguments passed to the superclass.
59
+
60
+ Raises:
61
+ ValueError: If an invalid activation function is provided.
62
+
63
+
64
+ """
65
+ super().__init__(**kwargs)
66
+ if activation == 'torch.nn.SELU':
67
+ beauty_string('SELU do not require BN','info',self.verbose)
68
+ use_bn = False
69
+ if isinstance(activation,str):
70
+ activation = get_activation(activation)
71
+ self.save_hyperparameters(logger=False)
72
+
73
+
74
+
75
+
76
+
77
+ self.patch_len = patch_len
78
+ self.patch_num = int(self.past_steps // patch_len)
79
+ d_model = d_model*self.mul
80
+
81
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
82
+ self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
83
+ emb_past_out_channel = self.emb_past.output_channels
84
+ emb_fut_out_channel = self.emb_fut.output_channels
85
+
86
+ self.output_attention = False## not need output attention
87
+
88
+ ###
89
+ self.en_embedding = EnEmbedding(self.past_channels, d_model, patch_len, dropout_rate)
90
+
91
+ self.ex_embedding = DataEmbedding_inverted(self.past_steps, d_model, embed_type='what?', freq='what?', dropout=dropout_rate) ##embed, freq not used inside
92
+
93
+
94
+ # Encoder-only architecture
95
+ self.encoder = Encoder(
96
+ [
97
+ EncoderLayer(
98
+ AttentionLayer(
99
+ FullAttention(False, factor = 0.1, attention_dropout=dropout_rate, ##NB factor is not used
100
+ output_attention=False),
101
+ d_model, n_head),
102
+ AttentionLayer(
103
+ FullAttention(False, 0.1, attention_dropout=dropout_rate,
104
+ output_attention=False),
105
+ d_model, n_head),
106
+ d_model,
107
+ d_ff,
108
+ dropout=dropout_rate,
109
+ activation=activation(),
110
+ )
111
+ for l in range(n_layer_decoder)
112
+ ],
113
+ norm_layer=torch.nn.LayerNorm(d_model)
114
+ )
115
+ self.head_nf = d_model * (self.patch_num + 1)
116
+ self.head = FlattenHead(self.past_channels, self.head_nf, self.future_steps*self.mul, head_dropout=dropout_rate)
117
+
118
+
119
+ self.future_reshape = nn.Linear(self.future_steps,self.future_steps*self.mul)
120
+ self.final_linear = nn.Sequential(activation(),
121
+ nn.Linear(self.past_channels+self.future_channels+emb_fut_out_channel,(self.past_channels+self.future_channels+emb_fut_out_channel)//2),
122
+ activation(),
123
+ nn.Linear((self.past_channels+self.future_channels+emb_fut_out_channel)//2,self.out_channels)
124
+ )
125
+
126
+
127
+
128
+
129
+
130
+
131
+ def forward(self, batch:dict)-> float:
132
+
133
+
134
+ x_enc = batch['x_num_past'].to(self.device)
135
+
136
+ BS = x_enc.shape[0]
137
+ if 'x_cat_future' in batch.keys():
138
+ emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
139
+ else:
140
+ emb_fut = self.emb_fut(BS,None)
141
+ tmp_future = [emb_fut]
142
+ if 'x_cat_past' in batch.keys():
143
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
144
+ else:
145
+ emb_past = self.emb_past(BS,None)
146
+
147
+ if 'x_num_future' in batch.keys():
148
+ x_future = batch['x_num_future'].to(self.device)
149
+ tmp_future.append(x_future)
150
+ if len(tmp_future)>0:
151
+ tmp_future = torch.cat(tmp_future,2)
152
+ else:
153
+ tmp_future = None
154
+
155
+
156
+
157
+
158
+
159
+ en_embed, n_vars = self.en_embedding(x_enc.permute(0, 2, 1))
160
+ ex_embed = self.ex_embedding(x_enc, emb_past)
161
+
162
+ enc_out = self.encoder(en_embed, ex_embed)
163
+
164
+
165
+
166
+
167
+ enc_out = torch.reshape(enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
168
+
169
+
170
+ # z: [bs x nvars x d_model x patch_num]
171
+ enc_out = enc_out.permute(0, 1, 3, 2)
172
+
173
+ dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
174
+ #dec_out = dec_out.permute(0, 2, 1)
175
+ if tmp_future is not None:
176
+ tmp_future = self.future_reshape(tmp_future.permute(0, 2, 1))
177
+ dec_out = torch.cat([tmp_future,dec_out],1)
178
+ dec_out = self.final_linear(dec_out.permute(0, 2, 1))
179
+ return dec_out.reshape(BS,self.future_steps,self.out_channels,self.mul)
180
+
181
+ #idx_target = batch['idx_target'][0]
182
+ #return dec_out[:, :,idx_target].reshape(BS,self.future_steps,self.out_channels,self.mul)
183
+
184
+
@@ -0,0 +1,299 @@
1
+
2
+ ## Copyright Copyright (c) 2020 Andrej Karpathy https://github.com/karpathy/minGPT?tab=MIT-1-ov-file#readme
3
+ ## Modified for notation alignmenet, batch structure and adapted for timeseries
4
+ ## extended to what inside vva folder
5
+
6
+ from torch import nn
7
+ import torch
8
+ from torch.nn import functional as F
9
+
10
+ try:
11
+ import lightning.pytorch as pl
12
+ from .base_v2 import Base
13
+ OLD_PL = False
14
+ except:
15
+ import pytorch_lightning as pl
16
+ OLD_PL = True
17
+ from .base import Base
18
+ from typing import List, Union
19
+ from .vva.minigpt import Block
20
+ from .vva.vqvae import VQVAE
21
+ import logging
22
+ from random import random
23
+ from ..data_structure.utils import beauty_string
24
+ from .utils import get_scope
25
+
26
+ torch.autograd.set_detect_anomaly(True)
27
+
28
+
29
+
30
+ class VQVAEA(Base):
31
+ handle_multivariate = False
32
+ handle_future_covariates = False
33
+ handle_categorical_variables = False
34
+ handle_quantile_loss = False
35
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
36
+
37
+
38
+ def __init__(self,
39
+ past_steps:int,
40
+ future_steps:int,
41
+ past_channels:int,
42
+ future_channels:int,
43
+ hidden_channels:int,
44
+ embs:List[int],
45
+ d_model:int,
46
+ max_voc_size:int,
47
+ num_layers:int,
48
+ dropout_rate:float,
49
+ commitment_cost:float,
50
+ decay:float,
51
+ n_heads:int,
52
+ out_channels:int,
53
+ epoch_vqvae: int,
54
+ persistence_weight:float=0.0,
55
+ loss_type: str='l1',
56
+ quantiles:List[int]=[],
57
+ optim:Union[str,None]=None,
58
+ optim_config:dict=None,
59
+ scheduler_config:dict=None,
60
+ **kwargs)->None:
61
+ """ Custom encoder-decoder
62
+
63
+ Args:
64
+ past_steps (int): number of past datapoints used
65
+ future_steps (int): number of future lag to predict
66
+ past_channels (int): number of numeric past variables, must be >0
67
+ future_channels (int): number of future numeric variables
68
+ embs (List): list of the initial dimension of the categorical variables
69
+ cat_emb_dim (int): final dimension of each categorical variable
70
+ hidden_RNN (int): hidden size of the RNN block
71
+ num_layers_RNN (int): number of RNN layers
72
+ kind (str): one among GRU or LSTM
73
+ kernel_size (int): kernel size in the encoder convolutional block
74
+ sum_emb (bool): if true the contribution of each embedding will be summed-up otherwise stacked
75
+ out_channels (int): number of output channels
76
+ activation (str, optional): activation fuction function pytorch. Default torch.nn.ReLU
77
+ remove_last (bool, optional): if True the model learns the difference respect to the last seen point
78
+ persistence_weight (float): weight controlling the divergence from persistence model. Default 0
79
+ loss_type (str, optional): this model uses custom losses or l1 or mse. Custom losses can be linear_penalization or exponential_penalization. Default l1,
80
+ quantiles (List[int], optional): we can use quantile loss il len(quantiles) = 0 (usually 0.1,0.5, 0.9) or L1loss in case len(quantiles)==0. Defaults to [].
81
+ dropout_rate (float, optional): dropout rate in Dropout layers
82
+ use_bn (bool, optional): if true BN layers will be added and dropouts will be removed
83
+ use_glu (bool,optional): use GLU for feature selection. Defaults to True.
84
+ glu_percentage (float, optiona): percentage of features to use. Defaults to 1.0.
85
+ n_classes (int): number of classes (0 in regression)
86
+ optim (str, optional): if not None it expects a pytorch optim method. Defaults to None that is mapped to Adam.
87
+ optim_config (dict, optional): configuration for Adam optimizer. Defaults to None.
88
+ scheduler_config (dict, optional): configuration for stepLR scheduler. Defaults to None.
89
+
90
+ """
91
+
92
+
93
+ super().__init__(**kwargs)
94
+ self.save_hyperparameters(logger=False)
95
+ self.d_model = d_model
96
+ self.max_voc_size = max_voc_size
97
+ self.future_steps = future_steps
98
+ self.epoch_vqvae = epoch_vqvae
99
+ ##PRIMA VQVAE
100
+ assert out_channels==1, beauty_string('Working only for one singal','section',True)
101
+ assert past_steps%2==0 and future_steps%2==0, beauty_string('There are some issue with the deconder in case of odd length','section',True)
102
+ self.vqvae = VQVAE(in_channels=1, hidden_channels=hidden_channels,out_channels=1,num_embeddings= max_voc_size,embedding_dim=d_model,commitment_cost=commitment_cost,decay=decay )
103
+
104
+ ##POI GPT
105
+
106
+
107
+ self.block_size = past_steps//2 + future_steps//2 -1
108
+ self.sentence_length = future_steps//2
109
+
110
+ self.transformer = nn.ModuleDict(dict(
111
+ wte = nn.Embedding(max_voc_size, d_model),
112
+ wpe = nn.Embedding(self.block_size, d_model),
113
+ drop = nn.Dropout(dropout_rate),
114
+ h = nn.ModuleList([Block( d_model,dropout_rate,n_heads,dropout_rate,self.block_size) for _ in range(num_layers)]), ##care can be different dropouts
115
+ ln_f = nn.LayerNorm(d_model),
116
+ lm_head = nn.Linear(d_model, max_voc_size, bias=False)
117
+ ))
118
+ # report number of parameters (note we don't count the decoder parameters in lm_head)
119
+ n_params = sum(p.numel() for p in self.transformer.parameters())
120
+ beauty_string("number of parameters: %.2fM" % (n_params/1e6,),'info',self.verbose)
121
+
122
+ self.use_quantiles = False
123
+ self.is_classification = True
124
+ self.optim_config = optim_config
125
+
126
+
127
+
128
+
129
+ def configure_optimizers(self):
130
+
131
+
132
+ #return torch.optim.Adam(self.vqvae.parameters(), lr=self.optim_config.lr_vqvae,
133
+ # weight_decay=self.optim_config.weight_decay_vqvae)
134
+
135
+
136
+ return torch.optim.AdamW([
137
+ {'params':self.vqvae.parameters(),'lr':self.optim_config.lr_vqvae,'weight_decay':self.optim_config.weight_decay_vqvae},
138
+ {'params':self.transformer.parameters(),'lr':self.optim_config.lr_gpt,'weight_decay':self.optim_config.weight_decay_gpt},
139
+ ])
140
+
141
+
142
+
143
+ def gpt(self,tokens):
144
+
145
+
146
+ b, t = tokens.size()
147
+ assert t <= self.block_size, beauty_string("Cannot forward sequence of length {t}, block size is only {self.block_size}",'section',True)
148
+ pos = torch.arange(0, t, dtype=torch.long, device=self.device).unsqueeze(0) # shape (1, t)
149
+
150
+ # forward the GPT model itself
151
+ tok_emb = self.transformer.wte(tokens) # token embeddings of shape (b, t, n_embd)
152
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
153
+ x = self.transformer.drop(tok_emb + pos_emb)
154
+ for block in self.transformer.h:
155
+ x = block(x)
156
+ x = self.transformer.ln_f(x)
157
+ logits = self.transformer.lm_head(x)
158
+ return logits
159
+
160
+ def forward(self, batch):
161
+
162
+ ##VQVAE
163
+ #current_epoch = self.current_epoch
164
+ #if current_epoch < 1000:
165
+ # self.vqvae.train()
166
+ # loss_gpt = 100
167
+ #else:
168
+ # self.vqvae.eval()
169
+ idx_target = batch['idx_target'][0]
170
+
171
+ #(tensor([194, 163, 174, 176, 160, 168, 175]),
172
+ # tensor([ -1, -1, -1, 160, 168, 175, 160]))
173
+
174
+
175
+ data = batch['x_num_past'][:,:,idx_target]
176
+ if self.current_epoch > self.epoch_vqvae:
177
+ with torch.no_grad():
178
+ vqloss, data_recon, perplexity,quantized_x,encodings_x = self.vqvae(data.permute(0,2,1))
179
+ loss_vqvae = 0
180
+ else:
181
+ vq_loss, data_recon, perplexity,quantized_x,encodings_x = self.vqvae(data.permute(0,2,1))
182
+ if random()<0.001:
183
+ beauty_string(perplexity,'info',self.verbose)
184
+ recon_error = F.mse_loss(data_recon.squeeze(), data.squeeze())
185
+ loss_vqvae = recon_error + vq_loss
186
+
187
+ if self.current_epoch > self.epoch_vqvae:
188
+ with torch.no_grad():
189
+ _, _, _,quantized_y,encodings_y = self.vqvae(batch['y'].permute(0,2,1))
190
+
191
+ ##GPT
192
+ tokens = torch.cat([encodings_x.argmax(dim=2),encodings_y.argmax(dim=2)[:,0:-1]],1)
193
+ tokens_y = torch.cat([encodings_x.argmax(dim=2)[:,0:-1],encodings_y.argmax(dim=2)],1)
194
+ tokens_y[:,0:encodings_x.shape[1]-1] = -1
195
+ logits = self.gpt(tokens)
196
+ loss_gpt = F.cross_entropy(logits.view(-1, logits.size(-1)),tokens_y.view(-1), ignore_index=-1)
197
+
198
+
199
+ ##adesso devo ricostruire la y perche' e quello che voglio come output
200
+ with torch.no_grad():
201
+ encoding_indices = torch.argmax(logits.reshape(-1,self.max_voc_size), dim=1).unsqueeze(1) ##
202
+ encodings = torch.zeros(encoding_indices.shape[0], self.vqvae._vq_vae._num_embeddings, device=self.device)
203
+ encodings.scatter_(1, encoding_indices, 1)
204
+ quantized = torch.matmul(encodings, self.vqvae._vq_vae._embedding.weight).view(data.shape[0],-1,self.d_model) ##B x L x hidden
205
+ quantized = quantized.permute(0, 2, 1).contiguous()
206
+ y_hat = self.vqvae._decoder(quantized,False).squeeze()[:,-self.future_steps:]
207
+
208
+ l1_loss = nn.L1Loss()(y_hat,batch['y'].squeeze())
209
+
210
+ return y_hat, loss_vqvae+loss_gpt+l1_loss
211
+ else:
212
+ return None, loss_vqvae
213
+
214
+
215
+ def training_step(self, batch, batch_idx):
216
+ """
217
+ pythotrch lightening stuff
218
+
219
+ :meta private:
220
+ """
221
+ _, loss = self(batch)
222
+ return loss
223
+
224
+ def validation_step(self, batch, batch_idx):
225
+ """
226
+ pythotrch lightening stuff
227
+
228
+ :meta private:
229
+ """
230
+ _, loss = self(batch)
231
+ return loss
232
+
233
+
234
+ def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None,num_samples=100):
235
+ """
236
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
237
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
238
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
239
+ """
240
+ assert do_sample is False,logging.info('NOT IMPLEMENTED YET')
241
+ if do_sample:
242
+
243
+ idx = idx.repeat(num_samples,1,1)
244
+ for _ in range(max_new_tokens):
245
+ tmp = []
246
+ for i in range(num_samples):
247
+ idx_cond = idx[i,:,:] if idx.size(2) <= self.block_size else idx[i,:, -self.block_size:]
248
+ logits = self.gpt(idx_cond)
249
+ logits = logits[:, -1, :] / temperature
250
+ if top_k is not None:
251
+ v, _ = torch.topk(logits, top_k)
252
+ logits[logits < v[:, [-1]]] = -float('Inf')
253
+ probs = F.softmax(logits, dim=-1)
254
+ idx_next = torch.multinomial(probs, num_samples=1, replacement=True)
255
+ tmp.append(idx_next)
256
+ tmp = torch.cat(tmp,dim=1).T.unsqueeze(2)
257
+ idx = torch.cat((idx, tmp), dim=2)
258
+ return idx
259
+ else:
260
+ for _ in range(max_new_tokens):
261
+
262
+ # if the sequence context is growing too long we must crop it at block_size
263
+ idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
264
+ # forward the model to get the logits for the index in the sequence
265
+ logits = self.gpt(idx_cond)
266
+ # pluck the logits at the final step and scale by desired temperature
267
+ logits = logits[:, -1, :] / temperature
268
+ # optionally crop the logits to only the top k options
269
+ if top_k is not None:
270
+ v, _ = torch.topk(logits, top_k)
271
+ logits[logits < v[:, [-1]]] = -float('Inf')
272
+ # apply softmax to convert logits to (normalized) probabilities
273
+ probs = F.softmax(logits, dim=-1)
274
+ # either sample from the distribution or take the most likely element
275
+ _, idx_next = torch.topk(probs, k=1, dim=-1)
276
+ # append sampled index to the running sequence and continue
277
+ idx = torch.cat((idx, idx_next), dim=1)
278
+
279
+ return idx
280
+
281
+ def inference(self, batch:dict)->torch.tensor:
282
+
283
+ idx_target = batch['idx_target'][0]
284
+ data = batch['x_num_past'][:,:,idx_target].to(self.device)
285
+ vq_loss, data_recon, perplexity,quantized_x,encodings_x = self.vqvae(data.permute(0,2,1))
286
+ x = encodings_x.argmax(dim=2)
287
+ inp = x[:, :self.sentence_length]
288
+ # let the model sample the rest of the sequence
289
+ cat = self.generate(inp, self.sentence_length, do_sample=False) # non riesco a gestirla qui :-)
290
+ encoding_indices = cat.flatten().unsqueeze(1) ##
291
+ encodings = torch.zeros(encoding_indices.shape[0], self.vqvae._vq_vae._num_embeddings, device=self.device)
292
+ encodings.scatter_(1, encoding_indices, 1)
293
+ quantized = torch.matmul(encodings, self.vqvae._vq_vae._embedding.weight).view(x.shape[0],-1,self.d_model) ##B x L x hidden
294
+ quantized = quantized.permute(0, 2, 1).contiguous()
295
+ y_hat = self.vqvae._decoder(quantized,False).squeeze()[:,-self.future_steps:]
296
+
297
+ ## BxLxCx3
298
+ return y_hat.unsqueeze(2).unsqueeze(3)
299
+