dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
dsipts/models/utils.py
ADDED
|
@@ -0,0 +1,624 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.init as init
|
|
3
|
+
from torch import nn
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numba import jit
|
|
6
|
+
from torch.autograd import Function
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss):
|
|
10
|
+
message = f'Can {"NOT" if not handle_multivariate else "" } handle multivariate output \n'\
|
|
11
|
+
f'Can {"NOT" if not handle_future_covariates else "" } handle future covariates\n'\
|
|
12
|
+
f'Can {"NOT" if not handle_categorical_variables else "" } handle categorical covariates\n'\
|
|
13
|
+
f'Can {"NOT" if not handle_quantile_loss else "" } handle Quantile loss function'
|
|
14
|
+
|
|
15
|
+
return message
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SinkhornDistance():
|
|
21
|
+
r"""
|
|
22
|
+
Given two empirical measures each with :math:`P_1` locations
|
|
23
|
+
:math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
|
|
24
|
+
outputs an approximation of the regularized OT cost for point clouds.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
eps (float): regularization coefficient
|
|
28
|
+
max_iter (int): maximum number of Sinkhorn iterations
|
|
29
|
+
reduction (string, optional): Specifies the reduction to apply to the output:
|
|
30
|
+
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
|
|
31
|
+
'mean': the sum of the output will be divided by the number of
|
|
32
|
+
elements in the output, 'sum': the output will be summed. Default: 'none'
|
|
33
|
+
|
|
34
|
+
Shape:
|
|
35
|
+
- Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
|
|
36
|
+
- Output: :math:`(N)` or :math:`()`, depending on `reduction`
|
|
37
|
+
"""
|
|
38
|
+
def __init__(self, eps, max_iter, reduction='none'):
|
|
39
|
+
super(SinkhornDistance, self).__init__()
|
|
40
|
+
self.eps = eps
|
|
41
|
+
self.max_iter = max_iter
|
|
42
|
+
self.reduction = reduction
|
|
43
|
+
|
|
44
|
+
def compute(self, x, y):
|
|
45
|
+
# The Sinkhorn algorithm takes as input three variables :
|
|
46
|
+
C = self._cost_matrix(x, y).to(x.device) # Wasserstein cost function
|
|
47
|
+
x_points = x.shape[-2]
|
|
48
|
+
y_points = y.shape[-2]
|
|
49
|
+
if x.dim() == 2:
|
|
50
|
+
batch_size = 1
|
|
51
|
+
else:
|
|
52
|
+
batch_size = x.shape[0]
|
|
53
|
+
|
|
54
|
+
# both marginals are fixed with equal weights
|
|
55
|
+
mu = torch.empty(batch_size, x_points, dtype=torch.float,
|
|
56
|
+
requires_grad=False).fill_(1.0 / x_points).squeeze().to(x.device)
|
|
57
|
+
nu = torch.empty(batch_size, y_points, dtype=torch.float,
|
|
58
|
+
requires_grad=False).fill_(1.0 / y_points).squeeze().to(x.device)
|
|
59
|
+
|
|
60
|
+
u = torch.zeros_like(mu).to(x.device)
|
|
61
|
+
v = torch.zeros_like(nu).to(x.device)
|
|
62
|
+
# To check if algorithm terminates because of threshold
|
|
63
|
+
# or max iterations reached
|
|
64
|
+
actual_nits = 0
|
|
65
|
+
# Stopping criterion
|
|
66
|
+
thresh = 1e-1
|
|
67
|
+
|
|
68
|
+
# Sinkhorn iterations
|
|
69
|
+
for i in range(self.max_iter):
|
|
70
|
+
u1 = u # useful to check the update
|
|
71
|
+
u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
|
|
72
|
+
v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
|
|
73
|
+
err = (u - u1).abs().sum(-1).mean()
|
|
74
|
+
|
|
75
|
+
actual_nits += 1
|
|
76
|
+
if err.item() < thresh:
|
|
77
|
+
break
|
|
78
|
+
|
|
79
|
+
U, V = u, v
|
|
80
|
+
# Transport plan pi = diag(a)*K*diag(b)
|
|
81
|
+
pi = torch.exp(self.M(C, U, V))
|
|
82
|
+
# Sinkhorn distance
|
|
83
|
+
cost = torch.sum(pi * C, dim=(-2, -1))
|
|
84
|
+
|
|
85
|
+
if self.reduction == 'mean':
|
|
86
|
+
cost = cost.mean()
|
|
87
|
+
elif self.reduction == 'sum':
|
|
88
|
+
cost = cost.sum()
|
|
89
|
+
|
|
90
|
+
return cost#, pi, C
|
|
91
|
+
|
|
92
|
+
def M(self, C, u, v):
|
|
93
|
+
"Modified cost for logarithmic updates"
|
|
94
|
+
"$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
|
|
95
|
+
return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def _cost_matrix(x, y, p=2):
|
|
99
|
+
"Returns the matrix of $|x_i-y_j|^p$."
|
|
100
|
+
x_col = x.unsqueeze(-2)
|
|
101
|
+
y_lin = y.unsqueeze(-3)
|
|
102
|
+
C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
|
|
103
|
+
return C
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def ave(u, u1, tau):
|
|
107
|
+
"Barycenter subroutine, used by kinetic acceleration through extrapolation."
|
|
108
|
+
return tau * u + (1 - tau) * u1
|
|
109
|
+
|
|
110
|
+
class QuantileLossMO(nn.Module):
|
|
111
|
+
"""Copied from git
|
|
112
|
+
"""
|
|
113
|
+
def __init__(self, quantiles):
|
|
114
|
+
super().__init__()
|
|
115
|
+
self.quantiles = quantiles
|
|
116
|
+
|
|
117
|
+
def forward(self, preds, target):
|
|
118
|
+
|
|
119
|
+
assert not target.requires_grad
|
|
120
|
+
assert preds.size(0) == target.size(0)
|
|
121
|
+
tot_loss = 0
|
|
122
|
+
for j in range(preds.shape[2]):
|
|
123
|
+
losses = []
|
|
124
|
+
##suppose BxLxCxMUL
|
|
125
|
+
for i, q in enumerate(self.quantiles):
|
|
126
|
+
errors = target[:,:,j] - preds[:,:,j, i]
|
|
127
|
+
|
|
128
|
+
losses.append(torch.abs(torch.max((q-1) * errors,q * errors)))
|
|
129
|
+
|
|
130
|
+
loss = torch.mean(torch.sum(torch.cat(losses, dim=1), dim=1))
|
|
131
|
+
tot_loss+=loss
|
|
132
|
+
return tot_loss/preds.shape[2]/len(self.quantiles)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class L1Loss(nn.Module):
|
|
137
|
+
"""Custom L1Loss
|
|
138
|
+
"""
|
|
139
|
+
def __init__(self):
|
|
140
|
+
super().__init__()
|
|
141
|
+
self.f = nn.L1Loss()
|
|
142
|
+
def forward(self, preds, target):
|
|
143
|
+
return self.f(preds[:,:,:,0],target)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class Permute(nn.Module):
|
|
149
|
+
def __init__(self):
|
|
150
|
+
super().__init__()
|
|
151
|
+
|
|
152
|
+
def forward(self, input):
|
|
153
|
+
return torch.permute(input,(0,2,1))
|
|
154
|
+
|
|
155
|
+
def get_activation(activation):
|
|
156
|
+
return eval(activation)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def weight_init_zeros(m):
|
|
160
|
+
|
|
161
|
+
if isinstance(m, nn.LSTM):
|
|
162
|
+
for param in m.parameters():
|
|
163
|
+
if len(param.shape) >= 2:
|
|
164
|
+
init.constant_(param.data,0.0)
|
|
165
|
+
else:
|
|
166
|
+
init.constant_(param.data,0.0)
|
|
167
|
+
elif isinstance(m, nn.Embedding):
|
|
168
|
+
init.constant_(m.weight,0.0)
|
|
169
|
+
|
|
170
|
+
elif isinstance(m, nn.LayerNorm):
|
|
171
|
+
init.zeros_(m.bias)
|
|
172
|
+
init.ones_(m.weight)
|
|
173
|
+
|
|
174
|
+
elif isinstance(m, nn.LSTMCell):
|
|
175
|
+
for param in m.parameters():
|
|
176
|
+
if len(param.shape) >= 2:
|
|
177
|
+
init.constant_(param.data,0.0)
|
|
178
|
+
else:
|
|
179
|
+
init.constant_(param.data,0.0)
|
|
180
|
+
elif isinstance(m, nn.GRU):
|
|
181
|
+
for param in m.parameters():
|
|
182
|
+
if len(param.shape) >= 2:
|
|
183
|
+
init.constant_(param.data,0.0)
|
|
184
|
+
else:
|
|
185
|
+
init.constant_(param.data,0.0)
|
|
186
|
+
for names in m._all_weights:
|
|
187
|
+
for name in filter(lambda n: "bias" in n, names):
|
|
188
|
+
bias = getattr(m, name)
|
|
189
|
+
n = bias.size(0)
|
|
190
|
+
bias.data[:n // 3].fill_(-1.)
|
|
191
|
+
elif isinstance(m, nn.GRUCell):
|
|
192
|
+
for param in m.parameters():
|
|
193
|
+
if len(param.shape) >= 2:
|
|
194
|
+
init.constant_(param.data,0.0)
|
|
195
|
+
else:
|
|
196
|
+
init.constant_(param.data,0.0)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
else:
|
|
200
|
+
try:
|
|
201
|
+
init.constant_(m.weight.data, 0.0)
|
|
202
|
+
if m.bias is not None:
|
|
203
|
+
init.constant_(m.bias.data, 0.0)
|
|
204
|
+
except:
|
|
205
|
+
pass
|
|
206
|
+
|
|
207
|
+
def weight_init(m):
|
|
208
|
+
"""
|
|
209
|
+
Usage:
|
|
210
|
+
model = Model()
|
|
211
|
+
model.apply(weight_init)
|
|
212
|
+
"""
|
|
213
|
+
if isinstance(m, nn.Conv1d):
|
|
214
|
+
init.normal_(m.weight.data)
|
|
215
|
+
if m.bias is not None:
|
|
216
|
+
init.normal_(m.bias.data)
|
|
217
|
+
elif isinstance(m, nn.Conv2d):
|
|
218
|
+
init.xavier_normal_(m.weight.data)
|
|
219
|
+
if m.bias is not None:
|
|
220
|
+
init.normal_(m.bias.data)
|
|
221
|
+
elif isinstance(m, nn.Conv3d):
|
|
222
|
+
init.xavier_normal_(m.weight.data)
|
|
223
|
+
if m.bias is not None:
|
|
224
|
+
init.normal_(m.bias.data)
|
|
225
|
+
elif isinstance(m, nn.ConvTranspose1d):
|
|
226
|
+
init.normal_(m.weight.data)
|
|
227
|
+
if m.bias is not None:
|
|
228
|
+
init.normal_(m.bias.data)
|
|
229
|
+
elif isinstance(m, nn.ConvTranspose2d):
|
|
230
|
+
init.xavier_normal_(m.weight.data)
|
|
231
|
+
if m.bias is not None:
|
|
232
|
+
init.normal_(m.bias.data)
|
|
233
|
+
elif isinstance(m, nn.ConvTranspose3d):
|
|
234
|
+
init.xavier_normal_(m.weight.data)
|
|
235
|
+
if m.bias is not None:
|
|
236
|
+
init.normal_(m.bias.data)
|
|
237
|
+
elif isinstance(m, nn.BatchNorm1d):
|
|
238
|
+
init.normal_(m.weight.data, mean=1, std=0.02)
|
|
239
|
+
init.constant_(m.bias.data, 0)
|
|
240
|
+
elif isinstance(m, nn.BatchNorm2d):
|
|
241
|
+
init.normal_(m.weight.data, mean=1, std=0.02)
|
|
242
|
+
init.constant_(m.bias.data, 0)
|
|
243
|
+
elif isinstance(m, nn.BatchNorm3d):
|
|
244
|
+
init.normal_(m.weight.data, mean=1, std=0.02)
|
|
245
|
+
init.constant_(m.bias.data, 0)
|
|
246
|
+
elif isinstance(m, nn.Linear):
|
|
247
|
+
init.xavier_normal_(m.weight.data)
|
|
248
|
+
if m.bias is not None:
|
|
249
|
+
init.normal_(m.bias.data)
|
|
250
|
+
elif isinstance(m, nn.LSTM):
|
|
251
|
+
for param in m.parameters():
|
|
252
|
+
if len(param.shape) >= 2:
|
|
253
|
+
init.orthogonal_(param.data)
|
|
254
|
+
else:
|
|
255
|
+
init.normal_(param.data)
|
|
256
|
+
elif isinstance(m, nn.LSTMCell):
|
|
257
|
+
for param in m.parameters():
|
|
258
|
+
if len(param.shape) >= 2:
|
|
259
|
+
init.orthogonal_(param.data)
|
|
260
|
+
else:
|
|
261
|
+
init.normal_(param.data)
|
|
262
|
+
elif isinstance(m, nn.GRU):
|
|
263
|
+
for param in m.parameters():
|
|
264
|
+
if len(param.shape) >= 2:
|
|
265
|
+
init.orthogonal_(param.data)
|
|
266
|
+
else:
|
|
267
|
+
init.normal_(param.data)
|
|
268
|
+
for names in m._all_weights:
|
|
269
|
+
for name in filter(lambda n: "bias" in n, names):
|
|
270
|
+
bias = getattr(m, name)
|
|
271
|
+
n = bias.size(0)
|
|
272
|
+
bias.data[:n // 3].fill_(-1.)
|
|
273
|
+
elif isinstance(m, nn.GRUCell):
|
|
274
|
+
for param in m.parameters():
|
|
275
|
+
if len(param.shape) >= 2:
|
|
276
|
+
init.orthogonal_(param.data)
|
|
277
|
+
else:
|
|
278
|
+
init.normal_(param.data)
|
|
279
|
+
|
|
280
|
+
elif isinstance(m, nn.Embedding):
|
|
281
|
+
init.normal_(m.weight, mean=0.0, std=0.02)
|
|
282
|
+
|
|
283
|
+
elif isinstance(m, nn.LayerNorm):
|
|
284
|
+
init.zeros_(m.bias)
|
|
285
|
+
init.ones_(m.weight)
|
|
286
|
+
|
|
287
|
+
# if isinstance(module, nn.Linear):
|
|
288
|
+
# torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
289
|
+
# if module.bias is not None:
|
|
290
|
+
# torch.nn.init.zeros_(module.bias)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def pairwise_distances(x, y=None):
|
|
296
|
+
'''
|
|
297
|
+
Input: x is a Nxd matrix
|
|
298
|
+
y is an optional Mxd matirx
|
|
299
|
+
Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
|
|
300
|
+
if y is not given then use 'y=x'.
|
|
301
|
+
i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
|
|
302
|
+
'''
|
|
303
|
+
x_norm = (x**2).sum(1).view(-1, 1)
|
|
304
|
+
if y is not None:
|
|
305
|
+
y_t = torch.transpose(y, 0, 1)
|
|
306
|
+
y_norm = (y**2).sum(1).view(1, -1)
|
|
307
|
+
else:
|
|
308
|
+
y_t = torch.transpose(x, 0, 1)
|
|
309
|
+
y_norm = x_norm.view(1, -1)
|
|
310
|
+
|
|
311
|
+
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
|
|
312
|
+
return torch.clamp(dist, 0.0, float('inf'))
|
|
313
|
+
|
|
314
|
+
@jit(nopython = True)
|
|
315
|
+
def compute_softdtw(D, gamma):
|
|
316
|
+
N = D.shape[0]
|
|
317
|
+
M = D.shape[1]
|
|
318
|
+
R = np.zeros((N + 2, M + 2)) + 1e8
|
|
319
|
+
R[0, 0] = 0
|
|
320
|
+
for j in range(1, M + 1):
|
|
321
|
+
for i in range(1, N + 1):
|
|
322
|
+
r0 = -R[i - 1, j - 1] / gamma
|
|
323
|
+
r1 = -R[i - 1, j] / gamma
|
|
324
|
+
r2 = -R[i, j - 1] / gamma
|
|
325
|
+
rmax = max(max(r0, r1), r2)
|
|
326
|
+
rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax)
|
|
327
|
+
softmin = - gamma * (np.log(rsum) + rmax)
|
|
328
|
+
R[i, j] = D[i - 1, j - 1] + softmin
|
|
329
|
+
return R
|
|
330
|
+
|
|
331
|
+
@jit(nopython = True)
|
|
332
|
+
def compute_softdtw_backward(D_, R, gamma):
|
|
333
|
+
N = D_.shape[0]
|
|
334
|
+
M = D_.shape[1]
|
|
335
|
+
D = np.zeros((N + 2, M + 2))
|
|
336
|
+
E = np.zeros((N + 2, M + 2))
|
|
337
|
+
D[1:N + 1, 1:M + 1] = D_
|
|
338
|
+
E[-1, -1] = 1
|
|
339
|
+
R[:, -1] = -1e8
|
|
340
|
+
R[-1, :] = -1e8
|
|
341
|
+
R[-1, -1] = R[-2, -2]
|
|
342
|
+
for j in range(M, 0, -1):
|
|
343
|
+
for i in range(N, 0, -1):
|
|
344
|
+
a0 = (R[i + 1, j] - R[i, j] - D[i + 1, j]) / gamma
|
|
345
|
+
b0 = (R[i, j + 1] - R[i, j] - D[i, j + 1]) / gamma
|
|
346
|
+
c0 = (R[i + 1, j + 1] - R[i, j] - D[i + 1, j + 1]) / gamma
|
|
347
|
+
a = np.exp(a0)
|
|
348
|
+
b = np.exp(b0)
|
|
349
|
+
c = np.exp(c0)
|
|
350
|
+
E[i, j] = E[i + 1, j] * a + E[i, j + 1] * b + E[i + 1, j + 1] * c
|
|
351
|
+
return E[1:N + 1, 1:M + 1]
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
class SoftDTWBatch(Function):
|
|
355
|
+
@staticmethod
|
|
356
|
+
def forward(ctx, D, gamma = 1.0): # D.shape: [batch_size, N , N]
|
|
357
|
+
dev = D.device
|
|
358
|
+
batch_size,N,N = D.shape
|
|
359
|
+
gamma = torch.FloatTensor([gamma]).to(dev)
|
|
360
|
+
D_ = D.detach().cpu().numpy()
|
|
361
|
+
g_ = gamma.item()
|
|
362
|
+
|
|
363
|
+
total_loss = 0
|
|
364
|
+
R = torch.zeros((batch_size, N+2 ,N+2)).to(dev)
|
|
365
|
+
for k in range(0, batch_size): # loop over all D in the batch
|
|
366
|
+
Rk = torch.FloatTensor(compute_softdtw(D_[k,:,:], g_)).to(dev)
|
|
367
|
+
R[k:k+1,:,:] = Rk
|
|
368
|
+
total_loss = total_loss + Rk[-2,-2]
|
|
369
|
+
ctx.save_for_backward(D, R, gamma)
|
|
370
|
+
return total_loss / batch_size
|
|
371
|
+
|
|
372
|
+
@staticmethod
|
|
373
|
+
def backward(ctx, grad_output):
|
|
374
|
+
dev = grad_output.device
|
|
375
|
+
D, R, gamma = ctx.saved_tensors
|
|
376
|
+
batch_size,N,N = D.shape
|
|
377
|
+
D_ = D.detach().cpu().numpy()
|
|
378
|
+
R_ = R.detach().cpu().numpy()
|
|
379
|
+
g_ = gamma.item()
|
|
380
|
+
|
|
381
|
+
E = torch.zeros((batch_size, N ,N)).to(dev)
|
|
382
|
+
for k in range(batch_size):
|
|
383
|
+
Ek = torch.FloatTensor(compute_softdtw_backward(D_[k,:,:], R_[k,:,:], g_)).to(dev)
|
|
384
|
+
E[k:k+1,:,:] = Ek
|
|
385
|
+
|
|
386
|
+
return grad_output * E, None
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
@jit(nopython = True)
|
|
393
|
+
def my_max(x, gamma):
|
|
394
|
+
# use the log-sum-exp trick
|
|
395
|
+
max_x = np.max(x)
|
|
396
|
+
exp_x = np.exp((x - max_x) / gamma)
|
|
397
|
+
Z = np.sum(exp_x)
|
|
398
|
+
return gamma * np.log(Z) + max_x, exp_x / Z
|
|
399
|
+
|
|
400
|
+
@jit(nopython = True)
|
|
401
|
+
def my_min(x,gamma) :
|
|
402
|
+
min_x, argmax_x = my_max(-x, gamma)
|
|
403
|
+
return - min_x, argmax_x
|
|
404
|
+
|
|
405
|
+
@jit(nopython = True)
|
|
406
|
+
def my_max_hessian_product(p, z, gamma):
|
|
407
|
+
return ( p * z - p * np.sum(p * z) ) /gamma
|
|
408
|
+
|
|
409
|
+
@jit(nopython = True)
|
|
410
|
+
def my_min_hessian_product(p, z, gamma):
|
|
411
|
+
return - my_max_hessian_product(p, z, gamma)
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
@jit(nopython = True)
|
|
415
|
+
def dtw_grad(theta, gamma):
|
|
416
|
+
m = theta.shape[0]
|
|
417
|
+
n = theta.shape[1]
|
|
418
|
+
V = np.zeros((m + 1, n + 1))
|
|
419
|
+
V[:, 0] = 1e10
|
|
420
|
+
V[0, :] = 1e10
|
|
421
|
+
V[0, 0] = 0
|
|
422
|
+
|
|
423
|
+
Q = np.zeros((m + 2, n + 2, 3))
|
|
424
|
+
|
|
425
|
+
for i in range(1, m + 1):
|
|
426
|
+
for j in range(1, n + 1):
|
|
427
|
+
# theta is indexed starting from 0.
|
|
428
|
+
v, Q[i, j] = my_min(np.array([V[i, j - 1],
|
|
429
|
+
V[i - 1, j - 1],
|
|
430
|
+
V[i - 1, j]]) , gamma)
|
|
431
|
+
V[i, j] = theta[i - 1, j - 1] + v
|
|
432
|
+
|
|
433
|
+
E = np.zeros((m + 2, n + 2))
|
|
434
|
+
E[m + 1, :] = 0
|
|
435
|
+
E[:, n + 1] = 0
|
|
436
|
+
E[m + 1, n + 1] = 1
|
|
437
|
+
Q[m + 1, n + 1] = 1
|
|
438
|
+
|
|
439
|
+
for i in range(m,0,-1):
|
|
440
|
+
for j in range(n,0,-1):
|
|
441
|
+
E[i, j] = Q[i, j + 1, 0] * E[i, j + 1] + \
|
|
442
|
+
Q[i + 1, j + 1, 1] * E[i + 1, j + 1] + \
|
|
443
|
+
Q[i + 1, j, 2] * E[i + 1, j]
|
|
444
|
+
|
|
445
|
+
return V[m, n], E[1:m + 1, 1:n + 1], Q, E
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
@jit(nopython = True)
|
|
449
|
+
def dtw_hessian_prod(theta, Z, Q, E, gamma):
|
|
450
|
+
m = Z.shape[0]
|
|
451
|
+
n = Z.shape[1]
|
|
452
|
+
|
|
453
|
+
V_dot = np.zeros((m + 1, n + 1))
|
|
454
|
+
V_dot[0, 0] = 0
|
|
455
|
+
|
|
456
|
+
Q_dot = np.zeros((m + 2, n + 2, 3))
|
|
457
|
+
for i in range(1, m + 1):
|
|
458
|
+
for j in range(1, n + 1):
|
|
459
|
+
# theta is indexed starting from 0.
|
|
460
|
+
V_dot[i, j] = Z[i - 1, j - 1] + \
|
|
461
|
+
Q[i, j, 0] * V_dot[i, j - 1] + \
|
|
462
|
+
Q[i, j, 1] * V_dot[i - 1, j - 1] + \
|
|
463
|
+
Q[i, j, 2] * V_dot[i - 1, j]
|
|
464
|
+
|
|
465
|
+
v = np.array([V_dot[i, j - 1], V_dot[i - 1, j - 1], V_dot[i - 1, j]])
|
|
466
|
+
Q_dot[i, j] = my_min_hessian_product(Q[i, j], v, gamma)
|
|
467
|
+
E_dot = np.zeros((m + 2, n + 2))
|
|
468
|
+
|
|
469
|
+
for j in range(n,0,-1):
|
|
470
|
+
for i in range(m,0,-1):
|
|
471
|
+
E_dot[i, j] = Q_dot[i, j + 1, 0] * E[i, j + 1] + \
|
|
472
|
+
Q[i, j + 1, 0] * E_dot[i, j + 1] + \
|
|
473
|
+
Q_dot[i + 1, j + 1, 1] * E[i + 1, j + 1] + \
|
|
474
|
+
Q[i + 1, j + 1, 1] * E_dot[i + 1, j + 1] + \
|
|
475
|
+
Q_dot[i + 1, j, 2] * E[i + 1, j] + \
|
|
476
|
+
Q[i + 1, j, 2] * E_dot[i + 1, j]
|
|
477
|
+
|
|
478
|
+
return V_dot[m, n], E_dot[1:m + 1, 1:n + 1]
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
class PathDTWBatch(Function):
|
|
482
|
+
@staticmethod
|
|
483
|
+
def forward(ctx, D, gamma): # D.shape: [batch_size, N , N]
|
|
484
|
+
batch_size,N,N = D.shape
|
|
485
|
+
device = D.device
|
|
486
|
+
D_cpu = D.detach().cpu().numpy()
|
|
487
|
+
gamma_gpu = torch.FloatTensor([gamma]).to(device)
|
|
488
|
+
|
|
489
|
+
grad_gpu = torch.zeros((batch_size, N ,N)).to(device)
|
|
490
|
+
Q_gpu = torch.zeros((batch_size, N+2 ,N+2,3)).to(device)
|
|
491
|
+
E_gpu = torch.zeros((batch_size, N+2 ,N+2)).to(device)
|
|
492
|
+
|
|
493
|
+
for k in range(0,batch_size): # loop over all D in the batch
|
|
494
|
+
_, grad_cpu_k, Q_cpu_k, E_cpu_k = dtw_grad(D_cpu[k,:,:], gamma)
|
|
495
|
+
grad_gpu[k,:,:] = torch.FloatTensor(grad_cpu_k).to(device)
|
|
496
|
+
Q_gpu[k,:,:,:] = torch.FloatTensor(Q_cpu_k).to(device)
|
|
497
|
+
E_gpu[k,:,:] = torch.FloatTensor(E_cpu_k).to(device)
|
|
498
|
+
ctx.save_for_backward(grad_gpu,D, Q_gpu ,E_gpu, gamma_gpu)
|
|
499
|
+
return torch.mean(grad_gpu, dim=0)
|
|
500
|
+
|
|
501
|
+
@staticmethod
|
|
502
|
+
def backward(ctx, grad_output):
|
|
503
|
+
device = grad_output.device
|
|
504
|
+
grad_gpu, D_gpu, Q_gpu, E_gpu, gamma = ctx.saved_tensors
|
|
505
|
+
D_cpu = D_gpu.detach().cpu().numpy()
|
|
506
|
+
Q_cpu = Q_gpu.detach().cpu().numpy()
|
|
507
|
+
E_cpu = E_gpu.detach().cpu().numpy()
|
|
508
|
+
gamma = gamma.detach().cpu().numpy()[0]
|
|
509
|
+
Z = grad_output.detach().cpu().numpy()
|
|
510
|
+
|
|
511
|
+
batch_size,N,N = D_cpu.shape
|
|
512
|
+
Hessian = torch.zeros((batch_size, N ,N)).to(device)
|
|
513
|
+
for k in range(0,batch_size):
|
|
514
|
+
_, hess_k = dtw_hessian_prod(D_cpu[k,:,:], Z, Q_cpu[k,:,:,:], E_cpu[k,:,:], gamma)
|
|
515
|
+
Hessian[k:k+1,:,:] = torch.FloatTensor(hess_k).to(device)
|
|
516
|
+
|
|
517
|
+
return Hessian, None
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
import math
|
|
522
|
+
from typing import Union
|
|
523
|
+
class Embedding_cat_variables(nn.Module):
|
|
524
|
+
def __init__(self, length: int, d_model: int, emb_dims: list,reduction_mode:str='mean',use_classical_positional_encoder:bool=False, device:str='cpu'):
|
|
525
|
+
"""
|
|
526
|
+
Embeds categorical variables with optional positional encodings.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
length (int): Sequence length (e.g., total time steps).
|
|
530
|
+
d_model (int): Output embedding dimension.
|
|
531
|
+
emb_dims (list): Vocabulary sizes for each categorical feature.
|
|
532
|
+
reduction_mode (str): 'mean', 'sum', or 'none'.
|
|
533
|
+
use_classical_positional_encoder (bool): Whether to use sinusoidal positional encoding.
|
|
534
|
+
device (str): Device name (e.g., 'cpu' or 'cuda').
|
|
535
|
+
|
|
536
|
+
Notes:
|
|
537
|
+
- If `reduction_mode` is 'none', all embeddings are concatenated.
|
|
538
|
+
- If `use_classical_positional_encoder` is True, uses fixed sin/cos encoding.
|
|
539
|
+
- If False, treats position as a categorical variable and embeds it.
|
|
540
|
+
"""
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
super().__init__()
|
|
544
|
+
self.length = length
|
|
545
|
+
self.device = device
|
|
546
|
+
self.reduction_mode = reduction_mode
|
|
547
|
+
self.emb_dims = emb_dims
|
|
548
|
+
|
|
549
|
+
self.use_classical_positional_encoder = use_classical_positional_encoder
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
if use_classical_positional_encoder:
|
|
553
|
+
pe = torch.zeros(length, d_model).to(device)
|
|
554
|
+
position = torch.arange(0, length, dtype=torch.float).unsqueeze(1).to(device)
|
|
555
|
+
|
|
556
|
+
# Compute the div_term (frequencies for sinusoids)
|
|
557
|
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).to(device)
|
|
558
|
+
div_term_odd = torch.exp(torch.arange(0, d_model-d_model%2, 2).float() * (-math.log(10000.0) / d_model)).to(device)
|
|
559
|
+
|
|
560
|
+
# Apply sine to even indices, cosine to odd indices
|
|
561
|
+
|
|
562
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
563
|
+
pe[:, 1::2] = torch.cos(position * div_term_odd)
|
|
564
|
+
## this is static positional encoder
|
|
565
|
+
self.register_buffer('pe', pe)##static
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
else:
|
|
569
|
+
self.register_buffer('pe_emb', torch.arange(0, self.length).reshape(1, -1, 1)) ##static
|
|
570
|
+
self.emb_dims = [length+1] + emb_dims
|
|
571
|
+
#otherwise we add a new embedding layer
|
|
572
|
+
|
|
573
|
+
if self.reduction_mode =='none':
|
|
574
|
+
self.output_channels = len(self.emb_dims)*d_model
|
|
575
|
+
if use_classical_positional_encoder:
|
|
576
|
+
self.output_channels+=d_model
|
|
577
|
+
else:
|
|
578
|
+
self.output_channels = d_model ## if you want to have a fixed d_model size use mean or sum strategy
|
|
579
|
+
|
|
580
|
+
##this is the core
|
|
581
|
+
self.cat_n_embd = nn.ModuleList([nn.Embedding(emb_dim, d_model) for emb_dim in self.emb_dims])
|
|
582
|
+
|
|
583
|
+
##the batch size is required in case x is None (only positional encoder)
|
|
584
|
+
def forward(self,BS:int, x: Union[torch.Tensor,None]) -> torch.Tensor:
|
|
585
|
+
|
|
586
|
+
#this is the easy part
|
|
587
|
+
if x is None:
|
|
588
|
+
if self.use_classical_positional_encoder:
|
|
589
|
+
return self.pe.repeat(BS,1,1)
|
|
590
|
+
else:
|
|
591
|
+
return self.get_cat_n_embd(self.pe_emb.repeat(BS,1,1)).squeeze(2)
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
else:
|
|
595
|
+
if self.use_classical_positional_encoder is False:
|
|
596
|
+
cat_vars = torch.cat(( self.pe_emb.repeat(BS,1,1),x), dim=2)
|
|
597
|
+
else:
|
|
598
|
+
cat_vars = x
|
|
599
|
+
#building the encoders
|
|
600
|
+
cat_n_embd = self.get_cat_n_embd(cat_vars)
|
|
601
|
+
|
|
602
|
+
if self.reduction_mode =='sum':
|
|
603
|
+
cat_n_embd = torch.sum(cat_n_embd,axis=2)
|
|
604
|
+
elif self.reduction_mode =='mean':
|
|
605
|
+
cat_n_embd = torch.mean(cat_n_embd,axis=2)
|
|
606
|
+
else:
|
|
607
|
+
cat_n_embd = cat_n_embd.reshape(BS, self.length,-1)
|
|
608
|
+
|
|
609
|
+
if self.use_classical_positional_encoder:
|
|
610
|
+
if self.reduction_mode =='none':
|
|
611
|
+
cat_n_embd = torch.cat([cat_n_embd,self.pe.repeat(BS,1,1)], 2) ##stack the positional encoder
|
|
612
|
+
else:
|
|
613
|
+
cat_n_embd = cat_n_embd+self.pe.repeat(BS,1,1) ##add the positional encoder
|
|
614
|
+
return cat_n_embd
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
##compute the target
|
|
618
|
+
def get_cat_n_embd(self, cat_vars):
|
|
619
|
+
emb = []
|
|
620
|
+
for index, layer in enumerate(self.cat_n_embd):
|
|
621
|
+
emb.append(layer(cat_vars[:, :, index]).unsqueeze(2))
|
|
622
|
+
|
|
623
|
+
cat_n_embd = torch.cat(emb,dim=2)
|
|
624
|
+
return cat_n_embd
|
|
File without changes
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from torch.nn import functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# -----------------------------------------------------------------------------
|
|
10
|
+
|
|
11
|
+
class NewGELU(nn.Module):
|
|
12
|
+
"""
|
|
13
|
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
|
|
14
|
+
Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
|
|
15
|
+
"""
|
|
16
|
+
def forward(self, x):
|
|
17
|
+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
|
18
|
+
|
|
19
|
+
class CausalSelfAttention(nn.Module):
|
|
20
|
+
"""
|
|
21
|
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
|
22
|
+
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
|
23
|
+
explicit implementation here to show that there is nothing too scary here.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, n_embd,n_head,attn_pdrop,resid_pdrop,block_size):
|
|
27
|
+
super().__init__()
|
|
28
|
+
assert n_embd % n_head == 0
|
|
29
|
+
# key, query, value projections for all heads, but in a batch
|
|
30
|
+
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
|
31
|
+
# output projection
|
|
32
|
+
self.c_proj = nn.Linear(n_embd, n_embd)
|
|
33
|
+
# regularization
|
|
34
|
+
self.attn_dropout = nn.Dropout(attn_pdrop)
|
|
35
|
+
self.resid_dropout = nn.Dropout(resid_pdrop)
|
|
36
|
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
|
37
|
+
self.register_buffer("bias", torch.tril(torch.ones(block_size, block_size))
|
|
38
|
+
.view(1, 1, block_size, block_size))
|
|
39
|
+
self.n_head = n_head
|
|
40
|
+
self.n_embd = n_embd
|
|
41
|
+
|
|
42
|
+
def forward(self, x):
|
|
43
|
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
|
44
|
+
|
|
45
|
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
46
|
+
q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
|
|
47
|
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
48
|
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
49
|
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
50
|
+
|
|
51
|
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
|
52
|
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
53
|
+
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
|
54
|
+
att = F.softmax(att, dim=-1)
|
|
55
|
+
att = self.attn_dropout(att)
|
|
56
|
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
|
57
|
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
|
58
|
+
|
|
59
|
+
# output projection
|
|
60
|
+
y = self.resid_dropout(self.c_proj(y))
|
|
61
|
+
return y
|
|
62
|
+
|
|
63
|
+
class Block(nn.Module):
|
|
64
|
+
""" an unassuming Transformer block """
|
|
65
|
+
|
|
66
|
+
def __init__(self, n_embd,resid_pdrop,n_head,attn_pdrop,block_size):
|
|
67
|
+
super().__init__()
|
|
68
|
+
self.ln_1 = nn.LayerNorm(n_embd)
|
|
69
|
+
self.attn = CausalSelfAttention(n_embd,n_head,attn_pdrop,resid_pdrop,block_size)
|
|
70
|
+
self.ln_2 = nn.LayerNorm(n_embd)
|
|
71
|
+
self.mlp = nn.ModuleDict(dict(
|
|
72
|
+
c_fc = nn.Linear(n_embd, 4 * n_embd),
|
|
73
|
+
c_proj = nn.Linear(4 * n_embd, n_embd),
|
|
74
|
+
act = NewGELU(),
|
|
75
|
+
dropout = nn.Dropout(resid_pdrop),
|
|
76
|
+
))
|
|
77
|
+
m = self.mlp
|
|
78
|
+
self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward
|
|
79
|
+
|
|
80
|
+
def forward(self, x):
|
|
81
|
+
x = x + self.attn(self.ln_1(x))
|
|
82
|
+
x = x + self.mlpf(self.ln_2(x))
|
|
83
|
+
return x
|