dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,185 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ import numpy as np
6
+
7
+ from math import sqrt
8
+
9
+
10
+ class TriangularCausalMask():
11
+ def __init__(self, B, L,device):
12
+ mask_shape = [B, 1, L, L]
13
+ with torch.no_grad():
14
+ self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
15
+
16
+ @property
17
+ def mask(self):
18
+ return self._mask
19
+
20
+ class ProbMask():
21
+ def __init__(self, B, H, L, index, scores,device):
22
+ _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).triu(1).to(device)
23
+ _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
24
+ indicator = _mask_ex[torch.arange(B)[:, None, None],torch.arange(H)[None, :, None],index, :]
25
+ self._mask = indicator.view(scores.shape)
26
+
27
+ @property
28
+ def mask(self):
29
+ return self._mask
30
+
31
+
32
+ class FullAttention(nn.Module):
33
+ def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
34
+ super(FullAttention, self).__init__()
35
+ self.scale = scale
36
+ self.mask_flag = mask_flag
37
+ self.output_attention = output_attention
38
+ self.dropout = nn.Dropout(attention_dropout)
39
+
40
+ def forward(self, queries, keys, values, attn_mask):
41
+ B, L, H, E = queries.shape
42
+ _, S, _, D = values.shape
43
+ scale = self.scale or 1./sqrt(E)
44
+
45
+ scores = torch.einsum("blhe,bshe->bhls", queries, keys)
46
+ if self.mask_flag:
47
+ if attn_mask is None:
48
+ attn_mask = TriangularCausalMask(B, L, device=queries.device)
49
+
50
+ scores.masked_fill_(attn_mask.mask, -np.inf)
51
+
52
+ A = self.dropout(torch.softmax(scale * scores, dim=-1))
53
+ V = torch.einsum("bhls,bshd->blhd", A, values)
54
+
55
+ if self.output_attention:
56
+ return (V.contiguous(), A)
57
+ else:
58
+ return (V.contiguous(), None)
59
+
60
+ class ProbAttention(nn.Module):
61
+ def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
62
+ super(ProbAttention, self).__init__()
63
+ self.factor = factor
64
+ self.scale = scale
65
+ self.mask_flag = mask_flag
66
+ self.output_attention = output_attention
67
+ self.dropout = nn.Dropout(attention_dropout)
68
+
69
+ def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
70
+ # Q [B, H, L, D]
71
+ B, H, L_K, E = K.shape
72
+ _, _, L_Q, _ = Q.shape
73
+
74
+ # calculate the sampled Q_K
75
+ K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
76
+ index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
77
+ K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
78
+ Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)
79
+
80
+ # find the Top_k query with sparisty measurement
81
+ M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
82
+ M_top = M.topk(n_top, sorted=False)[1]
83
+
84
+ # use the reduced Q to calculate Q_K
85
+ Q_reduce = Q[torch.arange(B)[:, None, None],
86
+ torch.arange(H)[None, :, None],
87
+ M_top, :] # factor*ln(L_q)
88
+ Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
89
+
90
+ return Q_K, M_top
91
+
92
+ def _get_initial_context(self, V, L_Q):
93
+ B, H, L_V, D = V.shape
94
+ if not self.mask_flag:
95
+ # V_sum = V.sum(dim=-2)
96
+ V_sum = V.mean(dim=-2)
97
+ contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
98
+ else: # use mask
99
+ assert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only
100
+ contex = V.cumsum(dim=-2)
101
+ return contex
102
+
103
+ def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
104
+ B, H, L_V, D = V.shape
105
+
106
+ if self.mask_flag:
107
+ attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
108
+ scores.masked_fill_(attn_mask.mask, -np.inf)
109
+
110
+ attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
111
+
112
+ context_in[torch.arange(B)[:, None, None],
113
+ torch.arange(H)[None, :, None],
114
+ index, :] = torch.matmul(attn, V).type_as(context_in)
115
+ if self.output_attention:
116
+ attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device)
117
+ attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
118
+ return (context_in, attns)
119
+ else:
120
+ return (context_in, None)
121
+
122
+ def forward(self, queries, keys, values, attn_mask):
123
+ B, L_Q, H, D = queries.shape
124
+ _, L_K, _, _ = keys.shape
125
+
126
+ queries = queries.transpose(2,1)
127
+ keys = keys.transpose(2,1)
128
+ values = values.transpose(2,1)
129
+
130
+ U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
131
+ u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
132
+
133
+ U_part = U_part if U_part<L_K else L_K
134
+ u = u if u<L_Q else L_Q
135
+
136
+ scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
137
+
138
+ # add scale factor
139
+ scale = self.scale or 1./sqrt(D)
140
+ if scale is not None:
141
+ scores_top = scores_top * scale
142
+ # get the context
143
+ context = self._get_initial_context(values, L_Q)
144
+ # update the context with selected top_k queries
145
+ context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
146
+
147
+ return context.transpose(2,1).contiguous(), attn
148
+
149
+
150
+ class AttentionLayer(nn.Module):
151
+ def __init__(self, attention, d_model, n_heads,
152
+ d_keys=None, d_values=None, mix=False):
153
+ super(AttentionLayer, self).__init__()
154
+
155
+ d_keys = d_keys or (d_model//n_heads)
156
+ d_values = d_values or (d_model//n_heads)
157
+
158
+ self.inner_attention = attention
159
+ self.query_projection = nn.Linear(d_model, d_keys * n_heads)
160
+ self.key_projection = nn.Linear(d_model, d_keys * n_heads)
161
+ self.value_projection = nn.Linear(d_model, d_values * n_heads)
162
+ self.out_projection = nn.Linear(d_values * n_heads, d_model)
163
+ self.n_heads = n_heads
164
+ self.mix = mix
165
+
166
+ def forward(self, queries, keys, values, attn_mask):
167
+ B, L, _ = queries.shape
168
+ _, S, _ = keys.shape
169
+ H = self.n_heads
170
+
171
+ queries = self.query_projection(queries).view(B, L, H, -1)
172
+ keys = self.key_projection(keys).view(B, S, H, -1)
173
+ values = self.value_projection(values).view(B, S, H, -1)
174
+
175
+ out, attn = self.inner_attention(
176
+ queries,
177
+ keys,
178
+ values,
179
+ attn_mask
180
+ )
181
+ if self.mix:
182
+ out = out.transpose(2,1).contiguous()
183
+ out = out.view(B, L, -1)
184
+
185
+ return self.out_projection(out), attn
@@ -0,0 +1,50 @@
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ class DecoderLayer(nn.Module):
5
+ def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
6
+ dropout=0.1, activation="relu"):
7
+ super(DecoderLayer, self).__init__()
8
+ d_ff = d_ff or 4*d_model
9
+ self.self_attention = self_attention
10
+ self.cross_attention = cross_attention
11
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
12
+ self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
13
+ self.norm1 = nn.LayerNorm(d_model)
14
+ self.norm2 = nn.LayerNorm(d_model)
15
+ self.norm3 = nn.LayerNorm(d_model)
16
+ self.dropout = nn.Dropout(dropout)
17
+ self.activation = F.relu if activation == "relu" else F.gelu
18
+
19
+ def forward(self, x, cross, x_mask=None, cross_mask=None):
20
+ x = x + self.dropout(self.self_attention(
21
+ x, x, x,
22
+ attn_mask=x_mask
23
+ )[0])
24
+ x = self.norm1(x)
25
+
26
+ x = x + self.dropout(self.cross_attention(
27
+ x, cross, cross,
28
+ attn_mask=cross_mask
29
+ )[0])
30
+
31
+ y = x = self.norm2(x)
32
+ y = self.dropout(self.activation(self.conv1(y.transpose(-1,1))))
33
+ y = self.dropout(self.conv2(y).transpose(-1,1))
34
+
35
+ return self.norm3(x+y)
36
+
37
+ class Decoder(nn.Module):
38
+ def __init__(self, layers, norm_layer=None):
39
+ super(Decoder, self).__init__()
40
+ self.layers = nn.ModuleList(layers)
41
+ self.norm = norm_layer
42
+
43
+ def forward(self, x, cross, x_mask=None, cross_mask=None):
44
+ for layer in self.layers:
45
+ x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
46
+
47
+ if self.norm is not None:
48
+ x = self.norm(x)
49
+
50
+ return x
@@ -0,0 +1,125 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import math
5
+
6
+ class PositionalEmbedding(nn.Module):
7
+ def __init__(self, d_model, max_len=5000):
8
+ super(PositionalEmbedding, self).__init__()
9
+ # Compute the positional encodings once in log space.
10
+ pe = torch.zeros(max_len, d_model).float()
11
+ pe.require_grad = False
12
+
13
+ position = torch.arange(0, max_len).float().unsqueeze(1)
14
+ div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
15
+
16
+ pe[:, 0::2] = torch.sin(position * div_term)
17
+ pe[:, 1::2] = torch.cos(position * div_term)
18
+
19
+ pe = pe.unsqueeze(0)
20
+ self.register_buffer('pe', pe)
21
+
22
+ def forward(self, x):
23
+ return self.pe[:, :x.size(1)]
24
+
25
+ class TokenEmbedding(nn.Module):
26
+ def __init__(self, c_in, d_model):
27
+ super(TokenEmbedding, self).__init__()
28
+ padding = 1 if torch.__version__>='1.5.0' else 2
29
+ self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
30
+ kernel_size=3, padding=padding, padding_mode='circular')
31
+ for m in self.modules():
32
+ if isinstance(m, nn.Conv1d):
33
+ nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu')
34
+
35
+ def forward(self, x):
36
+ x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2)
37
+ return x
38
+
39
+ class FixedEmbedding(nn.Module):
40
+ def __init__(self, c_in, d_model):
41
+ super(FixedEmbedding, self).__init__()
42
+
43
+ w = torch.zeros(c_in, d_model).float()
44
+ w.require_grad = False
45
+
46
+ position = torch.arange(0, c_in).float().unsqueeze(1)
47
+ div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
48
+
49
+ w[:, 0::2] = torch.sin(position * div_term)
50
+ w[:, 1::2] = torch.cos(position * div_term)
51
+
52
+ self.emb = nn.Embedding(c_in, d_model)
53
+ self.emb.weight = nn.Parameter(w, requires_grad=False)
54
+
55
+ def forward(self, x):
56
+ return self.emb(x).detach()
57
+
58
+ class TemporalEmbedding(nn.Module):
59
+ def __init__(self, d_model, embed_type='fixed', freq='h'):
60
+ super(TemporalEmbedding, self).__init__()
61
+
62
+ minute_size = 4
63
+ hour_size = 24
64
+ weekday_size = 7
65
+ day_size = 32
66
+ month_size = 13
67
+
68
+ Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding
69
+ if freq=='t':
70
+ self.minute_embed = Embed(minute_size, d_model)
71
+ self.hour_embed = Embed(hour_size, d_model)
72
+ self.weekday_embed = Embed(weekday_size, d_model)
73
+ self.day_embed = Embed(day_size, d_model)
74
+ self.month_embed = Embed(month_size, d_model)
75
+
76
+ def forward(self, x):
77
+ x = x.long()
78
+
79
+ minute_x = self.minute_embed(x[:,:,4]) if hasattr(self, 'minute_embed') else 0.
80
+ hour_x = self.hour_embed(x[:,:,3])
81
+ weekday_x = self.weekday_embed(x[:,:,2])
82
+ day_x = self.day_embed(x[:,:,1])
83
+ month_x = self.month_embed(x[:,:,0])
84
+
85
+ return hour_x + weekday_x + day_x + month_x + minute_x
86
+
87
+ class TimeFeatureEmbedding(nn.Module):
88
+ def __init__(self, d_model, embed_type='timeF', freq='h'):
89
+ super(TimeFeatureEmbedding, self).__init__()
90
+
91
+ freq_map = {'h':4, 't':5, 's':6, 'm':1, 'a':1, 'w':2, 'd':3, 'b':3}
92
+ d_inp = freq_map[freq]
93
+ self.embed = nn.Linear(d_inp, d_model)
94
+
95
+ def forward(self, x):
96
+ return self.embed(x)
97
+
98
+
99
+
100
+ class DataEmbedding(nn.Module):
101
+ def __init__(self, c_in, d_model, embs, dropout=0.1):
102
+ super(DataEmbedding, self).__init__()
103
+
104
+ self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
105
+ self.position_embedding = PositionalEmbedding(d_model=d_model)
106
+ #self.temporal_embedding = TemporalEmbedding(d_model=d_model, freq=freq)
107
+ self.emb_list = nn.ModuleList()
108
+ if embs is not None:
109
+ for k in embs:
110
+ self.emb_list.append(nn.Embedding(k+1,d_model))
111
+
112
+ self.dropout = nn.Dropout(p=dropout)
113
+
114
+ def forward(self, x, x_mark):
115
+ tot = None
116
+ for i in range(len(self.emb_list)):
117
+ if tot is None:
118
+ tot = self.emb_list[i](x_mark[:,:,i])
119
+ else:
120
+ tot += self.emb_list[i](x_mark[:,:,i])
121
+ if tot is not None:
122
+ x = self.value_embedding(x) + tot + self.position_embedding(x)
123
+ else:
124
+ x = self.value_embedding(x) + self.position_embedding(x)
125
+ return self.dropout(x)
@@ -0,0 +1,100 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ConvLayer(nn.Module):
6
+ def __init__(self, c_in):
7
+ super(ConvLayer, self).__init__()
8
+ padding = 1 if torch.__version__>='1.5.0' else 2
9
+ self.downConv = nn.Conv1d(in_channels=c_in,
10
+ out_channels=c_in,
11
+ kernel_size=3,
12
+ padding=padding,
13
+ padding_mode='circular')
14
+ self.norm = nn.BatchNorm1d(c_in)
15
+ self.activation = nn.ELU()
16
+ self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
17
+
18
+ def forward(self, x):
19
+ x = self.downConv(x.permute(0, 2, 1))
20
+ x = self.norm(x)
21
+ x = self.activation(x)
22
+ x = self.maxPool(x)
23
+ x = x.transpose(1,2)
24
+ return x
25
+
26
+ class EncoderLayer(nn.Module):
27
+ def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
28
+ super(EncoderLayer, self).__init__()
29
+ d_ff = d_ff or 4*d_model
30
+ self.attention = attention
31
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
32
+ self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
33
+ self.norm1 = nn.LayerNorm(d_model)
34
+ self.norm2 = nn.LayerNorm(d_model)
35
+ self.dropout = nn.Dropout(dropout)
36
+ self.activation = F.relu if activation == "relu" else F.gelu
37
+
38
+ def forward(self, x, attn_mask=None):
39
+ # x [B, L, D]
40
+ # x = x + self.dropout(self.attention(
41
+ # x, x, x,
42
+ # attn_mask = attn_mask
43
+ # ))
44
+ new_x, attn = self.attention(
45
+ x, x, x,
46
+ attn_mask = attn_mask
47
+ )
48
+ x = x + self.dropout(new_x)
49
+
50
+ y = x = self.norm1(x)
51
+ y = self.dropout(self.activation(self.conv1(y.transpose(-1,1))))
52
+ y = self.dropout(self.conv2(y).transpose(-1,1))
53
+
54
+ return self.norm2(x+y), attn
55
+
56
+ class Encoder(nn.Module):
57
+ def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
58
+ super(Encoder, self).__init__()
59
+ self.attn_layers = nn.ModuleList(attn_layers)
60
+ self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
61
+ self.norm = norm_layer
62
+
63
+ def forward(self, x, attn_mask=None):
64
+ # x [B, L, D]
65
+ attns = []
66
+ if self.conv_layers is not None:
67
+ for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
68
+ x, attn = attn_layer(x, attn_mask=attn_mask)
69
+ x = conv_layer(x)
70
+ attns.append(attn)
71
+ x, attn = self.attn_layers[-1](x, attn_mask=attn_mask)
72
+ attns.append(attn)
73
+ else:
74
+ for attn_layer in self.attn_layers:
75
+ x, attn = attn_layer(x, attn_mask=attn_mask)
76
+ attns.append(attn)
77
+
78
+ if self.norm is not None:
79
+ x = self.norm(x)
80
+
81
+ return x, attns
82
+
83
+ class EncoderStack(nn.Module):
84
+ def __init__(self, encoders, inp_lens):
85
+ super(EncoderStack, self).__init__()
86
+ self.encoders = nn.ModuleList(encoders)
87
+ self.inp_lens = inp_lens
88
+
89
+ def forward(self, x, attn_mask=None):
90
+ # x [B, L, D]
91
+ x_stack = []
92
+ attns = []
93
+ for i_len, encoder in zip(self.inp_lens, self.encoders):
94
+ inp_len = x.shape[1]//(2**i_len)
95
+ x_s, attn = encoder(x[:, -inp_len:, :])
96
+ x_stack.append(x_s)
97
+ attns.append(attn)
98
+ x_stack = torch.cat(x_stack, -2)
99
+
100
+ return x_stack, attns
@@ -0,0 +1,142 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+
6
+ class PositionalEmbedding(nn.Module):
7
+ def __init__(self, d_model, max_len=5000):
8
+ super(PositionalEmbedding, self).__init__()
9
+ # Compute the positional encodings once in log space.
10
+ pe = torch.zeros(max_len, d_model).float()
11
+ pe.require_grad = False
12
+
13
+ position = torch.arange(0, max_len).float().unsqueeze(1)
14
+ div_term = (torch.arange(0, d_model, 2).float()
15
+ * -(math.log(10000.0) / d_model)).exp()
16
+
17
+ pe[:, 0::2] = torch.sin(position * div_term)
18
+ pe[:, 1::2] = torch.cos(position * div_term)
19
+
20
+ pe = pe.unsqueeze(0)
21
+ self.register_buffer('pe', pe)
22
+
23
+ def forward(self, x):
24
+ return self.pe[:, :x.size(1)]
25
+
26
+
27
+ class TokenEmbedding(nn.Module):
28
+ def __init__(self, c_in, d_model):
29
+ super(TokenEmbedding, self).__init__()
30
+ padding = 1 if torch.__version__ >= '1.5.0' else 2
31
+ self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
32
+ kernel_size=3, padding=padding, padding_mode='circular', bias=False)
33
+ for m in self.modules():
34
+ if isinstance(m, nn.Conv1d):
35
+ nn.init.kaiming_normal_(
36
+ m.weight, mode='fan_in', nonlinearity='leaky_relu')
37
+
38
+ def forward(self, x):
39
+ x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
40
+ return x
41
+
42
+
43
+ class FixedEmbedding(nn.Module):
44
+ def __init__(self, c_in, d_model):
45
+ super(FixedEmbedding, self).__init__()
46
+
47
+ w = torch.zeros(c_in, d_model).float()
48
+ w.require_grad = False
49
+
50
+ position = torch.arange(0, c_in).float().unsqueeze(1)
51
+ div_term = (torch.arange(0, d_model, 2).float()
52
+ * -(math.log(10000.0) / d_model)).exp()
53
+
54
+ w[:, 0::2] = torch.sin(position * div_term)
55
+ w[:, 1::2] = torch.cos(position * div_term)
56
+
57
+ self.emb = nn.Embedding(c_in, d_model)
58
+ self.emb.weight = nn.Parameter(w, requires_grad=False)
59
+
60
+ def forward(self, x):
61
+ return self.emb(x).detach()
62
+
63
+
64
+ class TemporalEmbedding(nn.Module):
65
+ def __init__(self, d_model, embed_type='fixed', freq='h'):
66
+ super(TemporalEmbedding, self).__init__()
67
+
68
+ minute_size = 4
69
+ hour_size = 24
70
+ weekday_size = 7
71
+ day_size = 32
72
+ month_size = 13
73
+
74
+ Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
75
+ if freq == 't':
76
+ self.minute_embed = Embed(minute_size, d_model)
77
+ self.hour_embed = Embed(hour_size, d_model)
78
+ self.weekday_embed = Embed(weekday_size, d_model)
79
+ self.day_embed = Embed(day_size, d_model)
80
+ self.month_embed = Embed(month_size, d_model)
81
+
82
+ def forward(self, x):
83
+ x = x.long()
84
+ minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
85
+ self, 'minute_embed') else 0.
86
+ hour_x = self.hour_embed(x[:, :, 3])
87
+ weekday_x = self.weekday_embed(x[:, :, 2])
88
+ day_x = self.day_embed(x[:, :, 1])
89
+ month_x = self.month_embed(x[:, :, 0])
90
+
91
+ return hour_x + weekday_x + day_x + month_x + minute_x
92
+
93
+
94
+ class TimeFeatureEmbedding(nn.Module):
95
+ def __init__(self, d_model, embed_type='timeF', freq='h'):
96
+ super(TimeFeatureEmbedding, self).__init__()
97
+
98
+ freq_map = {'h': 4, 't': 5, 's': 6,
99
+ 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
100
+ d_inp = freq_map[freq]
101
+ self.embed = nn.Linear(d_inp, d_model, bias=False)
102
+
103
+ def forward(self, x):
104
+ return self.embed(x)
105
+
106
+
107
+ class DataEmbedding(nn.Module):
108
+ def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
109
+ super(DataEmbedding, self).__init__()
110
+
111
+ self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
112
+ self.position_embedding = PositionalEmbedding(d_model=d_model)
113
+ self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
114
+ freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
115
+ d_model=d_model, embed_type=embed_type, freq=freq)
116
+ self.dropout = nn.Dropout(p=dropout)
117
+
118
+ def forward(self, x, x_mark):
119
+ if x_mark is None:
120
+ x = self.value_embedding(x) + self.position_embedding(x)
121
+ else:
122
+ x = self.value_embedding(
123
+ x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
124
+ return self.dropout(x)
125
+
126
+
127
+ class DataEmbedding_inverted(nn.Module):
128
+ def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
129
+ super(DataEmbedding_inverted, self).__init__()
130
+ self.value_embedding = nn.Linear(c_in, d_model)
131
+ self.dropout = nn.Dropout(p=dropout)
132
+
133
+ def forward(self, x, x_mark):
134
+ x = x.permute(0, 2, 1)
135
+ # x: [Batch Variate Time]
136
+ if x_mark is None:
137
+ x = self.value_embedding(x)
138
+ else:
139
+ # the potential to take covariates (e.g. timestamps) as tokens
140
+ x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
141
+ # x: [Batch Variate d_model]
142
+ return self.dropout(x)