dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,438 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.distributions.normal import Normal
4
+ from einops import rearrange
5
+ from ..autoformer.layers import series_decomp
6
+
7
+ class encoder(nn.Module):
8
+ def __init__(self, seq_len,num_experts,hidden_size ):
9
+ super(encoder, self).__init__()
10
+ input_size = seq_len
11
+ num_experts = num_experts
12
+ encoder_hidden_size = hidden_size
13
+
14
+ self.distribution_fit = nn.Sequential(nn.Linear(input_size, encoder_hidden_size, bias=False), nn.ReLU(),
15
+ nn.Linear(encoder_hidden_size, num_experts, bias=False))
16
+
17
+ def forward(self, x):
18
+ mean = torch.mean(x, dim=-1)
19
+ out = self.distribution_fit(mean)
20
+ return out
21
+
22
+
23
+ class RevIN(nn.Module):
24
+ def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):
25
+ """
26
+ :param num_features: the number of features or channels
27
+ :param eps: a value added for numerical stability
28
+ :param affine: if True, RevIN has learnable affine parameters
29
+ """
30
+ super(RevIN, self).__init__()
31
+ self.num_features = num_features
32
+ self.eps = eps
33
+ self.affine = affine
34
+ self.subtract_last = subtract_last
35
+ if self.affine:
36
+ self._init_params()
37
+
38
+ def forward(self, x, mode: str):
39
+ if mode == 'norm':
40
+ self._get_statistics(x)
41
+ x = self._normalize(x)
42
+ elif mode == 'denorm':
43
+ x = self._denormalize(x)
44
+ else:
45
+ raise NotImplementedError
46
+ return x
47
+
48
+ def _init_params(self):
49
+ # initialize RevIN params: (C,)
50
+ self.affine_weight = nn.Parameter(torch.ones(self.num_features))
51
+ self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
52
+
53
+ def _get_statistics(self, x):
54
+ dim2reduce = tuple(range(1, x.ndim - 1))
55
+ if self.subtract_last:
56
+ self.last = x[:, -1, :].unsqueeze(1)
57
+ else:
58
+ self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
59
+ self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
60
+
61
+ def _normalize(self, x):
62
+ if self.subtract_last:
63
+ x = x - self.last
64
+ else:
65
+ x = x - self.mean
66
+ x = x / self.stdev
67
+ if self.affine:
68
+ x = x * self.affine_weight
69
+ x = x + self.affine_bias
70
+ return x
71
+
72
+ def _denormalize(self, x):
73
+ if self.affine:
74
+ x = x - self.affine_bias
75
+ x = x / (self.affine_weight + self.eps * self.eps)
76
+ x = x * self.stdev
77
+ if self.subtract_last:
78
+ x = x + self.last
79
+ else:
80
+ x = x + self.mean
81
+ return x
82
+
83
+
84
+ class Linear_extractor(nn.Module):
85
+ """
86
+ Paper link: https://arxiv.org/pdf/2205.13504.pdf
87
+ """
88
+
89
+ def __init__(self, seq_len,d_model,enc_in,CI,moving_avg, individual=False):
90
+ """
91
+ individual: Bool, whether shared model among different variates.
92
+ """
93
+ super(Linear_extractor, self).__init__()
94
+ self.seq_len = seq_len
95
+
96
+ self.pred_len = d_model
97
+ self.decompsition = series_decomp(moving_avg)
98
+ self.individual = individual
99
+ self.channels = enc_in
100
+ self.enc_in = 1 if CI else enc_in
101
+ if self.individual:
102
+ self.Linear_Seasonal = nn.ModuleList()
103
+ self.Linear_Trend = nn.ModuleList()
104
+
105
+ for i in range(self.channels):
106
+ self.Linear_Seasonal.append(
107
+ nn.Linear(self.seq_len, self.pred_len))
108
+ self.Linear_Trend.append(
109
+ nn.Linear(self.seq_len, self.pred_len))
110
+
111
+ self.Linear_Seasonal[i].weight = nn.Parameter(
112
+ (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
113
+ self.Linear_Trend[i].weight = nn.Parameter(
114
+ (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
115
+ else:
116
+ self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
117
+ self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)
118
+
119
+ self.Linear_Seasonal.weight = nn.Parameter(
120
+ (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
121
+ self.Linear_Trend.weight = nn.Parameter(
122
+ (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
123
+
124
+
125
+
126
+ def encoder(self, x):
127
+ seasonal_init, trend_init = self.decompsition(x)
128
+ seasonal_init, trend_init = seasonal_init.permute(
129
+ 0, 2, 1), trend_init.permute(0, 2, 1)
130
+ if self.individual:
131
+ seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size(1), self.pred_len],
132
+ dtype=seasonal_init.dtype).to(seasonal_init.device)
133
+ trend_output = torch.zeros([trend_init.size(0), trend_init.size(1), self.pred_len],
134
+ dtype=trend_init.dtype).to(trend_init.device)
135
+ for i in range(self.channels):
136
+ seasonal_output[:, i, :] = self.Linear_Seasonal[i](
137
+ seasonal_init[:, i, :])
138
+ trend_output[:, i, :] = self.Linear_Trend[i](
139
+ trend_init[:, i, :])
140
+ else:
141
+ seasonal_output = self.Linear_Seasonal(seasonal_init)
142
+ trend_output = self.Linear_Trend(trend_init)
143
+ x = seasonal_output + trend_output
144
+ return x.permute(0, 2, 1)
145
+
146
+ def forecast(self, x_enc):
147
+ # Encoder
148
+ return self.encoder(x_enc)
149
+
150
+
151
+ def forward(self, x_enc):
152
+ if x_enc.shape[0] == 0:
153
+ return torch.empty((0, self.pred_len, self.enc_in)).to(x_enc.device)
154
+ dec_out = self.forecast(x_enc)
155
+ return dec_out[:, -self.pred_len:, :] # [B, L, D]
156
+
157
+
158
+ class SparseDispatcher(object):
159
+ """Helper for implementing a mixture of experts.
160
+ The purpose of this class is to create input mini-batches for the
161
+ experts and to combine the results of the experts to form a unified
162
+ output tensor.
163
+ There are two functions:
164
+ dispatch - take an input Tensor and create input Tensors for each expert.
165
+ combine - take output Tensors from each expert and form a combined output
166
+ Tensor. Outputs from different experts for the same batch element are
167
+ summed together, weighted by the provided "gates".
168
+ The class is initialized with a "gates" Tensor, which specifies which
169
+ batch elements go to which experts, and the weights to use when combining
170
+ the outputs. Batch element b is sent to expert e iff gates[b, e] != 0.
171
+ The inputs and outputs are all two-dimensional [batch, depth].
172
+ Caller is responsible for collapsing additional dimensions prior to
173
+ calling this class and reshaping the output to the original shape.
174
+ See common_layers.reshape_like().
175
+ Example use:
176
+ gates: a float32 `Tensor` with shape `[batch_size, num_experts]`
177
+ inputs: a float32 `Tensor` with shape `[batch_size, input_size]`
178
+ experts: a list of length `num_experts` containing sub-networks.
179
+ dispatcher = SparseDispatcher(num_experts, gates)
180
+ expert_inputs = dispatcher.dispatch(inputs)
181
+ expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)]
182
+ outputs = dispatcher.combine(expert_outputs)
183
+ The preceding code sets the output for a particular example b to:
184
+ output[b] = Sum_i(gates[b, i] * experts[i](inputs[b]))
185
+ This class takes advantage of sparsity in the gate matrix by including in the
186
+ `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`.
187
+ """
188
+
189
+ def __init__(self, num_experts, gates):
190
+ """Create a SparseDispatcher."""
191
+
192
+ self._gates = gates
193
+ self._num_experts = num_experts
194
+ # sort experts
195
+ sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
196
+ # drop indices
197
+ _, self._expert_index = sorted_experts.split(1, dim=1)
198
+ # get according batch index for each expert
199
+ self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
200
+ # calculate num samples that each expert gets
201
+ self._part_sizes = (gates > 0).sum(0).tolist()
202
+ # expand gates to match with self._batch_index
203
+ gates_exp = gates[self._batch_index.flatten()]
204
+ self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
205
+
206
+ def dispatch(self, inp):
207
+ """Create one input Tensor for each expert.
208
+ The `Tensor` for an expert `i` contains the slices of `inp` corresponding
209
+ to the batch elements `b` where `gates[b, i] > 0`.
210
+ Args:
211
+ inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]`
212
+ Returns:
213
+ a list of `num_experts` `Tensor`s with shapes
214
+ `[expert_batch_size_i, <extra_input_dims>]`.
215
+ """
216
+
217
+ # assigns samples to experts whose gate is nonzero
218
+
219
+ # expand according to batch index so we can just split by _part_sizes
220
+ inp_exp = inp[self._batch_index].squeeze(1)
221
+ return torch.split(inp_exp, self._part_sizes, dim=0)
222
+
223
+ def combine(self, expert_out, multiply_by_gates=True):
224
+ """Sum together the expert output, weighted by the gates.
225
+ The slice corresponding to a particular batch element `b` is computed
226
+ as the sum over all experts `i` of the expert output, weighted by the
227
+ corresponding gate values. If `multiply_by_gates` is set to False, the
228
+ gate values are ignored.
229
+ Args:
230
+ expert_out: a list of `num_experts` `Tensor`s, each with shape
231
+ `[expert_batch_size_i, <extra_output_dims>]`.
232
+ multiply_by_gates: a boolean
233
+ Returns:
234
+ a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
235
+ """
236
+ # apply exp to expert outputs, so we are not longer in log space
237
+ stitched = torch.cat(expert_out, 0)
238
+ if multiply_by_gates:
239
+ # stitched = stitched.mul(self._nonzero_gates)
240
+ stitched = torch.einsum("i...,ij->i...", stitched, self._nonzero_gates)
241
+
242
+ shape = list(expert_out[-1].shape)
243
+ shape[0] = self._gates.size(0)
244
+ zeros = torch.zeros(*shape, requires_grad=True, device=stitched.device)
245
+ # combine samples that have been processed by the same k experts
246
+ combined = zeros.index_add(0, self._batch_index, stitched.float())
247
+ return combined
248
+
249
+ def expert_to_gates(self):
250
+ """Gate values corresponding to the examples in the per-expert `Tensor`s.
251
+ Returns:
252
+ a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
253
+ and shapes `[expert_batch_size_i]`
254
+ """
255
+ # split nonzero gates for each expert
256
+ return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
257
+
258
+
259
+ class Linear_extractor_cluster(nn.Module):
260
+ """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
261
+ Args:
262
+ input_size: integer - size of the input
263
+ output_size: integer - size of the input
264
+ num_experts: an integer - number of experts
265
+ hidden_size: an integer - hidden size of the experts
266
+ noisy_gating: a boolean
267
+ k: an integer - how many experts to use for each batch element
268
+ """
269
+
270
+ def __init__(self, noisy_gating,num_experts,seq_len,k,d_model,enc_in,CI,moving_avg,hidden_size ):
271
+ super(Linear_extractor_cluster, self).__init__()
272
+ self.noisy_gating = noisy_gating
273
+ self.num_experts = num_experts
274
+ self.input_size = seq_len
275
+ self.k = k
276
+ # instantiate experts
277
+ self.experts = nn.ModuleList([Linear_extractor(seq_len,d_model,enc_in,CI,moving_avg) for _ in range(self.num_experts)])
278
+ self.W_h = nn.Parameter(torch.eye(self.num_experts))
279
+ self.gate = encoder(seq_len,num_experts,hidden_size)
280
+ self.noise = encoder(seq_len,num_experts,hidden_size)
281
+
282
+ self.n_vars = enc_in
283
+ self.revin = RevIN(self.n_vars)
284
+
285
+ self.CI = CI
286
+ self.softplus = nn.Softplus()
287
+ self.softmax = nn.Softmax(1)
288
+ self.register_buffer("mean", torch.tensor([0.0]))
289
+ self.register_buffer("std", torch.tensor([1.0]))
290
+ assert self.k <= self.num_experts
291
+
292
+ def cv_squared(self, x):
293
+ """The squared coefficient of variation of a sample.
294
+ Useful as a loss to encourage a positive distribution to be more uniform.
295
+ Epsilons added for numerical stability.
296
+ Returns 0 for an empty Tensor.
297
+ Args:
298
+ x: a `Tensor`.
299
+ Returns:
300
+ a `Scalar`.
301
+ """
302
+ eps = 1e-10
303
+ # if only num_experts = 1
304
+
305
+ if x.shape[0] == 1:
306
+ return torch.tensor([0], device=x.device, dtype=x.dtype)
307
+ return x.float().var() / (x.float().mean() ** 2 + eps)
308
+
309
+ def _gates_to_load(self, gates):
310
+ """Compute the true load per expert, given the gates.
311
+ The load is the number of examples for which the corresponding gate is >0.
312
+ Args:
313
+ gates: a `Tensor` of shape [batch_size, n]
314
+ Returns:
315
+ a float32 `Tensor` of shape [n]
316
+ """
317
+ return (gates > 0).sum(0)
318
+
319
+ def _prob_in_top_k(
320
+ self, clean_values, noisy_values, noise_stddev, noisy_top_values
321
+ ):
322
+ """Helper function to NoisyTopKGating.
323
+ Computes the probability that value is in top k, given different random noise.
324
+ This gives us a way of backpropagating from a loss that balances the number
325
+ of times each expert is in the top k experts per example.
326
+ In the case of no noise, pass in None for noise_stddev, and the result will
327
+ not be differentiable.
328
+ Args:
329
+ clean_values: a `Tensor` of shape [batch, n].
330
+ noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
331
+ normally distributed noise with standard deviation noise_stddev.
332
+ noise_stddev: a `Tensor` of shape [batch, n], or None
333
+ noisy_top_values: a `Tensor` of shape [batch, m].
334
+ "values" Output of tf.top_k(noisy_top_values, m). m >= k+1
335
+ Returns:
336
+ a `Tensor` of shape [batch, n].
337
+ """
338
+ batch = clean_values.size(0)
339
+ m = noisy_top_values.size(1)
340
+ top_values_flat = noisy_top_values.flatten()
341
+
342
+ threshold_positions_if_in = (
343
+ torch.arange(batch, device=clean_values.device) * m + self.k
344
+ )
345
+ threshold_if_in = torch.unsqueeze(
346
+ torch.gather(top_values_flat, 0, threshold_positions_if_in), 1
347
+ )
348
+ is_in = torch.gt(noisy_values, threshold_if_in)
349
+ threshold_positions_if_out = threshold_positions_if_in - 1
350
+ threshold_if_out = torch.unsqueeze(
351
+ torch.gather(top_values_flat, 0, threshold_positions_if_out), 1
352
+ )
353
+ # is each value currently in the top k.
354
+ normal = Normal(self.mean, self.std)
355
+ prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev)
356
+ prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev)
357
+ prob = torch.where(is_in, prob_if_in, prob_if_out)
358
+ return prob
359
+
360
+ def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
361
+ """Noisy top-k gating.
362
+ See paper: https://arxiv.org/abs/1701.06538.
363
+ Args:
364
+ x: input Tensor with shape [batch_size, input_size]
365
+ train: a boolean - we only add noise at training time.
366
+ noise_epsilon: a float
367
+ Returns:
368
+ gates: a Tensor with shape [batch_size, num_experts]
369
+ load: a Tensor with shape [num_experts]
370
+ """
371
+ clean_logits = self.gate(x)
372
+
373
+ if self.noisy_gating and train:
374
+ raw_noise_stddev = self.noise(x)
375
+ noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon
376
+ noise = torch.randn_like(clean_logits)
377
+ noisy_logits = clean_logits + (noise * noise_stddev)
378
+ logits = noisy_logits @ self.W_h
379
+ else:
380
+ logits = clean_logits
381
+
382
+ # calculate topk + 1 that will be needed for the noisy gates
383
+ logits = self.softmax(logits)
384
+ top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
385
+ top_k_logits = top_logits[:, : self.k]
386
+ top_k_indices = top_indices[:, : self.k]
387
+ top_k_gates = top_k_logits / (
388
+ top_k_logits.sum(1, keepdim=True) + 1e-6
389
+ ) # normalization
390
+
391
+ zeros = torch.zeros_like(logits, requires_grad=True)
392
+ gates = zeros.scatter(1, top_k_indices, top_k_gates)
393
+
394
+ if self.noisy_gating and self.k < self.num_experts and train:
395
+ load = (
396
+ self._prob_in_top_k(
397
+ clean_logits, noisy_logits, noise_stddev, top_logits
398
+ )
399
+ ).sum(0)
400
+ else:
401
+ load = self._gates_to_load(gates)
402
+ return gates, load
403
+
404
+ def forward(self, x, loss_coef=1):
405
+ """Args:
406
+ x: tensor shape [batch_size, input_size]
407
+ train: a boolean scalar.
408
+ loss_coef: a scalar - multiplier on load-balancing losses
409
+
410
+ Returns:
411
+ y: a tensor with shape [batch_size, output_size].
412
+ extra_training_loss: a scalar. This should be added into the overall
413
+ training loss of the model. The backpropagation of this loss
414
+ encourages all experts to be approximately equally used across a batch.
415
+ """
416
+ gates, load = self.noisy_top_k_gating(x, self.training)
417
+ # calculate importance loss
418
+ importance = gates.sum(0)
419
+ loss = self.cv_squared(importance) + self.cv_squared(load)
420
+ loss *= loss_coef
421
+
422
+ dispatcher = SparseDispatcher(self.num_experts, gates)
423
+ if self.CI:
424
+ x_norm = rearrange(x, "(x y) l c -> x l (y c)", y=self.n_vars)
425
+ x_norm = self.revin(x_norm, "norm")
426
+ x_norm = rearrange(x_norm, "x l (y c) -> (x y) l c", y=self.n_vars)
427
+ else:
428
+ x_norm = self.revin(x, "norm")
429
+
430
+ expert_inputs = dispatcher.dispatch(x_norm)
431
+
432
+ gates = dispatcher.expert_to_gates()
433
+ expert_outputs = [
434
+ self.experts[i](expert_inputs[i]) for i in range(self.num_experts)
435
+ ]
436
+ y = dispatcher.combine(expert_outputs)
437
+
438
+ return y, loss
@@ -0,0 +1,202 @@
1
+ import torch.nn as nn
2
+ import torch
3
+ from math import sqrt
4
+
5
+ import torch.nn.functional as F
6
+ from torch.nn.functional import gumbel_softmax
7
+ import math
8
+ import torch.fft
9
+ from einops import rearrange
10
+
11
+
12
+ class EncoderLayer(nn.Module):
13
+ def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
14
+ super(EncoderLayer, self).__init__()
15
+ d_ff = d_ff or 4 * d_model
16
+ self.attention = attention
17
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
18
+ self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
19
+ self.norm1 = nn.LayerNorm(d_model)
20
+ self.norm2 = nn.LayerNorm(d_model)
21
+ self.dropout = nn.Dropout(dropout)
22
+ self.activation = F.relu if activation == "relu" else F.gelu
23
+
24
+ def forward(self, x, attn_mask=None, tau=None, delta=None):
25
+ new_x, attn = self.attention(
26
+ x, x, x,
27
+ attn_mask=attn_mask,
28
+ tau=tau, delta=delta
29
+ )
30
+ x = x + self.dropout(new_x)
31
+
32
+ y = x = self.norm1(x)
33
+ y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
34
+ y = self.dropout(self.conv2(y).transpose(-1, 1))
35
+
36
+ return self.norm2(x + y), attn
37
+
38
+
39
+ class Encoder(nn.Module):
40
+ def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
41
+ super(Encoder, self).__init__()
42
+ self.attn_layers = nn.ModuleList(attn_layers)
43
+ self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
44
+ self.norm = norm_layer
45
+
46
+ def forward(self, x, attn_mask=None, tau=None, delta=None):
47
+ # x [B, L, D]
48
+ attns = []
49
+ if self.conv_layers is not None:
50
+ for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
51
+ delta = delta if i == 0 else None
52
+ x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
53
+ x = conv_layer(x)
54
+ attns.append(attn)
55
+ x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
56
+ attns.append(attn)
57
+ else:
58
+ for attn_layer in self.attn_layers:
59
+ x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
60
+ attns.append(attn)
61
+
62
+ if self.norm is not None:
63
+ x = self.norm(x)
64
+
65
+ return x, attns
66
+
67
+
68
+ class FullAttention(nn.Module):
69
+ def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
70
+ super(FullAttention, self).__init__()
71
+ self.scale = scale
72
+ self.mask_flag = mask_flag
73
+ self.output_attention = output_attention
74
+ self.dropout = nn.Dropout(attention_dropout)
75
+
76
+ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
77
+ B, L, H, E = queries.shape
78
+ _, S, _, D = values.shape
79
+ scale = self.scale or 1. / sqrt(E)
80
+
81
+ scores = torch.einsum("blhe,bshe->bhls", queries, keys)
82
+
83
+ # if self.mask_flag:
84
+ # large_negative = -math.log(1e10)
85
+ # attention_mask = torch.where(attn_mask == 0, torch.tensor(large_negative), attn_mask)
86
+ #
87
+ # scores = scores * attention_mask
88
+ if self.mask_flag:
89
+ large_negative = -math.log(1e10)
90
+ attention_mask = torch.where(attn_mask == 0, large_negative, 0)
91
+
92
+ scores = scores * attn_mask + attention_mask
93
+
94
+ A = self.dropout(torch.softmax(scale * scores, dim=-1))
95
+ V = torch.einsum("bhls,bshd->blhd", A, values)
96
+
97
+ if self.output_attention:
98
+ return V.contiguous(), A
99
+ else:
100
+ return V.contiguous(), None
101
+
102
+
103
+ class AttentionLayer(nn.Module):
104
+ def __init__(self, attention, d_model, n_heads, d_keys=None,
105
+ d_values=None):
106
+ super(AttentionLayer, self).__init__()
107
+
108
+ d_keys = d_keys or (d_model // n_heads)
109
+ d_values = d_values or (d_model // n_heads)
110
+
111
+ self.inner_attention = attention
112
+ self.query_projection = nn.Linear(d_model, d_keys * n_heads)
113
+ self.key_projection = nn.Linear(d_model, d_keys * n_heads)
114
+ self.value_projection = nn.Linear(d_model, d_values * n_heads)
115
+ self.out_projection = nn.Linear(d_values * n_heads, d_model)
116
+ self.n_heads = n_heads
117
+
118
+ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
119
+ B, L, _ = queries.shape
120
+ _, S, _ = keys.shape
121
+ H = self.n_heads
122
+
123
+ queries = self.query_projection(queries).view(B, L, H, -1)
124
+ keys = self.key_projection(keys).view(B, S, H, -1)
125
+ values = self.value_projection(values).view(B, S, H, -1)
126
+
127
+ out, attn = self.inner_attention(
128
+ queries,
129
+ keys,
130
+ values,
131
+ attn_mask,
132
+ tau=tau,
133
+ delta=delta
134
+ )
135
+ out = out.view(B, L, -1)
136
+
137
+ return self.out_projection(out), attn
138
+
139
+
140
+ class Mahalanobis_mask(nn.Module):
141
+ def __init__(self, input_size):
142
+ super(Mahalanobis_mask, self).__init__()
143
+ frequency_size = input_size // 2 + 1
144
+ self.A = nn.Parameter(torch.randn(frequency_size, frequency_size), requires_grad=True)
145
+
146
+ def calculate_prob_distance(self, X):
147
+ XF = torch.abs(torch.fft.rfft(X, dim=-1))
148
+ X1 = XF.unsqueeze(2)
149
+ X2 = XF.unsqueeze(1)
150
+
151
+ # B x C x C x D
152
+ diff = X1 - X2
153
+
154
+ temp = torch.einsum("dk,bxck->bxcd", self.A, diff)
155
+
156
+ dist = torch.einsum("bxcd,bxcd->bxc", temp, temp)
157
+
158
+ # exp_dist = torch.exp(-dist)
159
+ exp_dist = 1 / (dist + 1e-10)
160
+ # 对角线置零
161
+
162
+ identity_matrices = 1 - torch.eye(exp_dist.shape[-1])
163
+ mask = identity_matrices.repeat(exp_dist.shape[0], 1, 1).to(exp_dist.device)
164
+ exp_dist = torch.einsum("bxc,bxc->bxc", exp_dist, mask)
165
+ exp_max, _ = torch.max(exp_dist, dim=-1, keepdim=True)
166
+ exp_max = exp_max.detach()
167
+
168
+ # B x C x C
169
+ p = exp_dist / exp_max
170
+
171
+ identity_matrices = torch.eye(p.shape[-1])
172
+ p1 = torch.einsum("bxc,bxc->bxc", p, mask)
173
+
174
+ diag = identity_matrices.repeat(p.shape[0], 1, 1).to(p.device)
175
+ p = (p1 + diag) * 0.99
176
+
177
+ return p
178
+
179
+ def bernoulli_gumbel_rsample(self, distribution_matrix):
180
+ b, c, d = distribution_matrix.shape
181
+
182
+ flatten_matrix = rearrange(distribution_matrix, 'b c d -> (b c d) 1')
183
+ r_flatten_matrix = 1 - flatten_matrix
184
+
185
+ log_flatten_matrix = torch.log(flatten_matrix / r_flatten_matrix)
186
+ log_r_flatten_matrix = torch.log(r_flatten_matrix / flatten_matrix)
187
+
188
+ new_matrix = torch.concat([log_flatten_matrix, log_r_flatten_matrix], dim=-1)
189
+ resample_matrix = gumbel_softmax(new_matrix, hard=True)
190
+
191
+ resample_matrix = rearrange(resample_matrix[..., 0], '(b c d) -> b c d', b=b, c=c, d=d)
192
+ return resample_matrix
193
+
194
+ def forward(self, X):
195
+ p = self.calculate_prob_distance(X)
196
+
197
+ # bernoulli中两个通道有关系的概率
198
+ sample = self.bernoulli_gumbel_rsample(p)
199
+
200
+ mask = sample.unsqueeze(1)
201
+ cnt = torch.sum(mask, dim=-1)
202
+ return mask
File without changes