dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
dsipts/models/TimeXER.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
## Copyright https://github.com/thuml/Time-Series-Library/blob/main/models/TimeMixer.py
|
|
2
|
+
## Modified for notation alignmenet and batch structure
|
|
3
|
+
## extended to what inside timexer folder
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import lightning.pytorch as pl
|
|
13
|
+
from .base_v2 import Base
|
|
14
|
+
OLD_PL = False
|
|
15
|
+
except:
|
|
16
|
+
import pytorch_lightning as pl
|
|
17
|
+
OLD_PL = True
|
|
18
|
+
from .base import Base
|
|
19
|
+
from .utils import QuantileLossMO,Permute, get_activation
|
|
20
|
+
from .itransformer.SelfAttention_Family import FullAttention, AttentionLayer
|
|
21
|
+
from .itransformer.Embed import DataEmbedding_inverted
|
|
22
|
+
from .timexer.Layers import FlattenHead,EnEmbedding, EncoderLayer, Encoder
|
|
23
|
+
|
|
24
|
+
from typing import List, Union
|
|
25
|
+
from ..data_structure.utils import beauty_string
|
|
26
|
+
from .utils import get_scope
|
|
27
|
+
from .utils import Embedding_cat_variables
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TimeXER(Base):
|
|
33
|
+
handle_multivariate = True
|
|
34
|
+
handle_future_covariates = True # or at least it seems...
|
|
35
|
+
handle_categorical_variables = True #solo nel encoder
|
|
36
|
+
handle_quantile_loss = True # NOT EFFICIENTLY ADDED, TODO fix this
|
|
37
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
38
|
+
|
|
39
|
+
def __init__(self,
|
|
40
|
+
patch_len:int,
|
|
41
|
+
d_model: int,
|
|
42
|
+
n_head: int,
|
|
43
|
+
d_ff:int=512,
|
|
44
|
+
dropout_rate: float=0.1,
|
|
45
|
+
n_layer_decoder: int=1,
|
|
46
|
+
activation: str='',
|
|
47
|
+
**kwargs)->None:
|
|
48
|
+
"""Initialize the model with specified parameters. https://github.com/thuml/Time-Series-Library/blob/main/models/TimeMixer.py
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
patch_len (int): Length of the patches.
|
|
52
|
+
d_model (int): Dimension of the model.
|
|
53
|
+
n_head (int): Number of attention heads.
|
|
54
|
+
d_ff (int, optional): Dimension of the feedforward network. Defaults to 512.
|
|
55
|
+
dropout_rate (float, optional): Dropout rate for regularization. Defaults to 0.1.
|
|
56
|
+
n_layer_decoder (int, optional): Number of layers in the decoder. Defaults to 1.
|
|
57
|
+
activation (str, optional): Activation function to use. Defaults to ''.
|
|
58
|
+
**kwargs: Additional keyword arguments passed to the superclass.
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
ValueError: If an invalid activation function is provided.
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
"""
|
|
65
|
+
super().__init__(**kwargs)
|
|
66
|
+
if activation == 'torch.nn.SELU':
|
|
67
|
+
beauty_string('SELU do not require BN','info',self.verbose)
|
|
68
|
+
use_bn = False
|
|
69
|
+
if isinstance(activation,str):
|
|
70
|
+
activation = get_activation(activation)
|
|
71
|
+
self.save_hyperparameters(logger=False)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
self.patch_len = patch_len
|
|
78
|
+
self.patch_num = int(self.past_steps // patch_len)
|
|
79
|
+
d_model = d_model*self.mul
|
|
80
|
+
|
|
81
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
82
|
+
self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
83
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
84
|
+
emb_fut_out_channel = self.emb_fut.output_channels
|
|
85
|
+
|
|
86
|
+
self.output_attention = False## not need output attention
|
|
87
|
+
|
|
88
|
+
###
|
|
89
|
+
self.en_embedding = EnEmbedding(self.past_channels, d_model, patch_len, dropout_rate)
|
|
90
|
+
|
|
91
|
+
self.ex_embedding = DataEmbedding_inverted(self.past_steps, d_model, embed_type='what?', freq='what?', dropout=dropout_rate) ##embed, freq not used inside
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Encoder-only architecture
|
|
95
|
+
self.encoder = Encoder(
|
|
96
|
+
[
|
|
97
|
+
EncoderLayer(
|
|
98
|
+
AttentionLayer(
|
|
99
|
+
FullAttention(False, factor = 0.1, attention_dropout=dropout_rate, ##NB factor is not used
|
|
100
|
+
output_attention=False),
|
|
101
|
+
d_model, n_head),
|
|
102
|
+
AttentionLayer(
|
|
103
|
+
FullAttention(False, 0.1, attention_dropout=dropout_rate,
|
|
104
|
+
output_attention=False),
|
|
105
|
+
d_model, n_head),
|
|
106
|
+
d_model,
|
|
107
|
+
d_ff,
|
|
108
|
+
dropout=dropout_rate,
|
|
109
|
+
activation=activation(),
|
|
110
|
+
)
|
|
111
|
+
for l in range(n_layer_decoder)
|
|
112
|
+
],
|
|
113
|
+
norm_layer=torch.nn.LayerNorm(d_model)
|
|
114
|
+
)
|
|
115
|
+
self.head_nf = d_model * (self.patch_num + 1)
|
|
116
|
+
self.head = FlattenHead(self.past_channels, self.head_nf, self.future_steps*self.mul, head_dropout=dropout_rate)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
self.future_reshape = nn.Linear(self.future_steps,self.future_steps*self.mul)
|
|
120
|
+
self.final_linear = nn.Sequential(activation(),
|
|
121
|
+
nn.Linear(self.past_channels+self.future_channels+emb_fut_out_channel,(self.past_channels+self.future_channels+emb_fut_out_channel)//2),
|
|
122
|
+
activation(),
|
|
123
|
+
nn.Linear((self.past_channels+self.future_channels+emb_fut_out_channel)//2,self.out_channels)
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def forward(self, batch:dict)-> float:
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
x_enc = batch['x_num_past'].to(self.device)
|
|
135
|
+
|
|
136
|
+
BS = x_enc.shape[0]
|
|
137
|
+
if 'x_cat_future' in batch.keys():
|
|
138
|
+
emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
|
|
139
|
+
else:
|
|
140
|
+
emb_fut = self.emb_fut(BS,None)
|
|
141
|
+
tmp_future = [emb_fut]
|
|
142
|
+
if 'x_cat_past' in batch.keys():
|
|
143
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
144
|
+
else:
|
|
145
|
+
emb_past = self.emb_past(BS,None)
|
|
146
|
+
|
|
147
|
+
if 'x_num_future' in batch.keys():
|
|
148
|
+
x_future = batch['x_num_future'].to(self.device)
|
|
149
|
+
tmp_future.append(x_future)
|
|
150
|
+
if len(tmp_future)>0:
|
|
151
|
+
tmp_future = torch.cat(tmp_future,2)
|
|
152
|
+
else:
|
|
153
|
+
tmp_future = None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
en_embed, n_vars = self.en_embedding(x_enc.permute(0, 2, 1))
|
|
160
|
+
ex_embed = self.ex_embedding(x_enc, emb_past)
|
|
161
|
+
|
|
162
|
+
enc_out = self.encoder(en_embed, ex_embed)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
enc_out = torch.reshape(enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
# z: [bs x nvars x d_model x patch_num]
|
|
171
|
+
enc_out = enc_out.permute(0, 1, 3, 2)
|
|
172
|
+
|
|
173
|
+
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
|
|
174
|
+
#dec_out = dec_out.permute(0, 2, 1)
|
|
175
|
+
if tmp_future is not None:
|
|
176
|
+
tmp_future = self.future_reshape(tmp_future.permute(0, 2, 1))
|
|
177
|
+
dec_out = torch.cat([tmp_future,dec_out],1)
|
|
178
|
+
dec_out = self.final_linear(dec_out.permute(0, 2, 1))
|
|
179
|
+
return dec_out.reshape(BS,self.future_steps,self.out_channels,self.mul)
|
|
180
|
+
|
|
181
|
+
#idx_target = batch['idx_target'][0]
|
|
182
|
+
#return dec_out[:, :,idx_target].reshape(BS,self.future_steps,self.out_channels,self.mul)
|
|
183
|
+
|
|
184
|
+
|
dsipts/models/VQVAEA.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
|
|
2
|
+
## Copyright Copyright (c) 2020 Andrej Karpathy https://github.com/karpathy/minGPT?tab=MIT-1-ov-file#readme
|
|
3
|
+
## Modified for notation alignmenet, batch structure and adapted for timeseries
|
|
4
|
+
## extended to what inside vva folder
|
|
5
|
+
|
|
6
|
+
from torch import nn
|
|
7
|
+
import torch
|
|
8
|
+
from torch.nn import functional as F
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import lightning.pytorch as pl
|
|
12
|
+
from .base_v2 import Base
|
|
13
|
+
OLD_PL = False
|
|
14
|
+
except:
|
|
15
|
+
import pytorch_lightning as pl
|
|
16
|
+
OLD_PL = True
|
|
17
|
+
from .base import Base
|
|
18
|
+
from typing import List, Union
|
|
19
|
+
from .vva.minigpt import Block
|
|
20
|
+
from .vva.vqvae import VQVAE
|
|
21
|
+
import logging
|
|
22
|
+
from random import random
|
|
23
|
+
from ..data_structure.utils import beauty_string
|
|
24
|
+
from .utils import get_scope
|
|
25
|
+
|
|
26
|
+
torch.autograd.set_detect_anomaly(True)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class VQVAEA(Base):
|
|
31
|
+
handle_multivariate = False
|
|
32
|
+
handle_future_covariates = False
|
|
33
|
+
handle_categorical_variables = False
|
|
34
|
+
handle_quantile_loss = False
|
|
35
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def __init__(self,
|
|
39
|
+
past_steps:int,
|
|
40
|
+
future_steps:int,
|
|
41
|
+
past_channels:int,
|
|
42
|
+
future_channels:int,
|
|
43
|
+
hidden_channels:int,
|
|
44
|
+
embs:List[int],
|
|
45
|
+
d_model:int,
|
|
46
|
+
max_voc_size:int,
|
|
47
|
+
num_layers:int,
|
|
48
|
+
dropout_rate:float,
|
|
49
|
+
commitment_cost:float,
|
|
50
|
+
decay:float,
|
|
51
|
+
n_heads:int,
|
|
52
|
+
out_channels:int,
|
|
53
|
+
epoch_vqvae: int,
|
|
54
|
+
persistence_weight:float=0.0,
|
|
55
|
+
loss_type: str='l1',
|
|
56
|
+
quantiles:List[int]=[],
|
|
57
|
+
optim:Union[str,None]=None,
|
|
58
|
+
optim_config:dict=None,
|
|
59
|
+
scheduler_config:dict=None,
|
|
60
|
+
**kwargs)->None:
|
|
61
|
+
""" Custom encoder-decoder
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
past_steps (int): number of past datapoints used
|
|
65
|
+
future_steps (int): number of future lag to predict
|
|
66
|
+
past_channels (int): number of numeric past variables, must be >0
|
|
67
|
+
future_channels (int): number of future numeric variables
|
|
68
|
+
embs (List): list of the initial dimension of the categorical variables
|
|
69
|
+
cat_emb_dim (int): final dimension of each categorical variable
|
|
70
|
+
hidden_RNN (int): hidden size of the RNN block
|
|
71
|
+
num_layers_RNN (int): number of RNN layers
|
|
72
|
+
kind (str): one among GRU or LSTM
|
|
73
|
+
kernel_size (int): kernel size in the encoder convolutional block
|
|
74
|
+
sum_emb (bool): if true the contribution of each embedding will be summed-up otherwise stacked
|
|
75
|
+
out_channels (int): number of output channels
|
|
76
|
+
activation (str, optional): activation fuction function pytorch. Default torch.nn.ReLU
|
|
77
|
+
remove_last (bool, optional): if True the model learns the difference respect to the last seen point
|
|
78
|
+
persistence_weight (float): weight controlling the divergence from persistence model. Default 0
|
|
79
|
+
loss_type (str, optional): this model uses custom losses or l1 or mse. Custom losses can be linear_penalization or exponential_penalization. Default l1,
|
|
80
|
+
quantiles (List[int], optional): we can use quantile loss il len(quantiles) = 0 (usually 0.1,0.5, 0.9) or L1loss in case len(quantiles)==0. Defaults to [].
|
|
81
|
+
dropout_rate (float, optional): dropout rate in Dropout layers
|
|
82
|
+
use_bn (bool, optional): if true BN layers will be added and dropouts will be removed
|
|
83
|
+
use_glu (bool,optional): use GLU for feature selection. Defaults to True.
|
|
84
|
+
glu_percentage (float, optiona): percentage of features to use. Defaults to 1.0.
|
|
85
|
+
n_classes (int): number of classes (0 in regression)
|
|
86
|
+
optim (str, optional): if not None it expects a pytorch optim method. Defaults to None that is mapped to Adam.
|
|
87
|
+
optim_config (dict, optional): configuration for Adam optimizer. Defaults to None.
|
|
88
|
+
scheduler_config (dict, optional): configuration for stepLR scheduler. Defaults to None.
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
super().__init__(**kwargs)
|
|
94
|
+
self.save_hyperparameters(logger=False)
|
|
95
|
+
self.d_model = d_model
|
|
96
|
+
self.max_voc_size = max_voc_size
|
|
97
|
+
self.future_steps = future_steps
|
|
98
|
+
self.epoch_vqvae = epoch_vqvae
|
|
99
|
+
##PRIMA VQVAE
|
|
100
|
+
assert out_channels==1, beauty_string('Working only for one singal','section',True)
|
|
101
|
+
assert past_steps%2==0 and future_steps%2==0, beauty_string('There are some issue with the deconder in case of odd length','section',True)
|
|
102
|
+
self.vqvae = VQVAE(in_channels=1, hidden_channels=hidden_channels,out_channels=1,num_embeddings= max_voc_size,embedding_dim=d_model,commitment_cost=commitment_cost,decay=decay )
|
|
103
|
+
|
|
104
|
+
##POI GPT
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
self.block_size = past_steps//2 + future_steps//2 -1
|
|
108
|
+
self.sentence_length = future_steps//2
|
|
109
|
+
|
|
110
|
+
self.transformer = nn.ModuleDict(dict(
|
|
111
|
+
wte = nn.Embedding(max_voc_size, d_model),
|
|
112
|
+
wpe = nn.Embedding(self.block_size, d_model),
|
|
113
|
+
drop = nn.Dropout(dropout_rate),
|
|
114
|
+
h = nn.ModuleList([Block( d_model,dropout_rate,n_heads,dropout_rate,self.block_size) for _ in range(num_layers)]), ##care can be different dropouts
|
|
115
|
+
ln_f = nn.LayerNorm(d_model),
|
|
116
|
+
lm_head = nn.Linear(d_model, max_voc_size, bias=False)
|
|
117
|
+
))
|
|
118
|
+
# report number of parameters (note we don't count the decoder parameters in lm_head)
|
|
119
|
+
n_params = sum(p.numel() for p in self.transformer.parameters())
|
|
120
|
+
beauty_string("number of parameters: %.2fM" % (n_params/1e6,),'info',self.verbose)
|
|
121
|
+
|
|
122
|
+
self.use_quantiles = False
|
|
123
|
+
self.is_classification = True
|
|
124
|
+
self.optim_config = optim_config
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def configure_optimizers(self):
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
#return torch.optim.Adam(self.vqvae.parameters(), lr=self.optim_config.lr_vqvae,
|
|
133
|
+
# weight_decay=self.optim_config.weight_decay_vqvae)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
return torch.optim.AdamW([
|
|
137
|
+
{'params':self.vqvae.parameters(),'lr':self.optim_config.lr_vqvae,'weight_decay':self.optim_config.weight_decay_vqvae},
|
|
138
|
+
{'params':self.transformer.parameters(),'lr':self.optim_config.lr_gpt,'weight_decay':self.optim_config.weight_decay_gpt},
|
|
139
|
+
])
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def gpt(self,tokens):
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
b, t = tokens.size()
|
|
147
|
+
assert t <= self.block_size, beauty_string("Cannot forward sequence of length {t}, block size is only {self.block_size}",'section',True)
|
|
148
|
+
pos = torch.arange(0, t, dtype=torch.long, device=self.device).unsqueeze(0) # shape (1, t)
|
|
149
|
+
|
|
150
|
+
# forward the GPT model itself
|
|
151
|
+
tok_emb = self.transformer.wte(tokens) # token embeddings of shape (b, t, n_embd)
|
|
152
|
+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
|
|
153
|
+
x = self.transformer.drop(tok_emb + pos_emb)
|
|
154
|
+
for block in self.transformer.h:
|
|
155
|
+
x = block(x)
|
|
156
|
+
x = self.transformer.ln_f(x)
|
|
157
|
+
logits = self.transformer.lm_head(x)
|
|
158
|
+
return logits
|
|
159
|
+
|
|
160
|
+
def forward(self, batch):
|
|
161
|
+
|
|
162
|
+
##VQVAE
|
|
163
|
+
#current_epoch = self.current_epoch
|
|
164
|
+
#if current_epoch < 1000:
|
|
165
|
+
# self.vqvae.train()
|
|
166
|
+
# loss_gpt = 100
|
|
167
|
+
#else:
|
|
168
|
+
# self.vqvae.eval()
|
|
169
|
+
idx_target = batch['idx_target'][0]
|
|
170
|
+
|
|
171
|
+
#(tensor([194, 163, 174, 176, 160, 168, 175]),
|
|
172
|
+
# tensor([ -1, -1, -1, 160, 168, 175, 160]))
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
data = batch['x_num_past'][:,:,idx_target]
|
|
176
|
+
if self.current_epoch > self.epoch_vqvae:
|
|
177
|
+
with torch.no_grad():
|
|
178
|
+
vqloss, data_recon, perplexity,quantized_x,encodings_x = self.vqvae(data.permute(0,2,1))
|
|
179
|
+
loss_vqvae = 0
|
|
180
|
+
else:
|
|
181
|
+
vq_loss, data_recon, perplexity,quantized_x,encodings_x = self.vqvae(data.permute(0,2,1))
|
|
182
|
+
if random()<0.001:
|
|
183
|
+
beauty_string(perplexity,'info',self.verbose)
|
|
184
|
+
recon_error = F.mse_loss(data_recon.squeeze(), data.squeeze())
|
|
185
|
+
loss_vqvae = recon_error + vq_loss
|
|
186
|
+
|
|
187
|
+
if self.current_epoch > self.epoch_vqvae:
|
|
188
|
+
with torch.no_grad():
|
|
189
|
+
_, _, _,quantized_y,encodings_y = self.vqvae(batch['y'].permute(0,2,1))
|
|
190
|
+
|
|
191
|
+
##GPT
|
|
192
|
+
tokens = torch.cat([encodings_x.argmax(dim=2),encodings_y.argmax(dim=2)[:,0:-1]],1)
|
|
193
|
+
tokens_y = torch.cat([encodings_x.argmax(dim=2)[:,0:-1],encodings_y.argmax(dim=2)],1)
|
|
194
|
+
tokens_y[:,0:encodings_x.shape[1]-1] = -1
|
|
195
|
+
logits = self.gpt(tokens)
|
|
196
|
+
loss_gpt = F.cross_entropy(logits.view(-1, logits.size(-1)),tokens_y.view(-1), ignore_index=-1)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
##adesso devo ricostruire la y perche' e quello che voglio come output
|
|
200
|
+
with torch.no_grad():
|
|
201
|
+
encoding_indices = torch.argmax(logits.reshape(-1,self.max_voc_size), dim=1).unsqueeze(1) ##
|
|
202
|
+
encodings = torch.zeros(encoding_indices.shape[0], self.vqvae._vq_vae._num_embeddings, device=self.device)
|
|
203
|
+
encodings.scatter_(1, encoding_indices, 1)
|
|
204
|
+
quantized = torch.matmul(encodings, self.vqvae._vq_vae._embedding.weight).view(data.shape[0],-1,self.d_model) ##B x L x hidden
|
|
205
|
+
quantized = quantized.permute(0, 2, 1).contiguous()
|
|
206
|
+
y_hat = self.vqvae._decoder(quantized,False).squeeze()[:,-self.future_steps:]
|
|
207
|
+
|
|
208
|
+
l1_loss = nn.L1Loss()(y_hat,batch['y'].squeeze())
|
|
209
|
+
|
|
210
|
+
return y_hat, loss_vqvae+loss_gpt+l1_loss
|
|
211
|
+
else:
|
|
212
|
+
return None, loss_vqvae
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def training_step(self, batch, batch_idx):
|
|
216
|
+
"""
|
|
217
|
+
pythotrch lightening stuff
|
|
218
|
+
|
|
219
|
+
:meta private:
|
|
220
|
+
"""
|
|
221
|
+
_, loss = self(batch)
|
|
222
|
+
return loss
|
|
223
|
+
|
|
224
|
+
def validation_step(self, batch, batch_idx):
|
|
225
|
+
"""
|
|
226
|
+
pythotrch lightening stuff
|
|
227
|
+
|
|
228
|
+
:meta private:
|
|
229
|
+
"""
|
|
230
|
+
_, loss = self(batch)
|
|
231
|
+
return loss
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None,num_samples=100):
|
|
235
|
+
"""
|
|
236
|
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
|
237
|
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
|
238
|
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
|
239
|
+
"""
|
|
240
|
+
assert do_sample is False,logging.info('NOT IMPLEMENTED YET')
|
|
241
|
+
if do_sample:
|
|
242
|
+
|
|
243
|
+
idx = idx.repeat(num_samples,1,1)
|
|
244
|
+
for _ in range(max_new_tokens):
|
|
245
|
+
tmp = []
|
|
246
|
+
for i in range(num_samples):
|
|
247
|
+
idx_cond = idx[i,:,:] if idx.size(2) <= self.block_size else idx[i,:, -self.block_size:]
|
|
248
|
+
logits = self.gpt(idx_cond)
|
|
249
|
+
logits = logits[:, -1, :] / temperature
|
|
250
|
+
if top_k is not None:
|
|
251
|
+
v, _ = torch.topk(logits, top_k)
|
|
252
|
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
|
253
|
+
probs = F.softmax(logits, dim=-1)
|
|
254
|
+
idx_next = torch.multinomial(probs, num_samples=1, replacement=True)
|
|
255
|
+
tmp.append(idx_next)
|
|
256
|
+
tmp = torch.cat(tmp,dim=1).T.unsqueeze(2)
|
|
257
|
+
idx = torch.cat((idx, tmp), dim=2)
|
|
258
|
+
return idx
|
|
259
|
+
else:
|
|
260
|
+
for _ in range(max_new_tokens):
|
|
261
|
+
|
|
262
|
+
# if the sequence context is growing too long we must crop it at block_size
|
|
263
|
+
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
|
|
264
|
+
# forward the model to get the logits for the index in the sequence
|
|
265
|
+
logits = self.gpt(idx_cond)
|
|
266
|
+
# pluck the logits at the final step and scale by desired temperature
|
|
267
|
+
logits = logits[:, -1, :] / temperature
|
|
268
|
+
# optionally crop the logits to only the top k options
|
|
269
|
+
if top_k is not None:
|
|
270
|
+
v, _ = torch.topk(logits, top_k)
|
|
271
|
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
|
272
|
+
# apply softmax to convert logits to (normalized) probabilities
|
|
273
|
+
probs = F.softmax(logits, dim=-1)
|
|
274
|
+
# either sample from the distribution or take the most likely element
|
|
275
|
+
_, idx_next = torch.topk(probs, k=1, dim=-1)
|
|
276
|
+
# append sampled index to the running sequence and continue
|
|
277
|
+
idx = torch.cat((idx, idx_next), dim=1)
|
|
278
|
+
|
|
279
|
+
return idx
|
|
280
|
+
|
|
281
|
+
def inference(self, batch:dict)->torch.tensor:
|
|
282
|
+
|
|
283
|
+
idx_target = batch['idx_target'][0]
|
|
284
|
+
data = batch['x_num_past'][:,:,idx_target].to(self.device)
|
|
285
|
+
vq_loss, data_recon, perplexity,quantized_x,encodings_x = self.vqvae(data.permute(0,2,1))
|
|
286
|
+
x = encodings_x.argmax(dim=2)
|
|
287
|
+
inp = x[:, :self.sentence_length]
|
|
288
|
+
# let the model sample the rest of the sequence
|
|
289
|
+
cat = self.generate(inp, self.sentence_length, do_sample=False) # non riesco a gestirla qui :-)
|
|
290
|
+
encoding_indices = cat.flatten().unsqueeze(1) ##
|
|
291
|
+
encodings = torch.zeros(encoding_indices.shape[0], self.vqvae._vq_vae._num_embeddings, device=self.device)
|
|
292
|
+
encodings.scatter_(1, encoding_indices, 1)
|
|
293
|
+
quantized = torch.matmul(encodings, self.vqvae._vq_vae._embedding.weight).view(x.shape[0],-1,self.d_model) ##B x L x hidden
|
|
294
|
+
quantized = quantized.permute(0, 2, 1).contiguous()
|
|
295
|
+
y_hat = self.vqvae._decoder(quantized,False).squeeze()[:,-self.future_steps:]
|
|
296
|
+
|
|
297
|
+
## BxLxCx3
|
|
298
|
+
return y_hat.unsqueeze(2).unsqueeze(3)
|
|
299
|
+
|