dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
from torch.optim import Optimizer
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
|
8
|
+
"""
|
|
9
|
+
A copy-paste from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
|
10
|
+
"""
|
|
11
|
+
L, S = query.size(-2), key.size(-2)
|
|
12
|
+
scale_factor = 1 / np.sqrt(query.size(-1)) if scale is None else scale
|
|
13
|
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
|
14
|
+
if is_causal:
|
|
15
|
+
assert attn_mask is None
|
|
16
|
+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
|
17
|
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
|
18
|
+
attn_bias.to(query.dtype)
|
|
19
|
+
|
|
20
|
+
if attn_mask is not None:
|
|
21
|
+
if attn_mask.dtype == torch.bool:
|
|
22
|
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
|
23
|
+
else:
|
|
24
|
+
attn_bias += attn_mask
|
|
25
|
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
26
|
+
attn_weight += attn_bias
|
|
27
|
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
28
|
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
|
29
|
+
return attn_weight @ value
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RevIN(nn.Module):
|
|
35
|
+
"""
|
|
36
|
+
Reversible Instance Normalization (RevIN) https://openreview.net/pdf?id=cGDAkQo1C0p
|
|
37
|
+
https://github.com/ts-kim/RevIN
|
|
38
|
+
"""
|
|
39
|
+
def __init__(self, num_features: int, eps=1e-5, affine=True):
|
|
40
|
+
"""
|
|
41
|
+
:param num_features: the number of features or channels
|
|
42
|
+
:param eps: a value added for numerical stability
|
|
43
|
+
:param affine: if True, RevIN has learnable affine parameters
|
|
44
|
+
"""
|
|
45
|
+
super(RevIN, self).__init__()
|
|
46
|
+
self.num_features = num_features
|
|
47
|
+
self.eps = eps
|
|
48
|
+
self.affine = affine
|
|
49
|
+
if self.affine:
|
|
50
|
+
self._init_params()
|
|
51
|
+
|
|
52
|
+
def forward(self, x, mode:str):
|
|
53
|
+
if mode == 'norm':
|
|
54
|
+
self._get_statistics(x)
|
|
55
|
+
x = self._normalize(x)
|
|
56
|
+
elif mode == 'denorm':
|
|
57
|
+
x = self._denormalize(x)
|
|
58
|
+
else: raise NotImplementedError
|
|
59
|
+
return x
|
|
60
|
+
|
|
61
|
+
def _init_params(self):
|
|
62
|
+
# initialize RevIN params: (C,)
|
|
63
|
+
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
|
64
|
+
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
|
65
|
+
|
|
66
|
+
def _get_statistics(self, x):
|
|
67
|
+
dim2reduce = tuple(range(1, x.ndim-1))
|
|
68
|
+
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
|
69
|
+
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
|
70
|
+
|
|
71
|
+
def _normalize(self, x):
|
|
72
|
+
x = x - self.mean
|
|
73
|
+
x = x / self.stdev
|
|
74
|
+
if self.affine:
|
|
75
|
+
x = x * self.affine_weight
|
|
76
|
+
x = x + self.affine_bias
|
|
77
|
+
return x
|
|
78
|
+
|
|
79
|
+
def _denormalize(self, x):
|
|
80
|
+
if self.affine:
|
|
81
|
+
x = x - self.affine_bias
|
|
82
|
+
x = x / (self.affine_weight + self.eps*self.eps)
|
|
83
|
+
x = x * self.stdev
|
|
84
|
+
x = x + self.mean
|
|
85
|
+
return x
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class SAM(torch.optim.Optimizer):
|
|
89
|
+
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
|
|
90
|
+
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
|
|
91
|
+
|
|
92
|
+
defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
|
|
93
|
+
super(SAM, self).__init__(params, defaults)
|
|
94
|
+
|
|
95
|
+
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
|
|
96
|
+
self.param_groups = self.base_optimizer.param_groups
|
|
97
|
+
self.defaults.update(self.base_optimizer.defaults)
|
|
98
|
+
|
|
99
|
+
@torch.no_grad()
|
|
100
|
+
def first_step(self, zero_grad=False):
|
|
101
|
+
grad_norm = self._grad_norm()
|
|
102
|
+
for group in self.param_groups:
|
|
103
|
+
scale = group["rho"] / (grad_norm + 1e-12)
|
|
104
|
+
|
|
105
|
+
for p in group["params"]:
|
|
106
|
+
if p.grad is None:
|
|
107
|
+
continue
|
|
108
|
+
self.state[p]["old_p"] = p.data.clone()
|
|
109
|
+
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
|
|
110
|
+
p.add_(e_w) # Perturb weights in the gradient direction
|
|
111
|
+
|
|
112
|
+
if zero_grad:
|
|
113
|
+
self.zero_grad()
|
|
114
|
+
|
|
115
|
+
@torch.no_grad()
|
|
116
|
+
def second_step(self, zero_grad=False):
|
|
117
|
+
for group in self.param_groups:
|
|
118
|
+
for p in group["params"]:
|
|
119
|
+
if p.grad is None:
|
|
120
|
+
continue
|
|
121
|
+
p.data = self.state[p]["old_p"] # Restore original weights
|
|
122
|
+
|
|
123
|
+
self.base_optimizer.step() # Apply the sharpness-aware update
|
|
124
|
+
|
|
125
|
+
if zero_grad:
|
|
126
|
+
self.zero_grad()
|
|
127
|
+
|
|
128
|
+
@torch.no_grad()
|
|
129
|
+
def step(self, closure=None):
|
|
130
|
+
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
|
|
131
|
+
|
|
132
|
+
with torch.enable_grad():
|
|
133
|
+
closure() # First forward-backward pass
|
|
134
|
+
|
|
135
|
+
self.first_step(zero_grad=True)
|
|
136
|
+
|
|
137
|
+
with torch.enable_grad():
|
|
138
|
+
closure() # Second forward-backward pass
|
|
139
|
+
|
|
140
|
+
self.second_step()
|
|
141
|
+
|
|
142
|
+
def _grad_norm(self):
|
|
143
|
+
shared_device = self.param_groups[0]["params"][0].device
|
|
144
|
+
grads = [
|
|
145
|
+
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
|
|
146
|
+
for group in self.param_groups for p in group["params"]
|
|
147
|
+
if p.grad is not None
|
|
148
|
+
]
|
|
149
|
+
return torch.norm(torch.stack(grads), p=2) if grads else torch.tensor(0.0, device=shared_device)
|
|
150
|
+
|
|
151
|
+
def load_state_dict(self, state_dict):
|
|
152
|
+
super().load_state_dict(state_dict)
|
|
153
|
+
if hasattr(self, "base_optimizer"): # Ensure base optimizer exists
|
|
154
|
+
self.base_optimizer.load_state_dict(state_dict["base_optimizer"])
|
|
File without changes
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
class embedding_cat_variables(nn.Module):
|
|
6
|
+
# at the moment cat_past and cat_fut together
|
|
7
|
+
def __init__(self, seq_len: int, lag: int, d_model: int, emb_dims: list, device):
|
|
8
|
+
"""Class for embedding categorical variables, adding 3 positional variables during forward
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
seq_len (int): length of the sequence (sum of past and future steps)
|
|
12
|
+
lag (int): number of future step to be predicted
|
|
13
|
+
hiden_size (int): dimension of all variables after they are embedded
|
|
14
|
+
emb_dims (list): size of the dictionary for embedding. One dimension for each categorical variable
|
|
15
|
+
device : -
|
|
16
|
+
"""
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.seq_len = seq_len
|
|
19
|
+
self.lag = lag
|
|
20
|
+
self.device = device
|
|
21
|
+
self.cat_embeds = emb_dims + [seq_len, lag+1, 2] #
|
|
22
|
+
self.cat_n_embd = nn.ModuleList([
|
|
23
|
+
nn.Embedding(emb_dim, d_model) for emb_dim in self.cat_embeds
|
|
24
|
+
])
|
|
25
|
+
|
|
26
|
+
def forward(self, x: Union[torch.Tensor,int],device:torch.device) -> torch.Tensor:
|
|
27
|
+
"""All components of x are concatenated with 3 new variables for data augmentation, in the order:
|
|
28
|
+
- pos_seq: assign at each step its time-position
|
|
29
|
+
- pos_fut: assign at each step its future position. 0 if it is a past step
|
|
30
|
+
- is_fut: explicit for each step if it is a future(1) or past one(0)
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
x (torch.Tensor): [bs, seq_len, num_vars]
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
torch.Tensor: [bs, seq_len, num_vars+3, n_embd]
|
|
37
|
+
"""
|
|
38
|
+
if isinstance(x, int):
|
|
39
|
+
no_emb = True
|
|
40
|
+
B = x
|
|
41
|
+
else:
|
|
42
|
+
no_emb = False
|
|
43
|
+
B, _, _ = x.shape
|
|
44
|
+
|
|
45
|
+
pos_seq = self.get_pos_seq(bs=B).to(device)
|
|
46
|
+
pos_fut = self.get_pos_fut(bs=B).to(device)
|
|
47
|
+
is_fut = self.get_is_fut(bs=B).to(device)
|
|
48
|
+
|
|
49
|
+
if no_emb:
|
|
50
|
+
cat_vars = torch.cat((pos_seq, pos_fut, is_fut), dim=2)
|
|
51
|
+
else:
|
|
52
|
+
cat_vars = torch.cat((x, pos_seq, pos_fut, is_fut), dim=2)
|
|
53
|
+
|
|
54
|
+
cat_n_embd = self.get_cat_n_embd(cat_vars)
|
|
55
|
+
return cat_n_embd
|
|
56
|
+
|
|
57
|
+
def get_pos_seq(self, bs):
|
|
58
|
+
pos_seq = torch.arange(0, self.seq_len)
|
|
59
|
+
pos_seq = pos_seq.repeat(bs,1).unsqueeze(2).to(self.device)
|
|
60
|
+
return pos_seq
|
|
61
|
+
|
|
62
|
+
def get_pos_fut(self, bs):
|
|
63
|
+
pos_fut = torch.cat((torch.zeros((self.seq_len-self.lag), dtype=torch.long),torch.arange(1,self.lag+1)))
|
|
64
|
+
pos_fut = pos_fut.repeat(bs,1).unsqueeze(2).to(self.device)
|
|
65
|
+
return pos_fut
|
|
66
|
+
|
|
67
|
+
def get_is_fut(self, bs):
|
|
68
|
+
is_fut = torch.cat((torch.zeros((self.seq_len-self.lag), dtype=torch.long),torch.ones((self.lag), dtype=torch.long)))
|
|
69
|
+
is_fut = is_fut.repeat(bs,1).unsqueeze(2).to(self.device)
|
|
70
|
+
return is_fut
|
|
71
|
+
|
|
72
|
+
def get_cat_n_embd(self, cat_vars):
|
|
73
|
+
cat_n_embd = torch.Tensor().to(cat_vars.device)
|
|
74
|
+
for index, layer in enumerate(self.cat_n_embd):
|
|
75
|
+
emb = layer(cat_vars[:, :, index])
|
|
76
|
+
cat_n_embd = torch.cat((cat_n_embd, emb.unsqueeze(2)),dim=2)
|
|
77
|
+
return cat_n_embd
|
|
78
|
+
|
|
79
|
+
class LSTM_Model(nn.Module):
|
|
80
|
+
def __init__(self, num_var: int, d_model: int, pred_step: int, num_layers: int, dropout: float):
|
|
81
|
+
"""LSTM from [..., d_model] to [..., predicted_step, num_of_vars]
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
num_var (int): number of variables encoded in the input tensor
|
|
85
|
+
d_model (int): encoding dimension of the tensor
|
|
86
|
+
pred_step (int): step to be predicted by LSTM
|
|
87
|
+
num_layers (int): number of layers of LSTM
|
|
88
|
+
dropout (float):
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
super().__init__()
|
|
92
|
+
self.num_var = num_var
|
|
93
|
+
self.d_model = d_model
|
|
94
|
+
self.num_layers = num_layers
|
|
95
|
+
self.pred_step = pred_step
|
|
96
|
+
|
|
97
|
+
self.lstm = nn.LSTM(d_model, d_model, num_layers=num_layers, batch_first=True, dropout=dropout)
|
|
98
|
+
self.linear = nn.Linear(d_model, pred_step*num_var)
|
|
99
|
+
|
|
100
|
+
def forward(self, x):
|
|
101
|
+
"""LSTM process over the x tensor and reshaping according to pred_step and num_var to be predicted
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
x (torch.Tensor): input tensor
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
torch.Tensor: tensor resized to [B, pred_step, num_var]
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
h0 = torch.zeros(self.num_layers, x.size(0), self.d_model).to(x.device)
|
|
111
|
+
c0 = torch.zeros(self.num_layers, x.size(0), self.d_model).to(x.device)
|
|
112
|
+
out, _ = self.lstm(x, (h0, c0))
|
|
113
|
+
out = self.linear(out[:, -1, :]) # Take the last output of the sequence
|
|
114
|
+
out = out.view(-1, self.pred_step, self.num_var)
|
|
115
|
+
return out
|
|
116
|
+
|
|
117
|
+
class GLU(nn.Module):
|
|
118
|
+
def __init__(self, d_model: int):
|
|
119
|
+
"""Gated Linear Unit
|
|
120
|
+
|
|
121
|
+
Auxiliary subnet for sigmoid element-wise multiplication
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
d_model (int): dimension of operations
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
super().__init__()
|
|
128
|
+
self.linear1 = nn.Linear(d_model, d_model, bias = False)
|
|
129
|
+
self.linear2 = nn.Linear(d_model, d_model, bias = False)
|
|
130
|
+
self.sigmoid = nn.Sigmoid()
|
|
131
|
+
|
|
132
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
133
|
+
x1 = self.sigmoid(self.linear1(x))
|
|
134
|
+
x2 = self.linear2(x)
|
|
135
|
+
out = x1*x2 #element-wise multiplication
|
|
136
|
+
return out
|
|
137
|
+
|
|
138
|
+
class GRN(nn.Module):
|
|
139
|
+
def __init__(self, d_model: int, dropout_rate: float):
|
|
140
|
+
"""Gated Residual Network
|
|
141
|
+
|
|
142
|
+
Auxiliary subnet for gating residual connections
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
d_model (int):
|
|
146
|
+
dropout_rate (float):
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
super().__init__()
|
|
150
|
+
self.linear1 = nn.Linear(d_model, d_model)
|
|
151
|
+
self.elu = nn.ELU()
|
|
152
|
+
self.linear2 = nn.Linear(d_model, d_model)
|
|
153
|
+
self.res_conn = ResidualConnection(d_model, dropout_rate)
|
|
154
|
+
|
|
155
|
+
def forward(self, x: torch.Tensor, using_norm:bool = True) -> torch.Tensor:
|
|
156
|
+
eta1 = self.elu(self.linear1(x))
|
|
157
|
+
eta2 = self.linear2(eta1)
|
|
158
|
+
out = self.res_conn(eta2, x, using_norm)
|
|
159
|
+
return out
|
|
160
|
+
|
|
161
|
+
class ResidualConnection(nn.Module):
|
|
162
|
+
def __init__(self, d_model, dropout_rate):
|
|
163
|
+
"""Residual Connection of res_conn with GLU(x)
|
|
164
|
+
|
|
165
|
+
Auxiliary subnet for residual connections
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
d_model (int):
|
|
169
|
+
dropout_rate (float):
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
super().__init__()
|
|
173
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
174
|
+
self.glu = GLU(d_model)
|
|
175
|
+
self.norm = nn.LayerNorm(d_model)
|
|
176
|
+
|
|
177
|
+
def forward(self, x: torch.Tensor, res_conn: torch.Tensor, using_norm:bool = True) -> torch.Tensor:
|
|
178
|
+
"""Res Cionnection using normalizing computatiion on 'x' and strict 'res_conn'
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
x (torch.Tensor): GLU(dropout(x))
|
|
182
|
+
res_conn (torch.Tensor): tensor summed to x before normalization
|
|
183
|
+
using_norm (bool, optional): _description_. Defaults to True.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
torch.Tensor:
|
|
187
|
+
"""
|
|
188
|
+
x = self.glu(self.dropout(x))
|
|
189
|
+
out = res_conn + x
|
|
190
|
+
if using_norm:
|
|
191
|
+
out = self.norm(out)
|
|
192
|
+
return out
|
|
193
|
+
|
|
194
|
+
class InterpretableMultiHead(nn.Module):
|
|
195
|
+
def __init__(self, d_model, d_head, n_head):
|
|
196
|
+
"""Interpretable MultiHead Attention
|
|
197
|
+
|
|
198
|
+
Similar to canonical MultiHead Attention with Query-Keys-Value structure
|
|
199
|
+
Particularities are:
|
|
200
|
+
- Only one common "Value"-Linear layer for all heads
|
|
201
|
+
- output of all heads are summed together and then rescaled over the number of heads
|
|
202
|
+
The final output tensor is re-embedded in the initial dimension
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
d_model (int): starting and ending dimension of the net
|
|
206
|
+
d_head (int): hidden dimension of all heads
|
|
207
|
+
n_head (int): number of heads
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
super().__init__()
|
|
211
|
+
self.d_head = d_head
|
|
212
|
+
self.n_head = n_head
|
|
213
|
+
self.Q_layers = nn.ModuleList([nn.Linear(d_model,d_head) for _ in range(n_head)])
|
|
214
|
+
self.K_layers = nn.ModuleList([nn.Linear(d_model,d_head) for _ in range(n_head)])
|
|
215
|
+
self.Softmax_layers = nn.ModuleList([nn.Softmax(dim=-1) for _ in range(n_head)])
|
|
216
|
+
self.V_layer = nn.Linear(d_model, d_head)
|
|
217
|
+
self.out_layer = nn.Linear(d_head, d_model)
|
|
218
|
+
|
|
219
|
+
def forward(self, query:torch.Tensor, key:torch.Tensor, value:torch.Tensor) -> torch.Tensor:
|
|
220
|
+
out = torch.Tensor()
|
|
221
|
+
for (q_layer, k_layer, softmax) in zip(self.Q_layers, self.K_layers, self.Softmax_layers):
|
|
222
|
+
Q = q_layer(query)
|
|
223
|
+
K = k_layer(key)
|
|
224
|
+
wei = Q @ K.transpose(-2,-1) * (self.d_head**-0.5)
|
|
225
|
+
wei = softmax(wei)
|
|
226
|
+
V = self.V_layer(value)
|
|
227
|
+
out_h = wei @ V
|
|
228
|
+
if out.shape[0]>0:
|
|
229
|
+
out = out + out_h # sum the result of the head attention
|
|
230
|
+
else:
|
|
231
|
+
out = out_h # out is not modifies/initialized yet
|
|
232
|
+
out = out / self.n_head
|
|
233
|
+
out = self.out_layer(out) # comeback to d_model dimension
|
|
234
|
+
return out
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch
|
|
3
|
+
import math
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PositionalEmbedding(nn.Module):
|
|
8
|
+
def __init__(self, d_model, max_len=5000):
|
|
9
|
+
super(PositionalEmbedding, self).__init__()
|
|
10
|
+
# Compute the positional encodings once in log space.
|
|
11
|
+
pe = torch.zeros(max_len, d_model).float()
|
|
12
|
+
pe.require_grad = False
|
|
13
|
+
|
|
14
|
+
position = torch.arange(0, max_len).float().unsqueeze(1)
|
|
15
|
+
div_term = (torch.arange(0, d_model, 2).float()
|
|
16
|
+
* -(math.log(10000.0) / d_model)).exp()
|
|
17
|
+
|
|
18
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
19
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
20
|
+
|
|
21
|
+
pe = pe.unsqueeze(0)
|
|
22
|
+
self.register_buffer('pe', pe)
|
|
23
|
+
|
|
24
|
+
def forward(self, x):
|
|
25
|
+
return self.pe[:, :x.size(1)]
|
|
26
|
+
|
|
27
|
+
class EnEmbedding(nn.Module):
|
|
28
|
+
def __init__(self, n_vars, d_model, patch_len, dropout):
|
|
29
|
+
super(EnEmbedding, self).__init__()
|
|
30
|
+
# Patching
|
|
31
|
+
self.patch_len = patch_len
|
|
32
|
+
|
|
33
|
+
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
|
|
34
|
+
self.glb_token = nn.Parameter(torch.randn(1, n_vars, 1, d_model))
|
|
35
|
+
self.position_embedding = PositionalEmbedding(d_model)
|
|
36
|
+
self.dropout = nn.Dropout(dropout)
|
|
37
|
+
|
|
38
|
+
def forward(self, x):
|
|
39
|
+
# do patching
|
|
40
|
+
n_vars = x.shape[1]
|
|
41
|
+
glb = self.glb_token.repeat((x.shape[0], 1, 1, 1))
|
|
42
|
+
|
|
43
|
+
x = x.unfold(dimension=-1, size=self.patch_len, step=self.patch_len)
|
|
44
|
+
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
|
|
45
|
+
# Input encoding
|
|
46
|
+
x = self.value_embedding(x) + self.position_embedding(x)
|
|
47
|
+
x = torch.reshape(x, (-1, n_vars, x.shape[-2], x.shape[-1]))
|
|
48
|
+
x = torch.cat([x, glb], dim=2)
|
|
49
|
+
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
|
|
50
|
+
return self.dropout(x), n_vars
|
|
51
|
+
|
|
52
|
+
class FlattenHead(nn.Module):
|
|
53
|
+
def __init__(self, n_vars, nf, target_window, head_dropout=0):
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.n_vars = n_vars
|
|
56
|
+
self.flatten = nn.Flatten(start_dim=-2)
|
|
57
|
+
self.linear = nn.Linear(nf, target_window)
|
|
58
|
+
self.dropout = nn.Dropout(head_dropout)
|
|
59
|
+
|
|
60
|
+
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
|
|
61
|
+
x = self.flatten(x)
|
|
62
|
+
x = self.linear(x)
|
|
63
|
+
x = self.dropout(x)
|
|
64
|
+
return x
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class EncoderLayer(nn.Module):
|
|
69
|
+
def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
|
|
70
|
+
dropout=0.1, activation="relu"):
|
|
71
|
+
super(EncoderLayer, self).__init__()
|
|
72
|
+
d_ff = d_ff or 4 * d_model
|
|
73
|
+
self.self_attention = self_attention
|
|
74
|
+
self.cross_attention = cross_attention
|
|
75
|
+
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
|
76
|
+
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
|
77
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
78
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
79
|
+
self.norm3 = nn.LayerNorm(d_model)
|
|
80
|
+
self.dropout = nn.Dropout(dropout)
|
|
81
|
+
self.activation = F.relu if activation == "relu" else F.gelu
|
|
82
|
+
|
|
83
|
+
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
|
|
84
|
+
B, L, D = cross.shape
|
|
85
|
+
x = x + self.dropout(self.self_attention(
|
|
86
|
+
x, x, x,
|
|
87
|
+
attn_mask=x_mask,
|
|
88
|
+
tau=tau, delta=None
|
|
89
|
+
)[0])
|
|
90
|
+
x = self.norm1(x)
|
|
91
|
+
|
|
92
|
+
x_glb_ori = x[:, -1, :].unsqueeze(1)
|
|
93
|
+
x_glb = torch.reshape(x_glb_ori, (B, -1, D))
|
|
94
|
+
x_glb_attn = self.dropout(self.cross_attention(
|
|
95
|
+
x_glb, cross, cross,
|
|
96
|
+
attn_mask=cross_mask,
|
|
97
|
+
tau=tau, delta=delta
|
|
98
|
+
)[0])
|
|
99
|
+
x_glb_attn = torch.reshape(x_glb_attn,
|
|
100
|
+
(x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2])).unsqueeze(1)
|
|
101
|
+
x_glb = x_glb_ori + x_glb_attn
|
|
102
|
+
x_glb = self.norm2(x_glb)
|
|
103
|
+
|
|
104
|
+
y = x = torch.cat([x[:, :-1, :], x_glb], dim=1)
|
|
105
|
+
|
|
106
|
+
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
|
107
|
+
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
|
108
|
+
|
|
109
|
+
return self.norm3(x + y)
|
|
110
|
+
|
|
111
|
+
class Encoder(nn.Module):
|
|
112
|
+
def __init__(self, layers, norm_layer=None, projection=None):
|
|
113
|
+
super(Encoder, self).__init__()
|
|
114
|
+
self.layers = nn.ModuleList(layers)
|
|
115
|
+
self.norm = norm_layer
|
|
116
|
+
self.projection = projection
|
|
117
|
+
|
|
118
|
+
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
|
|
119
|
+
for layer in self.layers:
|
|
120
|
+
x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)
|
|
121
|
+
|
|
122
|
+
if self.norm is not None:
|
|
123
|
+
x = self.norm(x)
|
|
124
|
+
|
|
125
|
+
if self.projection is not None:
|
|
126
|
+
x = self.projection(x)
|
|
127
|
+
return x
|
|
File without changes
|
|
File without changes
|