dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from .attn import TwoStageAttentionLayer
|
|
5
|
+
from math import ceil
|
|
6
|
+
|
|
7
|
+
class SegMerging(nn.Module):
|
|
8
|
+
'''
|
|
9
|
+
Segment Merging Layer.
|
|
10
|
+
The adjacent `win_size' segments in each dimension will be merged into one segment to
|
|
11
|
+
get representation of a coarser scale
|
|
12
|
+
we set win_size = 2 in our paper
|
|
13
|
+
'''
|
|
14
|
+
def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.d_model = d_model
|
|
17
|
+
self.win_size = win_size
|
|
18
|
+
self.linear_trans = nn.Linear(win_size * d_model, d_model)
|
|
19
|
+
self.norm = norm_layer(win_size * d_model)
|
|
20
|
+
|
|
21
|
+
def forward(self, x):
|
|
22
|
+
"""
|
|
23
|
+
x: B, ts_d, L, d_model
|
|
24
|
+
"""
|
|
25
|
+
batch_size, ts_d, seg_num, d_model = x.shape
|
|
26
|
+
pad_num = seg_num % self.win_size
|
|
27
|
+
#import pdb
|
|
28
|
+
#pdb.set_trace()
|
|
29
|
+
if pad_num != 0:
|
|
30
|
+
pad_num = self.win_size - pad_num
|
|
31
|
+
x = torch.cat((x, x[:, :, -pad_num:, :]), dim = -2)
|
|
32
|
+
|
|
33
|
+
seg_to_merge = []
|
|
34
|
+
for i in range(self.win_size):
|
|
35
|
+
seg_to_merge.append(x[:, :, i::self.win_size, :])
|
|
36
|
+
|
|
37
|
+
x = torch.cat(seg_to_merge, -1) # [B, ts_d, seg_num/win_size, win_size*d_model]
|
|
38
|
+
|
|
39
|
+
x = self.norm(x)
|
|
40
|
+
x = self.linear_trans(x)
|
|
41
|
+
|
|
42
|
+
return x
|
|
43
|
+
|
|
44
|
+
class scale_block(nn.Module):
|
|
45
|
+
'''
|
|
46
|
+
We can use one segment merging layer followed by multiple TSA layers in each scale
|
|
47
|
+
the parameter `depth' determines the number of TSA layers used in each scale
|
|
48
|
+
We set depth = 1 in the paper
|
|
49
|
+
'''
|
|
50
|
+
def __init__(self, win_size, d_model, n_heads, d_ff, depth, dropout, \
|
|
51
|
+
seg_num = 10, factor=10):
|
|
52
|
+
super(scale_block, self).__init__()
|
|
53
|
+
|
|
54
|
+
if (win_size > 1):
|
|
55
|
+
self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm)
|
|
56
|
+
else:
|
|
57
|
+
self.merge_layer = None
|
|
58
|
+
|
|
59
|
+
self.encode_layers = nn.ModuleList()
|
|
60
|
+
|
|
61
|
+
for i in range(depth):
|
|
62
|
+
self.encode_layers.append(TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, \
|
|
63
|
+
d_ff, dropout))
|
|
64
|
+
|
|
65
|
+
def forward(self, x):
|
|
66
|
+
_, ts_dim, _, _ = x.shape
|
|
67
|
+
|
|
68
|
+
if self.merge_layer is not None:
|
|
69
|
+
x = self.merge_layer(x)
|
|
70
|
+
|
|
71
|
+
for layer in self.encode_layers:
|
|
72
|
+
x = layer(x)
|
|
73
|
+
|
|
74
|
+
return x
|
|
75
|
+
|
|
76
|
+
class Encoder(nn.Module):
|
|
77
|
+
'''
|
|
78
|
+
The Encoder of Crossformer.
|
|
79
|
+
'''
|
|
80
|
+
def __init__(self, e_blocks, win_size, d_model, n_heads, d_ff, block_depth, dropout,
|
|
81
|
+
in_seg_num = 10, factor=10):
|
|
82
|
+
super(Encoder, self).__init__()
|
|
83
|
+
self.encode_blocks = nn.ModuleList()
|
|
84
|
+
|
|
85
|
+
self.encode_blocks.append(scale_block(1, d_model, n_heads, d_ff, block_depth, dropout,\
|
|
86
|
+
in_seg_num, factor))
|
|
87
|
+
for i in range(1, e_blocks):
|
|
88
|
+
self.encode_blocks.append(scale_block(win_size, d_model, n_heads, d_ff, block_depth, dropout,\
|
|
89
|
+
ceil(in_seg_num/win_size**i), factor))
|
|
90
|
+
|
|
91
|
+
def forward(self, x):
|
|
92
|
+
encode_x = []
|
|
93
|
+
encode_x.append(x)
|
|
94
|
+
|
|
95
|
+
for block in self.encode_blocks:
|
|
96
|
+
x = block(x)
|
|
97
|
+
encode_x.append(x)
|
|
98
|
+
|
|
99
|
+
return encode_x
|
|
File without changes
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# -*-Encoding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Authors:
|
|
4
|
+
Li,Yan (liyan22021121@gmail.com)
|
|
5
|
+
"""
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from functools import partial
|
|
9
|
+
from inspect import isfunction
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps):
|
|
17
|
+
if beta_schedule == 'quad':
|
|
18
|
+
betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2
|
|
19
|
+
elif beta_schedule == 'linear':
|
|
20
|
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
|
21
|
+
elif beta_schedule == 'const':
|
|
22
|
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
|
23
|
+
elif beta_schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
|
24
|
+
betas = 1. / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
|
|
25
|
+
else:
|
|
26
|
+
raise NotImplementedError(beta_schedule)
|
|
27
|
+
assert betas.shape == (num_diffusion_timesteps,)
|
|
28
|
+
return betas
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def default(val, d):
|
|
32
|
+
if val is not None:
|
|
33
|
+
return val
|
|
34
|
+
return d() if isfunction(d) else d
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def extract(a, t, x_shape):
|
|
38
|
+
#print(a.shape, t.shape)
|
|
39
|
+
b, *_ = t.shape
|
|
40
|
+
out = a.gather(-1, t)
|
|
41
|
+
#print(out.shape)
|
|
42
|
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def noise_like(shape, device, repeat=False):
|
|
46
|
+
def repeat_noise():
|
|
47
|
+
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
|
48
|
+
def noise():
|
|
49
|
+
return torch.randn(shape, device=device)
|
|
50
|
+
return repeat_noise() if repeat else noise()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class GaussianDiffusion(nn.Module):
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
bvae,
|
|
57
|
+
input_size,
|
|
58
|
+
beta_start=0,
|
|
59
|
+
beta_end=0.1,
|
|
60
|
+
diff_steps=100,
|
|
61
|
+
loss_type="l2",
|
|
62
|
+
betas=None,
|
|
63
|
+
scale = 0.1,
|
|
64
|
+
beta_schedule="linear",
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
Params:
|
|
68
|
+
bave: The bidirectional vae model.
|
|
69
|
+
beta_start: The start value of the beta schedule.
|
|
70
|
+
beta_end: The end value of the beta schedule.
|
|
71
|
+
beta_schedule: the kind of the beta schedule, here are fixed to linear, you can adjust it as needed.
|
|
72
|
+
diff_steps: The maximum diffusion steps.
|
|
73
|
+
scale: scale parameters for the target time series.
|
|
74
|
+
"""
|
|
75
|
+
super().__init__()
|
|
76
|
+
self.generative = bvae
|
|
77
|
+
self.scale = scale
|
|
78
|
+
self.beta_start = beta_start
|
|
79
|
+
self.beta_end = beta_end
|
|
80
|
+
betas = get_beta_schedule(beta_schedule, beta_start, beta_end, diff_steps)
|
|
81
|
+
alphas = 1.0 - betas
|
|
82
|
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
83
|
+
|
|
84
|
+
alphas_target = 1.0 - betas*scale
|
|
85
|
+
alphas_target_cumprod = np.cumprod(alphas_target, axis=0)
|
|
86
|
+
self.alphas_target = alphas_target
|
|
87
|
+
self.alphas_target_cumprod = alphas_target_cumprod
|
|
88
|
+
|
|
89
|
+
(timesteps,) = betas.shape
|
|
90
|
+
self.num_timesteps = int(timesteps)
|
|
91
|
+
self.loss_type = loss_type
|
|
92
|
+
|
|
93
|
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
|
94
|
+
self.register_buffer("betas", to_torch(betas))
|
|
95
|
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
|
96
|
+
|
|
97
|
+
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
|
98
|
+
self.register_buffer("sqrt_alphas_target_cumprod", to_torch(np.sqrt(alphas_target_cumprod)))
|
|
99
|
+
self.register_buffer(
|
|
100
|
+
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
|
101
|
+
)
|
|
102
|
+
self.register_buffer(
|
|
103
|
+
"sqrt_one_minus_alphas_target_cumprod", to_torch(np.sqrt(1.0 - alphas_target_cumprod))
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def q_sample(self, x_start, t, noise=None):
|
|
107
|
+
"""
|
|
108
|
+
Diffuse the initial input.
|
|
109
|
+
:param x_start: [B, T, *]
|
|
110
|
+
:return: [B, T, *]
|
|
111
|
+
"""
|
|
112
|
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
113
|
+
return (
|
|
114
|
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
115
|
+
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def q_sample_target(self, y_target, t, noise=None):
|
|
119
|
+
"""
|
|
120
|
+
Diffuse the target.
|
|
121
|
+
:param y_target: [B1, T1, *]
|
|
122
|
+
:return: (tensor) [B1, T1, *]
|
|
123
|
+
"""
|
|
124
|
+
noise = default(noise, lambda: torch.randn_like(y_target))
|
|
125
|
+
|
|
126
|
+
return (
|
|
127
|
+
extract(self.sqrt_alphas_target_cumprod, t, y_target.shape) * y_target
|
|
128
|
+
+ extract(self.sqrt_one_minus_alphas_target_cumprod, t, y_target.shape) * noise
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def p_losses(self, x_start, y_target, t, noise=None, noise1=None):
|
|
132
|
+
"""
|
|
133
|
+
Put the diffused input into the BVAE to generate the output.
|
|
134
|
+
Params
|
|
135
|
+
:param x_start: [B, T, *]
|
|
136
|
+
:param y_target: [B1, T1, *]
|
|
137
|
+
:param t: [B,]
|
|
138
|
+
-----------------------
|
|
139
|
+
Return
|
|
140
|
+
:return output: the distribution of generative results.
|
|
141
|
+
:return y_noisy: diffused target.
|
|
142
|
+
:return total_c: the total correlations of latent variables in BVAE.
|
|
143
|
+
:return all_z: all latent variables of BVAE.
|
|
144
|
+
"""
|
|
145
|
+
B, T, _ = x_start.shape
|
|
146
|
+
B1, T1, _ = y_target.shape
|
|
147
|
+
x_start = x_start.reshape(B, 1, T, -1)
|
|
148
|
+
y_target = y_target.reshape(B1, 1, T1, -1)
|
|
149
|
+
|
|
150
|
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
151
|
+
noise1 = default(noise1, lambda: torch.randn_like(y_target))
|
|
152
|
+
|
|
153
|
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise.to(x_start.device))
|
|
154
|
+
|
|
155
|
+
y_noisy = self.q_sample_target(y_target=y_target, t=t, noise=noise1.to(y_target.device))
|
|
156
|
+
x_noisy = x_noisy.reshape(B,1, T,-1)
|
|
157
|
+
|
|
158
|
+
y_noisy = y_noisy.reshape(B1,1, T1,-1)
|
|
159
|
+
|
|
160
|
+
logits, total_c, all_z = self.generative(x_noisy) ##forward dell'encoder
|
|
161
|
+
|
|
162
|
+
output = self.generative.decoder_output(logits)
|
|
163
|
+
return output, y_noisy, total_c, all_z
|
|
164
|
+
|
|
165
|
+
def log_prob(self, x_input, y_target, time):
|
|
166
|
+
output, y_noisy, total_c, all_z = self.p_losses(
|
|
167
|
+
x_input, y_target, time,
|
|
168
|
+
)
|
|
169
|
+
return output, y_noisy, total_c, all_z
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# -*-Encoding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Authors:
|
|
4
|
+
Li,Yan (liyan22021121@gmail.com)
|
|
5
|
+
"""
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import math
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PositionalEmbedding(nn.Module):
|
|
12
|
+
def __init__(self, d_model, max_len=5000):
|
|
13
|
+
super(PositionalEmbedding, self).__init__()
|
|
14
|
+
# Compute the positional encodings once in log space.
|
|
15
|
+
pe = torch.zeros(max_len, d_model).float()
|
|
16
|
+
pe.require_grad = False
|
|
17
|
+
|
|
18
|
+
position = torch.arange(0, max_len).float().unsqueeze(1)
|
|
19
|
+
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
|
|
20
|
+
|
|
21
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
22
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
23
|
+
|
|
24
|
+
pe = pe.unsqueeze(0)
|
|
25
|
+
self.register_buffer('pe', pe)
|
|
26
|
+
|
|
27
|
+
def forward(self, x):
|
|
28
|
+
return self.pe[:, :x.size(1)]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TokenEmbedding(nn.Module):
|
|
32
|
+
def __init__(self, c_in, d_model):
|
|
33
|
+
super(TokenEmbedding, self).__init__()
|
|
34
|
+
padding = 1 if torch.__version__ >= '1.5.0' else 2
|
|
35
|
+
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
|
36
|
+
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
|
37
|
+
for m in self.modules():
|
|
38
|
+
if isinstance(m, nn.Conv1d):
|
|
39
|
+
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
|
|
40
|
+
|
|
41
|
+
def forward(self, x):
|
|
42
|
+
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
|
|
43
|
+
return x
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TemporalEmbedding(nn.Module):
|
|
47
|
+
def __init__(self, d_model, freq='h'):
|
|
48
|
+
super(TemporalEmbedding, self).__init__()
|
|
49
|
+
|
|
50
|
+
minute_size = 4
|
|
51
|
+
hour_size = 24
|
|
52
|
+
weekday_size = 7
|
|
53
|
+
day_size = 32
|
|
54
|
+
month_size = 13
|
|
55
|
+
|
|
56
|
+
# Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding
|
|
57
|
+
Embed = nn.Embedding
|
|
58
|
+
if freq == 't':
|
|
59
|
+
self.minute_embed = Embed(minute_size, d_model)
|
|
60
|
+
self.fc = nn.Linear(5*d_model, d_model)
|
|
61
|
+
else:
|
|
62
|
+
self.fc = nn.Linear(4*d_model, d_model)
|
|
63
|
+
self.hour_embed = Embed(hour_size, d_model)
|
|
64
|
+
self.weekday_embed = Embed(weekday_size, d_model)
|
|
65
|
+
self.day_embed = Embed(day_size, d_model)
|
|
66
|
+
self.month_embed = Embed(month_size, d_model)
|
|
67
|
+
|
|
68
|
+
def forward(self, x):
|
|
69
|
+
x = x.long()
|
|
70
|
+
minute_x = self.minute_embed(x[:,:,4]) if hasattr(self, 'minute_embed') else 0.
|
|
71
|
+
hour_x = self.hour_embed(x[:,:,3])
|
|
72
|
+
weekday_x = self.weekday_embed(x[:,:,2])
|
|
73
|
+
day_x = self.day_embed(x[:,:,1])
|
|
74
|
+
month_x = self.month_embed(x[:,:,0])
|
|
75
|
+
if hasattr(self, 'minute_embed'):
|
|
76
|
+
out = torch.cat((minute_x, hour_x, weekday_x, day_x, month_x), dim=2)
|
|
77
|
+
else:
|
|
78
|
+
out = torch.cat((hour_x, weekday_x, day_x, month_x), dim=2)
|
|
79
|
+
# print(out.shape)
|
|
80
|
+
out = self.fc(out)
|
|
81
|
+
# print(out.shape)
|
|
82
|
+
return out
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class DataEmbedding(nn.Module):
|
|
86
|
+
def __init__(self, c_in, d_model, embs, dropout=0.1):
|
|
87
|
+
super(DataEmbedding, self).__init__()
|
|
88
|
+
|
|
89
|
+
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
|
90
|
+
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
|
91
|
+
#self.temporal_embedding = TemporalEmbedding(d_model=d_model, freq=freq)
|
|
92
|
+
self.emb_list = nn.ModuleList()
|
|
93
|
+
if embs is not None:
|
|
94
|
+
for k in embs:
|
|
95
|
+
self.emb_list.append(nn.Embedding(k+1,d_model))
|
|
96
|
+
|
|
97
|
+
self.dropout = nn.Dropout(p=dropout)
|
|
98
|
+
|
|
99
|
+
def forward(self, x, x_mark):
|
|
100
|
+
tot = None
|
|
101
|
+
for i in range(len(self.emb_list)):
|
|
102
|
+
if tot is None:
|
|
103
|
+
tot = self.emb_list[i](x_mark[:,:,i])
|
|
104
|
+
else:
|
|
105
|
+
tot += self.emb_list[i](x_mark[:,:,i])
|
|
106
|
+
|
|
107
|
+
x = self.value_embedding(x) + tot + self.position_embedding(x)
|
|
108
|
+
return self.dropout(x)
|