dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torch.distributions.normal import Normal
|
|
4
|
+
from einops import rearrange
|
|
5
|
+
from ..autoformer.layers import series_decomp
|
|
6
|
+
|
|
7
|
+
class encoder(nn.Module):
|
|
8
|
+
def __init__(self, seq_len,num_experts,hidden_size ):
|
|
9
|
+
super(encoder, self).__init__()
|
|
10
|
+
input_size = seq_len
|
|
11
|
+
num_experts = num_experts
|
|
12
|
+
encoder_hidden_size = hidden_size
|
|
13
|
+
|
|
14
|
+
self.distribution_fit = nn.Sequential(nn.Linear(input_size, encoder_hidden_size, bias=False), nn.ReLU(),
|
|
15
|
+
nn.Linear(encoder_hidden_size, num_experts, bias=False))
|
|
16
|
+
|
|
17
|
+
def forward(self, x):
|
|
18
|
+
mean = torch.mean(x, dim=-1)
|
|
19
|
+
out = self.distribution_fit(mean)
|
|
20
|
+
return out
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RevIN(nn.Module):
|
|
24
|
+
def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):
|
|
25
|
+
"""
|
|
26
|
+
:param num_features: the number of features or channels
|
|
27
|
+
:param eps: a value added for numerical stability
|
|
28
|
+
:param affine: if True, RevIN has learnable affine parameters
|
|
29
|
+
"""
|
|
30
|
+
super(RevIN, self).__init__()
|
|
31
|
+
self.num_features = num_features
|
|
32
|
+
self.eps = eps
|
|
33
|
+
self.affine = affine
|
|
34
|
+
self.subtract_last = subtract_last
|
|
35
|
+
if self.affine:
|
|
36
|
+
self._init_params()
|
|
37
|
+
|
|
38
|
+
def forward(self, x, mode: str):
|
|
39
|
+
if mode == 'norm':
|
|
40
|
+
self._get_statistics(x)
|
|
41
|
+
x = self._normalize(x)
|
|
42
|
+
elif mode == 'denorm':
|
|
43
|
+
x = self._denormalize(x)
|
|
44
|
+
else:
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
return x
|
|
47
|
+
|
|
48
|
+
def _init_params(self):
|
|
49
|
+
# initialize RevIN params: (C,)
|
|
50
|
+
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
|
51
|
+
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
|
52
|
+
|
|
53
|
+
def _get_statistics(self, x):
|
|
54
|
+
dim2reduce = tuple(range(1, x.ndim - 1))
|
|
55
|
+
if self.subtract_last:
|
|
56
|
+
self.last = x[:, -1, :].unsqueeze(1)
|
|
57
|
+
else:
|
|
58
|
+
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
|
59
|
+
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
|
60
|
+
|
|
61
|
+
def _normalize(self, x):
|
|
62
|
+
if self.subtract_last:
|
|
63
|
+
x = x - self.last
|
|
64
|
+
else:
|
|
65
|
+
x = x - self.mean
|
|
66
|
+
x = x / self.stdev
|
|
67
|
+
if self.affine:
|
|
68
|
+
x = x * self.affine_weight
|
|
69
|
+
x = x + self.affine_bias
|
|
70
|
+
return x
|
|
71
|
+
|
|
72
|
+
def _denormalize(self, x):
|
|
73
|
+
if self.affine:
|
|
74
|
+
x = x - self.affine_bias
|
|
75
|
+
x = x / (self.affine_weight + self.eps * self.eps)
|
|
76
|
+
x = x * self.stdev
|
|
77
|
+
if self.subtract_last:
|
|
78
|
+
x = x + self.last
|
|
79
|
+
else:
|
|
80
|
+
x = x + self.mean
|
|
81
|
+
return x
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class Linear_extractor(nn.Module):
|
|
85
|
+
"""
|
|
86
|
+
Paper link: https://arxiv.org/pdf/2205.13504.pdf
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, seq_len,d_model,enc_in,CI,moving_avg, individual=False):
|
|
90
|
+
"""
|
|
91
|
+
individual: Bool, whether shared model among different variates.
|
|
92
|
+
"""
|
|
93
|
+
super(Linear_extractor, self).__init__()
|
|
94
|
+
self.seq_len = seq_len
|
|
95
|
+
|
|
96
|
+
self.pred_len = d_model
|
|
97
|
+
self.decompsition = series_decomp(moving_avg)
|
|
98
|
+
self.individual = individual
|
|
99
|
+
self.channels = enc_in
|
|
100
|
+
self.enc_in = 1 if CI else enc_in
|
|
101
|
+
if self.individual:
|
|
102
|
+
self.Linear_Seasonal = nn.ModuleList()
|
|
103
|
+
self.Linear_Trend = nn.ModuleList()
|
|
104
|
+
|
|
105
|
+
for i in range(self.channels):
|
|
106
|
+
self.Linear_Seasonal.append(
|
|
107
|
+
nn.Linear(self.seq_len, self.pred_len))
|
|
108
|
+
self.Linear_Trend.append(
|
|
109
|
+
nn.Linear(self.seq_len, self.pred_len))
|
|
110
|
+
|
|
111
|
+
self.Linear_Seasonal[i].weight = nn.Parameter(
|
|
112
|
+
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
|
|
113
|
+
self.Linear_Trend[i].weight = nn.Parameter(
|
|
114
|
+
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
|
|
115
|
+
else:
|
|
116
|
+
self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
|
|
117
|
+
self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)
|
|
118
|
+
|
|
119
|
+
self.Linear_Seasonal.weight = nn.Parameter(
|
|
120
|
+
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
|
|
121
|
+
self.Linear_Trend.weight = nn.Parameter(
|
|
122
|
+
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def encoder(self, x):
|
|
127
|
+
seasonal_init, trend_init = self.decompsition(x)
|
|
128
|
+
seasonal_init, trend_init = seasonal_init.permute(
|
|
129
|
+
0, 2, 1), trend_init.permute(0, 2, 1)
|
|
130
|
+
if self.individual:
|
|
131
|
+
seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size(1), self.pred_len],
|
|
132
|
+
dtype=seasonal_init.dtype).to(seasonal_init.device)
|
|
133
|
+
trend_output = torch.zeros([trend_init.size(0), trend_init.size(1), self.pred_len],
|
|
134
|
+
dtype=trend_init.dtype).to(trend_init.device)
|
|
135
|
+
for i in range(self.channels):
|
|
136
|
+
seasonal_output[:, i, :] = self.Linear_Seasonal[i](
|
|
137
|
+
seasonal_init[:, i, :])
|
|
138
|
+
trend_output[:, i, :] = self.Linear_Trend[i](
|
|
139
|
+
trend_init[:, i, :])
|
|
140
|
+
else:
|
|
141
|
+
seasonal_output = self.Linear_Seasonal(seasonal_init)
|
|
142
|
+
trend_output = self.Linear_Trend(trend_init)
|
|
143
|
+
x = seasonal_output + trend_output
|
|
144
|
+
return x.permute(0, 2, 1)
|
|
145
|
+
|
|
146
|
+
def forecast(self, x_enc):
|
|
147
|
+
# Encoder
|
|
148
|
+
return self.encoder(x_enc)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def forward(self, x_enc):
|
|
152
|
+
if x_enc.shape[0] == 0:
|
|
153
|
+
return torch.empty((0, self.pred_len, self.enc_in)).to(x_enc.device)
|
|
154
|
+
dec_out = self.forecast(x_enc)
|
|
155
|
+
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class SparseDispatcher(object):
|
|
159
|
+
"""Helper for implementing a mixture of experts.
|
|
160
|
+
The purpose of this class is to create input mini-batches for the
|
|
161
|
+
experts and to combine the results of the experts to form a unified
|
|
162
|
+
output tensor.
|
|
163
|
+
There are two functions:
|
|
164
|
+
dispatch - take an input Tensor and create input Tensors for each expert.
|
|
165
|
+
combine - take output Tensors from each expert and form a combined output
|
|
166
|
+
Tensor. Outputs from different experts for the same batch element are
|
|
167
|
+
summed together, weighted by the provided "gates".
|
|
168
|
+
The class is initialized with a "gates" Tensor, which specifies which
|
|
169
|
+
batch elements go to which experts, and the weights to use when combining
|
|
170
|
+
the outputs. Batch element b is sent to expert e iff gates[b, e] != 0.
|
|
171
|
+
The inputs and outputs are all two-dimensional [batch, depth].
|
|
172
|
+
Caller is responsible for collapsing additional dimensions prior to
|
|
173
|
+
calling this class and reshaping the output to the original shape.
|
|
174
|
+
See common_layers.reshape_like().
|
|
175
|
+
Example use:
|
|
176
|
+
gates: a float32 `Tensor` with shape `[batch_size, num_experts]`
|
|
177
|
+
inputs: a float32 `Tensor` with shape `[batch_size, input_size]`
|
|
178
|
+
experts: a list of length `num_experts` containing sub-networks.
|
|
179
|
+
dispatcher = SparseDispatcher(num_experts, gates)
|
|
180
|
+
expert_inputs = dispatcher.dispatch(inputs)
|
|
181
|
+
expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)]
|
|
182
|
+
outputs = dispatcher.combine(expert_outputs)
|
|
183
|
+
The preceding code sets the output for a particular example b to:
|
|
184
|
+
output[b] = Sum_i(gates[b, i] * experts[i](inputs[b]))
|
|
185
|
+
This class takes advantage of sparsity in the gate matrix by including in the
|
|
186
|
+
`Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`.
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
def __init__(self, num_experts, gates):
|
|
190
|
+
"""Create a SparseDispatcher."""
|
|
191
|
+
|
|
192
|
+
self._gates = gates
|
|
193
|
+
self._num_experts = num_experts
|
|
194
|
+
# sort experts
|
|
195
|
+
sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
|
|
196
|
+
# drop indices
|
|
197
|
+
_, self._expert_index = sorted_experts.split(1, dim=1)
|
|
198
|
+
# get according batch index for each expert
|
|
199
|
+
self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
|
|
200
|
+
# calculate num samples that each expert gets
|
|
201
|
+
self._part_sizes = (gates > 0).sum(0).tolist()
|
|
202
|
+
# expand gates to match with self._batch_index
|
|
203
|
+
gates_exp = gates[self._batch_index.flatten()]
|
|
204
|
+
self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
|
|
205
|
+
|
|
206
|
+
def dispatch(self, inp):
|
|
207
|
+
"""Create one input Tensor for each expert.
|
|
208
|
+
The `Tensor` for an expert `i` contains the slices of `inp` corresponding
|
|
209
|
+
to the batch elements `b` where `gates[b, i] > 0`.
|
|
210
|
+
Args:
|
|
211
|
+
inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]`
|
|
212
|
+
Returns:
|
|
213
|
+
a list of `num_experts` `Tensor`s with shapes
|
|
214
|
+
`[expert_batch_size_i, <extra_input_dims>]`.
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
# assigns samples to experts whose gate is nonzero
|
|
218
|
+
|
|
219
|
+
# expand according to batch index so we can just split by _part_sizes
|
|
220
|
+
inp_exp = inp[self._batch_index].squeeze(1)
|
|
221
|
+
return torch.split(inp_exp, self._part_sizes, dim=0)
|
|
222
|
+
|
|
223
|
+
def combine(self, expert_out, multiply_by_gates=True):
|
|
224
|
+
"""Sum together the expert output, weighted by the gates.
|
|
225
|
+
The slice corresponding to a particular batch element `b` is computed
|
|
226
|
+
as the sum over all experts `i` of the expert output, weighted by the
|
|
227
|
+
corresponding gate values. If `multiply_by_gates` is set to False, the
|
|
228
|
+
gate values are ignored.
|
|
229
|
+
Args:
|
|
230
|
+
expert_out: a list of `num_experts` `Tensor`s, each with shape
|
|
231
|
+
`[expert_batch_size_i, <extra_output_dims>]`.
|
|
232
|
+
multiply_by_gates: a boolean
|
|
233
|
+
Returns:
|
|
234
|
+
a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
|
|
235
|
+
"""
|
|
236
|
+
# apply exp to expert outputs, so we are not longer in log space
|
|
237
|
+
stitched = torch.cat(expert_out, 0)
|
|
238
|
+
if multiply_by_gates:
|
|
239
|
+
# stitched = stitched.mul(self._nonzero_gates)
|
|
240
|
+
stitched = torch.einsum("i...,ij->i...", stitched, self._nonzero_gates)
|
|
241
|
+
|
|
242
|
+
shape = list(expert_out[-1].shape)
|
|
243
|
+
shape[0] = self._gates.size(0)
|
|
244
|
+
zeros = torch.zeros(*shape, requires_grad=True, device=stitched.device)
|
|
245
|
+
# combine samples that have been processed by the same k experts
|
|
246
|
+
combined = zeros.index_add(0, self._batch_index, stitched.float())
|
|
247
|
+
return combined
|
|
248
|
+
|
|
249
|
+
def expert_to_gates(self):
|
|
250
|
+
"""Gate values corresponding to the examples in the per-expert `Tensor`s.
|
|
251
|
+
Returns:
|
|
252
|
+
a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
|
|
253
|
+
and shapes `[expert_batch_size_i]`
|
|
254
|
+
"""
|
|
255
|
+
# split nonzero gates for each expert
|
|
256
|
+
return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class Linear_extractor_cluster(nn.Module):
|
|
260
|
+
"""Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
|
|
261
|
+
Args:
|
|
262
|
+
input_size: integer - size of the input
|
|
263
|
+
output_size: integer - size of the input
|
|
264
|
+
num_experts: an integer - number of experts
|
|
265
|
+
hidden_size: an integer - hidden size of the experts
|
|
266
|
+
noisy_gating: a boolean
|
|
267
|
+
k: an integer - how many experts to use for each batch element
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
def __init__(self, noisy_gating,num_experts,seq_len,k,d_model,enc_in,CI,moving_avg,hidden_size ):
|
|
271
|
+
super(Linear_extractor_cluster, self).__init__()
|
|
272
|
+
self.noisy_gating = noisy_gating
|
|
273
|
+
self.num_experts = num_experts
|
|
274
|
+
self.input_size = seq_len
|
|
275
|
+
self.k = k
|
|
276
|
+
# instantiate experts
|
|
277
|
+
self.experts = nn.ModuleList([Linear_extractor(seq_len,d_model,enc_in,CI,moving_avg) for _ in range(self.num_experts)])
|
|
278
|
+
self.W_h = nn.Parameter(torch.eye(self.num_experts))
|
|
279
|
+
self.gate = encoder(seq_len,num_experts,hidden_size)
|
|
280
|
+
self.noise = encoder(seq_len,num_experts,hidden_size)
|
|
281
|
+
|
|
282
|
+
self.n_vars = enc_in
|
|
283
|
+
self.revin = RevIN(self.n_vars)
|
|
284
|
+
|
|
285
|
+
self.CI = CI
|
|
286
|
+
self.softplus = nn.Softplus()
|
|
287
|
+
self.softmax = nn.Softmax(1)
|
|
288
|
+
self.register_buffer("mean", torch.tensor([0.0]))
|
|
289
|
+
self.register_buffer("std", torch.tensor([1.0]))
|
|
290
|
+
assert self.k <= self.num_experts
|
|
291
|
+
|
|
292
|
+
def cv_squared(self, x):
|
|
293
|
+
"""The squared coefficient of variation of a sample.
|
|
294
|
+
Useful as a loss to encourage a positive distribution to be more uniform.
|
|
295
|
+
Epsilons added for numerical stability.
|
|
296
|
+
Returns 0 for an empty Tensor.
|
|
297
|
+
Args:
|
|
298
|
+
x: a `Tensor`.
|
|
299
|
+
Returns:
|
|
300
|
+
a `Scalar`.
|
|
301
|
+
"""
|
|
302
|
+
eps = 1e-10
|
|
303
|
+
# if only num_experts = 1
|
|
304
|
+
|
|
305
|
+
if x.shape[0] == 1:
|
|
306
|
+
return torch.tensor([0], device=x.device, dtype=x.dtype)
|
|
307
|
+
return x.float().var() / (x.float().mean() ** 2 + eps)
|
|
308
|
+
|
|
309
|
+
def _gates_to_load(self, gates):
|
|
310
|
+
"""Compute the true load per expert, given the gates.
|
|
311
|
+
The load is the number of examples for which the corresponding gate is >0.
|
|
312
|
+
Args:
|
|
313
|
+
gates: a `Tensor` of shape [batch_size, n]
|
|
314
|
+
Returns:
|
|
315
|
+
a float32 `Tensor` of shape [n]
|
|
316
|
+
"""
|
|
317
|
+
return (gates > 0).sum(0)
|
|
318
|
+
|
|
319
|
+
def _prob_in_top_k(
|
|
320
|
+
self, clean_values, noisy_values, noise_stddev, noisy_top_values
|
|
321
|
+
):
|
|
322
|
+
"""Helper function to NoisyTopKGating.
|
|
323
|
+
Computes the probability that value is in top k, given different random noise.
|
|
324
|
+
This gives us a way of backpropagating from a loss that balances the number
|
|
325
|
+
of times each expert is in the top k experts per example.
|
|
326
|
+
In the case of no noise, pass in None for noise_stddev, and the result will
|
|
327
|
+
not be differentiable.
|
|
328
|
+
Args:
|
|
329
|
+
clean_values: a `Tensor` of shape [batch, n].
|
|
330
|
+
noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
|
|
331
|
+
normally distributed noise with standard deviation noise_stddev.
|
|
332
|
+
noise_stddev: a `Tensor` of shape [batch, n], or None
|
|
333
|
+
noisy_top_values: a `Tensor` of shape [batch, m].
|
|
334
|
+
"values" Output of tf.top_k(noisy_top_values, m). m >= k+1
|
|
335
|
+
Returns:
|
|
336
|
+
a `Tensor` of shape [batch, n].
|
|
337
|
+
"""
|
|
338
|
+
batch = clean_values.size(0)
|
|
339
|
+
m = noisy_top_values.size(1)
|
|
340
|
+
top_values_flat = noisy_top_values.flatten()
|
|
341
|
+
|
|
342
|
+
threshold_positions_if_in = (
|
|
343
|
+
torch.arange(batch, device=clean_values.device) * m + self.k
|
|
344
|
+
)
|
|
345
|
+
threshold_if_in = torch.unsqueeze(
|
|
346
|
+
torch.gather(top_values_flat, 0, threshold_positions_if_in), 1
|
|
347
|
+
)
|
|
348
|
+
is_in = torch.gt(noisy_values, threshold_if_in)
|
|
349
|
+
threshold_positions_if_out = threshold_positions_if_in - 1
|
|
350
|
+
threshold_if_out = torch.unsqueeze(
|
|
351
|
+
torch.gather(top_values_flat, 0, threshold_positions_if_out), 1
|
|
352
|
+
)
|
|
353
|
+
# is each value currently in the top k.
|
|
354
|
+
normal = Normal(self.mean, self.std)
|
|
355
|
+
prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev)
|
|
356
|
+
prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev)
|
|
357
|
+
prob = torch.where(is_in, prob_if_in, prob_if_out)
|
|
358
|
+
return prob
|
|
359
|
+
|
|
360
|
+
def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
|
|
361
|
+
"""Noisy top-k gating.
|
|
362
|
+
See paper: https://arxiv.org/abs/1701.06538.
|
|
363
|
+
Args:
|
|
364
|
+
x: input Tensor with shape [batch_size, input_size]
|
|
365
|
+
train: a boolean - we only add noise at training time.
|
|
366
|
+
noise_epsilon: a float
|
|
367
|
+
Returns:
|
|
368
|
+
gates: a Tensor with shape [batch_size, num_experts]
|
|
369
|
+
load: a Tensor with shape [num_experts]
|
|
370
|
+
"""
|
|
371
|
+
clean_logits = self.gate(x)
|
|
372
|
+
|
|
373
|
+
if self.noisy_gating and train:
|
|
374
|
+
raw_noise_stddev = self.noise(x)
|
|
375
|
+
noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon
|
|
376
|
+
noise = torch.randn_like(clean_logits)
|
|
377
|
+
noisy_logits = clean_logits + (noise * noise_stddev)
|
|
378
|
+
logits = noisy_logits @ self.W_h
|
|
379
|
+
else:
|
|
380
|
+
logits = clean_logits
|
|
381
|
+
|
|
382
|
+
# calculate topk + 1 that will be needed for the noisy gates
|
|
383
|
+
logits = self.softmax(logits)
|
|
384
|
+
top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
|
|
385
|
+
top_k_logits = top_logits[:, : self.k]
|
|
386
|
+
top_k_indices = top_indices[:, : self.k]
|
|
387
|
+
top_k_gates = top_k_logits / (
|
|
388
|
+
top_k_logits.sum(1, keepdim=True) + 1e-6
|
|
389
|
+
) # normalization
|
|
390
|
+
|
|
391
|
+
zeros = torch.zeros_like(logits, requires_grad=True)
|
|
392
|
+
gates = zeros.scatter(1, top_k_indices, top_k_gates)
|
|
393
|
+
|
|
394
|
+
if self.noisy_gating and self.k < self.num_experts and train:
|
|
395
|
+
load = (
|
|
396
|
+
self._prob_in_top_k(
|
|
397
|
+
clean_logits, noisy_logits, noise_stddev, top_logits
|
|
398
|
+
)
|
|
399
|
+
).sum(0)
|
|
400
|
+
else:
|
|
401
|
+
load = self._gates_to_load(gates)
|
|
402
|
+
return gates, load
|
|
403
|
+
|
|
404
|
+
def forward(self, x, loss_coef=1):
|
|
405
|
+
"""Args:
|
|
406
|
+
x: tensor shape [batch_size, input_size]
|
|
407
|
+
train: a boolean scalar.
|
|
408
|
+
loss_coef: a scalar - multiplier on load-balancing losses
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
y: a tensor with shape [batch_size, output_size].
|
|
412
|
+
extra_training_loss: a scalar. This should be added into the overall
|
|
413
|
+
training loss of the model. The backpropagation of this loss
|
|
414
|
+
encourages all experts to be approximately equally used across a batch.
|
|
415
|
+
"""
|
|
416
|
+
gates, load = self.noisy_top_k_gating(x, self.training)
|
|
417
|
+
# calculate importance loss
|
|
418
|
+
importance = gates.sum(0)
|
|
419
|
+
loss = self.cv_squared(importance) + self.cv_squared(load)
|
|
420
|
+
loss *= loss_coef
|
|
421
|
+
|
|
422
|
+
dispatcher = SparseDispatcher(self.num_experts, gates)
|
|
423
|
+
if self.CI:
|
|
424
|
+
x_norm = rearrange(x, "(x y) l c -> x l (y c)", y=self.n_vars)
|
|
425
|
+
x_norm = self.revin(x_norm, "norm")
|
|
426
|
+
x_norm = rearrange(x_norm, "x l (y c) -> (x y) l c", y=self.n_vars)
|
|
427
|
+
else:
|
|
428
|
+
x_norm = self.revin(x, "norm")
|
|
429
|
+
|
|
430
|
+
expert_inputs = dispatcher.dispatch(x_norm)
|
|
431
|
+
|
|
432
|
+
gates = dispatcher.expert_to_gates()
|
|
433
|
+
expert_outputs = [
|
|
434
|
+
self.experts[i](expert_inputs[i]) for i in range(self.num_experts)
|
|
435
|
+
]
|
|
436
|
+
y = dispatcher.combine(expert_outputs)
|
|
437
|
+
|
|
438
|
+
return y, loss
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch
|
|
3
|
+
from math import sqrt
|
|
4
|
+
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch.nn.functional import gumbel_softmax
|
|
7
|
+
import math
|
|
8
|
+
import torch.fft
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EncoderLayer(nn.Module):
|
|
13
|
+
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
|
|
14
|
+
super(EncoderLayer, self).__init__()
|
|
15
|
+
d_ff = d_ff or 4 * d_model
|
|
16
|
+
self.attention = attention
|
|
17
|
+
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
|
18
|
+
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
|
19
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
20
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
21
|
+
self.dropout = nn.Dropout(dropout)
|
|
22
|
+
self.activation = F.relu if activation == "relu" else F.gelu
|
|
23
|
+
|
|
24
|
+
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
|
25
|
+
new_x, attn = self.attention(
|
|
26
|
+
x, x, x,
|
|
27
|
+
attn_mask=attn_mask,
|
|
28
|
+
tau=tau, delta=delta
|
|
29
|
+
)
|
|
30
|
+
x = x + self.dropout(new_x)
|
|
31
|
+
|
|
32
|
+
y = x = self.norm1(x)
|
|
33
|
+
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
|
34
|
+
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
|
35
|
+
|
|
36
|
+
return self.norm2(x + y), attn
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Encoder(nn.Module):
|
|
40
|
+
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
|
41
|
+
super(Encoder, self).__init__()
|
|
42
|
+
self.attn_layers = nn.ModuleList(attn_layers)
|
|
43
|
+
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
|
|
44
|
+
self.norm = norm_layer
|
|
45
|
+
|
|
46
|
+
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
|
47
|
+
# x [B, L, D]
|
|
48
|
+
attns = []
|
|
49
|
+
if self.conv_layers is not None:
|
|
50
|
+
for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
|
|
51
|
+
delta = delta if i == 0 else None
|
|
52
|
+
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
|
53
|
+
x = conv_layer(x)
|
|
54
|
+
attns.append(attn)
|
|
55
|
+
x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
|
|
56
|
+
attns.append(attn)
|
|
57
|
+
else:
|
|
58
|
+
for attn_layer in self.attn_layers:
|
|
59
|
+
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
|
60
|
+
attns.append(attn)
|
|
61
|
+
|
|
62
|
+
if self.norm is not None:
|
|
63
|
+
x = self.norm(x)
|
|
64
|
+
|
|
65
|
+
return x, attns
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class FullAttention(nn.Module):
|
|
69
|
+
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
|
|
70
|
+
super(FullAttention, self).__init__()
|
|
71
|
+
self.scale = scale
|
|
72
|
+
self.mask_flag = mask_flag
|
|
73
|
+
self.output_attention = output_attention
|
|
74
|
+
self.dropout = nn.Dropout(attention_dropout)
|
|
75
|
+
|
|
76
|
+
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
|
77
|
+
B, L, H, E = queries.shape
|
|
78
|
+
_, S, _, D = values.shape
|
|
79
|
+
scale = self.scale or 1. / sqrt(E)
|
|
80
|
+
|
|
81
|
+
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
|
82
|
+
|
|
83
|
+
# if self.mask_flag:
|
|
84
|
+
# large_negative = -math.log(1e10)
|
|
85
|
+
# attention_mask = torch.where(attn_mask == 0, torch.tensor(large_negative), attn_mask)
|
|
86
|
+
#
|
|
87
|
+
# scores = scores * attention_mask
|
|
88
|
+
if self.mask_flag:
|
|
89
|
+
large_negative = -math.log(1e10)
|
|
90
|
+
attention_mask = torch.where(attn_mask == 0, large_negative, 0)
|
|
91
|
+
|
|
92
|
+
scores = scores * attn_mask + attention_mask
|
|
93
|
+
|
|
94
|
+
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
|
95
|
+
V = torch.einsum("bhls,bshd->blhd", A, values)
|
|
96
|
+
|
|
97
|
+
if self.output_attention:
|
|
98
|
+
return V.contiguous(), A
|
|
99
|
+
else:
|
|
100
|
+
return V.contiguous(), None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class AttentionLayer(nn.Module):
|
|
104
|
+
def __init__(self, attention, d_model, n_heads, d_keys=None,
|
|
105
|
+
d_values=None):
|
|
106
|
+
super(AttentionLayer, self).__init__()
|
|
107
|
+
|
|
108
|
+
d_keys = d_keys or (d_model // n_heads)
|
|
109
|
+
d_values = d_values or (d_model // n_heads)
|
|
110
|
+
|
|
111
|
+
self.inner_attention = attention
|
|
112
|
+
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
|
113
|
+
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
|
114
|
+
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
|
115
|
+
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
|
116
|
+
self.n_heads = n_heads
|
|
117
|
+
|
|
118
|
+
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
|
119
|
+
B, L, _ = queries.shape
|
|
120
|
+
_, S, _ = keys.shape
|
|
121
|
+
H = self.n_heads
|
|
122
|
+
|
|
123
|
+
queries = self.query_projection(queries).view(B, L, H, -1)
|
|
124
|
+
keys = self.key_projection(keys).view(B, S, H, -1)
|
|
125
|
+
values = self.value_projection(values).view(B, S, H, -1)
|
|
126
|
+
|
|
127
|
+
out, attn = self.inner_attention(
|
|
128
|
+
queries,
|
|
129
|
+
keys,
|
|
130
|
+
values,
|
|
131
|
+
attn_mask,
|
|
132
|
+
tau=tau,
|
|
133
|
+
delta=delta
|
|
134
|
+
)
|
|
135
|
+
out = out.view(B, L, -1)
|
|
136
|
+
|
|
137
|
+
return self.out_projection(out), attn
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class Mahalanobis_mask(nn.Module):
|
|
141
|
+
def __init__(self, input_size):
|
|
142
|
+
super(Mahalanobis_mask, self).__init__()
|
|
143
|
+
frequency_size = input_size // 2 + 1
|
|
144
|
+
self.A = nn.Parameter(torch.randn(frequency_size, frequency_size), requires_grad=True)
|
|
145
|
+
|
|
146
|
+
def calculate_prob_distance(self, X):
|
|
147
|
+
XF = torch.abs(torch.fft.rfft(X, dim=-1))
|
|
148
|
+
X1 = XF.unsqueeze(2)
|
|
149
|
+
X2 = XF.unsqueeze(1)
|
|
150
|
+
|
|
151
|
+
# B x C x C x D
|
|
152
|
+
diff = X1 - X2
|
|
153
|
+
|
|
154
|
+
temp = torch.einsum("dk,bxck->bxcd", self.A, diff)
|
|
155
|
+
|
|
156
|
+
dist = torch.einsum("bxcd,bxcd->bxc", temp, temp)
|
|
157
|
+
|
|
158
|
+
# exp_dist = torch.exp(-dist)
|
|
159
|
+
exp_dist = 1 / (dist + 1e-10)
|
|
160
|
+
# 对角线置零
|
|
161
|
+
|
|
162
|
+
identity_matrices = 1 - torch.eye(exp_dist.shape[-1])
|
|
163
|
+
mask = identity_matrices.repeat(exp_dist.shape[0], 1, 1).to(exp_dist.device)
|
|
164
|
+
exp_dist = torch.einsum("bxc,bxc->bxc", exp_dist, mask)
|
|
165
|
+
exp_max, _ = torch.max(exp_dist, dim=-1, keepdim=True)
|
|
166
|
+
exp_max = exp_max.detach()
|
|
167
|
+
|
|
168
|
+
# B x C x C
|
|
169
|
+
p = exp_dist / exp_max
|
|
170
|
+
|
|
171
|
+
identity_matrices = torch.eye(p.shape[-1])
|
|
172
|
+
p1 = torch.einsum("bxc,bxc->bxc", p, mask)
|
|
173
|
+
|
|
174
|
+
diag = identity_matrices.repeat(p.shape[0], 1, 1).to(p.device)
|
|
175
|
+
p = (p1 + diag) * 0.99
|
|
176
|
+
|
|
177
|
+
return p
|
|
178
|
+
|
|
179
|
+
def bernoulli_gumbel_rsample(self, distribution_matrix):
|
|
180
|
+
b, c, d = distribution_matrix.shape
|
|
181
|
+
|
|
182
|
+
flatten_matrix = rearrange(distribution_matrix, 'b c d -> (b c d) 1')
|
|
183
|
+
r_flatten_matrix = 1 - flatten_matrix
|
|
184
|
+
|
|
185
|
+
log_flatten_matrix = torch.log(flatten_matrix / r_flatten_matrix)
|
|
186
|
+
log_r_flatten_matrix = torch.log(r_flatten_matrix / flatten_matrix)
|
|
187
|
+
|
|
188
|
+
new_matrix = torch.concat([log_flatten_matrix, log_r_flatten_matrix], dim=-1)
|
|
189
|
+
resample_matrix = gumbel_softmax(new_matrix, hard=True)
|
|
190
|
+
|
|
191
|
+
resample_matrix = rearrange(resample_matrix[..., 0], '(b c d) -> b c d', b=b, c=c, d=d)
|
|
192
|
+
return resample_matrix
|
|
193
|
+
|
|
194
|
+
def forward(self, X):
|
|
195
|
+
p = self.calculate_prob_distance(X)
|
|
196
|
+
|
|
197
|
+
# bernoulli中两个通道有关系的概率
|
|
198
|
+
sample = self.bernoulli_gumbel_rsample(p)
|
|
199
|
+
|
|
200
|
+
mask = sample.unsqueeze(1)
|
|
201
|
+
cnt = torch.sum(mask, dim=-1)
|
|
202
|
+
return mask
|
|
File without changes
|