dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
dsipts/models/utils.py ADDED
@@ -0,0 +1,624 @@
1
+ import torch
2
+ import torch.nn.init as init
3
+ from torch import nn
4
+ import numpy as np
5
+ from numba import jit
6
+ from torch.autograd import Function
7
+
8
+
9
+ def get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss):
10
+ message = f'Can {"NOT" if not handle_multivariate else "" } handle multivariate output \n'\
11
+ f'Can {"NOT" if not handle_future_covariates else "" } handle future covariates\n'\
12
+ f'Can {"NOT" if not handle_categorical_variables else "" } handle categorical covariates\n'\
13
+ f'Can {"NOT" if not handle_quantile_loss else "" } handle Quantile loss function'
14
+
15
+ return message
16
+
17
+
18
+
19
+
20
+ class SinkhornDistance():
21
+ r"""
22
+ Given two empirical measures each with :math:`P_1` locations
23
+ :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
24
+ outputs an approximation of the regularized OT cost for point clouds.
25
+
26
+ Args:
27
+ eps (float): regularization coefficient
28
+ max_iter (int): maximum number of Sinkhorn iterations
29
+ reduction (string, optional): Specifies the reduction to apply to the output:
30
+ 'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
31
+ 'mean': the sum of the output will be divided by the number of
32
+ elements in the output, 'sum': the output will be summed. Default: 'none'
33
+
34
+ Shape:
35
+ - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
36
+ - Output: :math:`(N)` or :math:`()`, depending on `reduction`
37
+ """
38
+ def __init__(self, eps, max_iter, reduction='none'):
39
+ super(SinkhornDistance, self).__init__()
40
+ self.eps = eps
41
+ self.max_iter = max_iter
42
+ self.reduction = reduction
43
+
44
+ def compute(self, x, y):
45
+ # The Sinkhorn algorithm takes as input three variables :
46
+ C = self._cost_matrix(x, y).to(x.device) # Wasserstein cost function
47
+ x_points = x.shape[-2]
48
+ y_points = y.shape[-2]
49
+ if x.dim() == 2:
50
+ batch_size = 1
51
+ else:
52
+ batch_size = x.shape[0]
53
+
54
+ # both marginals are fixed with equal weights
55
+ mu = torch.empty(batch_size, x_points, dtype=torch.float,
56
+ requires_grad=False).fill_(1.0 / x_points).squeeze().to(x.device)
57
+ nu = torch.empty(batch_size, y_points, dtype=torch.float,
58
+ requires_grad=False).fill_(1.0 / y_points).squeeze().to(x.device)
59
+
60
+ u = torch.zeros_like(mu).to(x.device)
61
+ v = torch.zeros_like(nu).to(x.device)
62
+ # To check if algorithm terminates because of threshold
63
+ # or max iterations reached
64
+ actual_nits = 0
65
+ # Stopping criterion
66
+ thresh = 1e-1
67
+
68
+ # Sinkhorn iterations
69
+ for i in range(self.max_iter):
70
+ u1 = u # useful to check the update
71
+ u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
72
+ v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
73
+ err = (u - u1).abs().sum(-1).mean()
74
+
75
+ actual_nits += 1
76
+ if err.item() < thresh:
77
+ break
78
+
79
+ U, V = u, v
80
+ # Transport plan pi = diag(a)*K*diag(b)
81
+ pi = torch.exp(self.M(C, U, V))
82
+ # Sinkhorn distance
83
+ cost = torch.sum(pi * C, dim=(-2, -1))
84
+
85
+ if self.reduction == 'mean':
86
+ cost = cost.mean()
87
+ elif self.reduction == 'sum':
88
+ cost = cost.sum()
89
+
90
+ return cost#, pi, C
91
+
92
+ def M(self, C, u, v):
93
+ "Modified cost for logarithmic updates"
94
+ "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
95
+ return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps
96
+
97
+ @staticmethod
98
+ def _cost_matrix(x, y, p=2):
99
+ "Returns the matrix of $|x_i-y_j|^p$."
100
+ x_col = x.unsqueeze(-2)
101
+ y_lin = y.unsqueeze(-3)
102
+ C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
103
+ return C
104
+
105
+ @staticmethod
106
+ def ave(u, u1, tau):
107
+ "Barycenter subroutine, used by kinetic acceleration through extrapolation."
108
+ return tau * u + (1 - tau) * u1
109
+
110
+ class QuantileLossMO(nn.Module):
111
+ """Copied from git
112
+ """
113
+ def __init__(self, quantiles):
114
+ super().__init__()
115
+ self.quantiles = quantiles
116
+
117
+ def forward(self, preds, target):
118
+
119
+ assert not target.requires_grad
120
+ assert preds.size(0) == target.size(0)
121
+ tot_loss = 0
122
+ for j in range(preds.shape[2]):
123
+ losses = []
124
+ ##suppose BxLxCxMUL
125
+ for i, q in enumerate(self.quantiles):
126
+ errors = target[:,:,j] - preds[:,:,j, i]
127
+
128
+ losses.append(torch.abs(torch.max((q-1) * errors,q * errors)))
129
+
130
+ loss = torch.mean(torch.sum(torch.cat(losses, dim=1), dim=1))
131
+ tot_loss+=loss
132
+ return tot_loss/preds.shape[2]/len(self.quantiles)
133
+
134
+
135
+
136
+ class L1Loss(nn.Module):
137
+ """Custom L1Loss
138
+ """
139
+ def __init__(self):
140
+ super().__init__()
141
+ self.f = nn.L1Loss()
142
+ def forward(self, preds, target):
143
+ return self.f(preds[:,:,:,0],target)
144
+
145
+
146
+
147
+
148
+ class Permute(nn.Module):
149
+ def __init__(self):
150
+ super().__init__()
151
+
152
+ def forward(self, input):
153
+ return torch.permute(input,(0,2,1))
154
+
155
+ def get_activation(activation):
156
+ return eval(activation)
157
+
158
+
159
+ def weight_init_zeros(m):
160
+
161
+ if isinstance(m, nn.LSTM):
162
+ for param in m.parameters():
163
+ if len(param.shape) >= 2:
164
+ init.constant_(param.data,0.0)
165
+ else:
166
+ init.constant_(param.data,0.0)
167
+ elif isinstance(m, nn.Embedding):
168
+ init.constant_(m.weight,0.0)
169
+
170
+ elif isinstance(m, nn.LayerNorm):
171
+ init.zeros_(m.bias)
172
+ init.ones_(m.weight)
173
+
174
+ elif isinstance(m, nn.LSTMCell):
175
+ for param in m.parameters():
176
+ if len(param.shape) >= 2:
177
+ init.constant_(param.data,0.0)
178
+ else:
179
+ init.constant_(param.data,0.0)
180
+ elif isinstance(m, nn.GRU):
181
+ for param in m.parameters():
182
+ if len(param.shape) >= 2:
183
+ init.constant_(param.data,0.0)
184
+ else:
185
+ init.constant_(param.data,0.0)
186
+ for names in m._all_weights:
187
+ for name in filter(lambda n: "bias" in n, names):
188
+ bias = getattr(m, name)
189
+ n = bias.size(0)
190
+ bias.data[:n // 3].fill_(-1.)
191
+ elif isinstance(m, nn.GRUCell):
192
+ for param in m.parameters():
193
+ if len(param.shape) >= 2:
194
+ init.constant_(param.data,0.0)
195
+ else:
196
+ init.constant_(param.data,0.0)
197
+
198
+
199
+ else:
200
+ try:
201
+ init.constant_(m.weight.data, 0.0)
202
+ if m.bias is not None:
203
+ init.constant_(m.bias.data, 0.0)
204
+ except:
205
+ pass
206
+
207
+ def weight_init(m):
208
+ """
209
+ Usage:
210
+ model = Model()
211
+ model.apply(weight_init)
212
+ """
213
+ if isinstance(m, nn.Conv1d):
214
+ init.normal_(m.weight.data)
215
+ if m.bias is not None:
216
+ init.normal_(m.bias.data)
217
+ elif isinstance(m, nn.Conv2d):
218
+ init.xavier_normal_(m.weight.data)
219
+ if m.bias is not None:
220
+ init.normal_(m.bias.data)
221
+ elif isinstance(m, nn.Conv3d):
222
+ init.xavier_normal_(m.weight.data)
223
+ if m.bias is not None:
224
+ init.normal_(m.bias.data)
225
+ elif isinstance(m, nn.ConvTranspose1d):
226
+ init.normal_(m.weight.data)
227
+ if m.bias is not None:
228
+ init.normal_(m.bias.data)
229
+ elif isinstance(m, nn.ConvTranspose2d):
230
+ init.xavier_normal_(m.weight.data)
231
+ if m.bias is not None:
232
+ init.normal_(m.bias.data)
233
+ elif isinstance(m, nn.ConvTranspose3d):
234
+ init.xavier_normal_(m.weight.data)
235
+ if m.bias is not None:
236
+ init.normal_(m.bias.data)
237
+ elif isinstance(m, nn.BatchNorm1d):
238
+ init.normal_(m.weight.data, mean=1, std=0.02)
239
+ init.constant_(m.bias.data, 0)
240
+ elif isinstance(m, nn.BatchNorm2d):
241
+ init.normal_(m.weight.data, mean=1, std=0.02)
242
+ init.constant_(m.bias.data, 0)
243
+ elif isinstance(m, nn.BatchNorm3d):
244
+ init.normal_(m.weight.data, mean=1, std=0.02)
245
+ init.constant_(m.bias.data, 0)
246
+ elif isinstance(m, nn.Linear):
247
+ init.xavier_normal_(m.weight.data)
248
+ if m.bias is not None:
249
+ init.normal_(m.bias.data)
250
+ elif isinstance(m, nn.LSTM):
251
+ for param in m.parameters():
252
+ if len(param.shape) >= 2:
253
+ init.orthogonal_(param.data)
254
+ else:
255
+ init.normal_(param.data)
256
+ elif isinstance(m, nn.LSTMCell):
257
+ for param in m.parameters():
258
+ if len(param.shape) >= 2:
259
+ init.orthogonal_(param.data)
260
+ else:
261
+ init.normal_(param.data)
262
+ elif isinstance(m, nn.GRU):
263
+ for param in m.parameters():
264
+ if len(param.shape) >= 2:
265
+ init.orthogonal_(param.data)
266
+ else:
267
+ init.normal_(param.data)
268
+ for names in m._all_weights:
269
+ for name in filter(lambda n: "bias" in n, names):
270
+ bias = getattr(m, name)
271
+ n = bias.size(0)
272
+ bias.data[:n // 3].fill_(-1.)
273
+ elif isinstance(m, nn.GRUCell):
274
+ for param in m.parameters():
275
+ if len(param.shape) >= 2:
276
+ init.orthogonal_(param.data)
277
+ else:
278
+ init.normal_(param.data)
279
+
280
+ elif isinstance(m, nn.Embedding):
281
+ init.normal_(m.weight, mean=0.0, std=0.02)
282
+
283
+ elif isinstance(m, nn.LayerNorm):
284
+ init.zeros_(m.bias)
285
+ init.ones_(m.weight)
286
+
287
+ # if isinstance(module, nn.Linear):
288
+ # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
289
+ # if module.bias is not None:
290
+ # torch.nn.init.zeros_(module.bias)
291
+
292
+
293
+
294
+
295
+ def pairwise_distances(x, y=None):
296
+ '''
297
+ Input: x is a Nxd matrix
298
+ y is an optional Mxd matirx
299
+ Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
300
+ if y is not given then use 'y=x'.
301
+ i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
302
+ '''
303
+ x_norm = (x**2).sum(1).view(-1, 1)
304
+ if y is not None:
305
+ y_t = torch.transpose(y, 0, 1)
306
+ y_norm = (y**2).sum(1).view(1, -1)
307
+ else:
308
+ y_t = torch.transpose(x, 0, 1)
309
+ y_norm = x_norm.view(1, -1)
310
+
311
+ dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
312
+ return torch.clamp(dist, 0.0, float('inf'))
313
+
314
+ @jit(nopython = True)
315
+ def compute_softdtw(D, gamma):
316
+ N = D.shape[0]
317
+ M = D.shape[1]
318
+ R = np.zeros((N + 2, M + 2)) + 1e8
319
+ R[0, 0] = 0
320
+ for j in range(1, M + 1):
321
+ for i in range(1, N + 1):
322
+ r0 = -R[i - 1, j - 1] / gamma
323
+ r1 = -R[i - 1, j] / gamma
324
+ r2 = -R[i, j - 1] / gamma
325
+ rmax = max(max(r0, r1), r2)
326
+ rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax)
327
+ softmin = - gamma * (np.log(rsum) + rmax)
328
+ R[i, j] = D[i - 1, j - 1] + softmin
329
+ return R
330
+
331
+ @jit(nopython = True)
332
+ def compute_softdtw_backward(D_, R, gamma):
333
+ N = D_.shape[0]
334
+ M = D_.shape[1]
335
+ D = np.zeros((N + 2, M + 2))
336
+ E = np.zeros((N + 2, M + 2))
337
+ D[1:N + 1, 1:M + 1] = D_
338
+ E[-1, -1] = 1
339
+ R[:, -1] = -1e8
340
+ R[-1, :] = -1e8
341
+ R[-1, -1] = R[-2, -2]
342
+ for j in range(M, 0, -1):
343
+ for i in range(N, 0, -1):
344
+ a0 = (R[i + 1, j] - R[i, j] - D[i + 1, j]) / gamma
345
+ b0 = (R[i, j + 1] - R[i, j] - D[i, j + 1]) / gamma
346
+ c0 = (R[i + 1, j + 1] - R[i, j] - D[i + 1, j + 1]) / gamma
347
+ a = np.exp(a0)
348
+ b = np.exp(b0)
349
+ c = np.exp(c0)
350
+ E[i, j] = E[i + 1, j] * a + E[i, j + 1] * b + E[i + 1, j + 1] * c
351
+ return E[1:N + 1, 1:M + 1]
352
+
353
+
354
+ class SoftDTWBatch(Function):
355
+ @staticmethod
356
+ def forward(ctx, D, gamma = 1.0): # D.shape: [batch_size, N , N]
357
+ dev = D.device
358
+ batch_size,N,N = D.shape
359
+ gamma = torch.FloatTensor([gamma]).to(dev)
360
+ D_ = D.detach().cpu().numpy()
361
+ g_ = gamma.item()
362
+
363
+ total_loss = 0
364
+ R = torch.zeros((batch_size, N+2 ,N+2)).to(dev)
365
+ for k in range(0, batch_size): # loop over all D in the batch
366
+ Rk = torch.FloatTensor(compute_softdtw(D_[k,:,:], g_)).to(dev)
367
+ R[k:k+1,:,:] = Rk
368
+ total_loss = total_loss + Rk[-2,-2]
369
+ ctx.save_for_backward(D, R, gamma)
370
+ return total_loss / batch_size
371
+
372
+ @staticmethod
373
+ def backward(ctx, grad_output):
374
+ dev = grad_output.device
375
+ D, R, gamma = ctx.saved_tensors
376
+ batch_size,N,N = D.shape
377
+ D_ = D.detach().cpu().numpy()
378
+ R_ = R.detach().cpu().numpy()
379
+ g_ = gamma.item()
380
+
381
+ E = torch.zeros((batch_size, N ,N)).to(dev)
382
+ for k in range(batch_size):
383
+ Ek = torch.FloatTensor(compute_softdtw_backward(D_[k,:,:], R_[k,:,:], g_)).to(dev)
384
+ E[k:k+1,:,:] = Ek
385
+
386
+ return grad_output * E, None
387
+
388
+
389
+
390
+
391
+
392
+ @jit(nopython = True)
393
+ def my_max(x, gamma):
394
+ # use the log-sum-exp trick
395
+ max_x = np.max(x)
396
+ exp_x = np.exp((x - max_x) / gamma)
397
+ Z = np.sum(exp_x)
398
+ return gamma * np.log(Z) + max_x, exp_x / Z
399
+
400
+ @jit(nopython = True)
401
+ def my_min(x,gamma) :
402
+ min_x, argmax_x = my_max(-x, gamma)
403
+ return - min_x, argmax_x
404
+
405
+ @jit(nopython = True)
406
+ def my_max_hessian_product(p, z, gamma):
407
+ return ( p * z - p * np.sum(p * z) ) /gamma
408
+
409
+ @jit(nopython = True)
410
+ def my_min_hessian_product(p, z, gamma):
411
+ return - my_max_hessian_product(p, z, gamma)
412
+
413
+
414
+ @jit(nopython = True)
415
+ def dtw_grad(theta, gamma):
416
+ m = theta.shape[0]
417
+ n = theta.shape[1]
418
+ V = np.zeros((m + 1, n + 1))
419
+ V[:, 0] = 1e10
420
+ V[0, :] = 1e10
421
+ V[0, 0] = 0
422
+
423
+ Q = np.zeros((m + 2, n + 2, 3))
424
+
425
+ for i in range(1, m + 1):
426
+ for j in range(1, n + 1):
427
+ # theta is indexed starting from 0.
428
+ v, Q[i, j] = my_min(np.array([V[i, j - 1],
429
+ V[i - 1, j - 1],
430
+ V[i - 1, j]]) , gamma)
431
+ V[i, j] = theta[i - 1, j - 1] + v
432
+
433
+ E = np.zeros((m + 2, n + 2))
434
+ E[m + 1, :] = 0
435
+ E[:, n + 1] = 0
436
+ E[m + 1, n + 1] = 1
437
+ Q[m + 1, n + 1] = 1
438
+
439
+ for i in range(m,0,-1):
440
+ for j in range(n,0,-1):
441
+ E[i, j] = Q[i, j + 1, 0] * E[i, j + 1] + \
442
+ Q[i + 1, j + 1, 1] * E[i + 1, j + 1] + \
443
+ Q[i + 1, j, 2] * E[i + 1, j]
444
+
445
+ return V[m, n], E[1:m + 1, 1:n + 1], Q, E
446
+
447
+
448
+ @jit(nopython = True)
449
+ def dtw_hessian_prod(theta, Z, Q, E, gamma):
450
+ m = Z.shape[0]
451
+ n = Z.shape[1]
452
+
453
+ V_dot = np.zeros((m + 1, n + 1))
454
+ V_dot[0, 0] = 0
455
+
456
+ Q_dot = np.zeros((m + 2, n + 2, 3))
457
+ for i in range(1, m + 1):
458
+ for j in range(1, n + 1):
459
+ # theta is indexed starting from 0.
460
+ V_dot[i, j] = Z[i - 1, j - 1] + \
461
+ Q[i, j, 0] * V_dot[i, j - 1] + \
462
+ Q[i, j, 1] * V_dot[i - 1, j - 1] + \
463
+ Q[i, j, 2] * V_dot[i - 1, j]
464
+
465
+ v = np.array([V_dot[i, j - 1], V_dot[i - 1, j - 1], V_dot[i - 1, j]])
466
+ Q_dot[i, j] = my_min_hessian_product(Q[i, j], v, gamma)
467
+ E_dot = np.zeros((m + 2, n + 2))
468
+
469
+ for j in range(n,0,-1):
470
+ for i in range(m,0,-1):
471
+ E_dot[i, j] = Q_dot[i, j + 1, 0] * E[i, j + 1] + \
472
+ Q[i, j + 1, 0] * E_dot[i, j + 1] + \
473
+ Q_dot[i + 1, j + 1, 1] * E[i + 1, j + 1] + \
474
+ Q[i + 1, j + 1, 1] * E_dot[i + 1, j + 1] + \
475
+ Q_dot[i + 1, j, 2] * E[i + 1, j] + \
476
+ Q[i + 1, j, 2] * E_dot[i + 1, j]
477
+
478
+ return V_dot[m, n], E_dot[1:m + 1, 1:n + 1]
479
+
480
+
481
+ class PathDTWBatch(Function):
482
+ @staticmethod
483
+ def forward(ctx, D, gamma): # D.shape: [batch_size, N , N]
484
+ batch_size,N,N = D.shape
485
+ device = D.device
486
+ D_cpu = D.detach().cpu().numpy()
487
+ gamma_gpu = torch.FloatTensor([gamma]).to(device)
488
+
489
+ grad_gpu = torch.zeros((batch_size, N ,N)).to(device)
490
+ Q_gpu = torch.zeros((batch_size, N+2 ,N+2,3)).to(device)
491
+ E_gpu = torch.zeros((batch_size, N+2 ,N+2)).to(device)
492
+
493
+ for k in range(0,batch_size): # loop over all D in the batch
494
+ _, grad_cpu_k, Q_cpu_k, E_cpu_k = dtw_grad(D_cpu[k,:,:], gamma)
495
+ grad_gpu[k,:,:] = torch.FloatTensor(grad_cpu_k).to(device)
496
+ Q_gpu[k,:,:,:] = torch.FloatTensor(Q_cpu_k).to(device)
497
+ E_gpu[k,:,:] = torch.FloatTensor(E_cpu_k).to(device)
498
+ ctx.save_for_backward(grad_gpu,D, Q_gpu ,E_gpu, gamma_gpu)
499
+ return torch.mean(grad_gpu, dim=0)
500
+
501
+ @staticmethod
502
+ def backward(ctx, grad_output):
503
+ device = grad_output.device
504
+ grad_gpu, D_gpu, Q_gpu, E_gpu, gamma = ctx.saved_tensors
505
+ D_cpu = D_gpu.detach().cpu().numpy()
506
+ Q_cpu = Q_gpu.detach().cpu().numpy()
507
+ E_cpu = E_gpu.detach().cpu().numpy()
508
+ gamma = gamma.detach().cpu().numpy()[0]
509
+ Z = grad_output.detach().cpu().numpy()
510
+
511
+ batch_size,N,N = D_cpu.shape
512
+ Hessian = torch.zeros((batch_size, N ,N)).to(device)
513
+ for k in range(0,batch_size):
514
+ _, hess_k = dtw_hessian_prod(D_cpu[k,:,:], Z, Q_cpu[k,:,:,:], E_cpu[k,:,:], gamma)
515
+ Hessian[k:k+1,:,:] = torch.FloatTensor(hess_k).to(device)
516
+
517
+ return Hessian, None
518
+
519
+
520
+
521
+ import math
522
+ from typing import Union
523
+ class Embedding_cat_variables(nn.Module):
524
+ def __init__(self, length: int, d_model: int, emb_dims: list,reduction_mode:str='mean',use_classical_positional_encoder:bool=False, device:str='cpu'):
525
+ """
526
+ Embeds categorical variables with optional positional encodings.
527
+
528
+ Args:
529
+ length (int): Sequence length (e.g., total time steps).
530
+ d_model (int): Output embedding dimension.
531
+ emb_dims (list): Vocabulary sizes for each categorical feature.
532
+ reduction_mode (str): 'mean', 'sum', or 'none'.
533
+ use_classical_positional_encoder (bool): Whether to use sinusoidal positional encoding.
534
+ device (str): Device name (e.g., 'cpu' or 'cuda').
535
+
536
+ Notes:
537
+ - If `reduction_mode` is 'none', all embeddings are concatenated.
538
+ - If `use_classical_positional_encoder` is True, uses fixed sin/cos encoding.
539
+ - If False, treats position as a categorical variable and embeds it.
540
+ """
541
+
542
+
543
+ super().__init__()
544
+ self.length = length
545
+ self.device = device
546
+ self.reduction_mode = reduction_mode
547
+ self.emb_dims = emb_dims
548
+
549
+ self.use_classical_positional_encoder = use_classical_positional_encoder
550
+
551
+
552
+ if use_classical_positional_encoder:
553
+ pe = torch.zeros(length, d_model).to(device)
554
+ position = torch.arange(0, length, dtype=torch.float).unsqueeze(1).to(device)
555
+
556
+ # Compute the div_term (frequencies for sinusoids)
557
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).to(device)
558
+ div_term_odd = torch.exp(torch.arange(0, d_model-d_model%2, 2).float() * (-math.log(10000.0) / d_model)).to(device)
559
+
560
+ # Apply sine to even indices, cosine to odd indices
561
+
562
+ pe[:, 0::2] = torch.sin(position * div_term)
563
+ pe[:, 1::2] = torch.cos(position * div_term_odd)
564
+ ## this is static positional encoder
565
+ self.register_buffer('pe', pe)##static
566
+
567
+
568
+ else:
569
+ self.register_buffer('pe_emb', torch.arange(0, self.length).reshape(1, -1, 1)) ##static
570
+ self.emb_dims = [length+1] + emb_dims
571
+ #otherwise we add a new embedding layer
572
+
573
+ if self.reduction_mode =='none':
574
+ self.output_channels = len(self.emb_dims)*d_model
575
+ if use_classical_positional_encoder:
576
+ self.output_channels+=d_model
577
+ else:
578
+ self.output_channels = d_model ## if you want to have a fixed d_model size use mean or sum strategy
579
+
580
+ ##this is the core
581
+ self.cat_n_embd = nn.ModuleList([nn.Embedding(emb_dim, d_model) for emb_dim in self.emb_dims])
582
+
583
+ ##the batch size is required in case x is None (only positional encoder)
584
+ def forward(self,BS:int, x: Union[torch.Tensor,None]) -> torch.Tensor:
585
+
586
+ #this is the easy part
587
+ if x is None:
588
+ if self.use_classical_positional_encoder:
589
+ return self.pe.repeat(BS,1,1)
590
+ else:
591
+ return self.get_cat_n_embd(self.pe_emb.repeat(BS,1,1)).squeeze(2)
592
+
593
+
594
+ else:
595
+ if self.use_classical_positional_encoder is False:
596
+ cat_vars = torch.cat(( self.pe_emb.repeat(BS,1,1),x), dim=2)
597
+ else:
598
+ cat_vars = x
599
+ #building the encoders
600
+ cat_n_embd = self.get_cat_n_embd(cat_vars)
601
+
602
+ if self.reduction_mode =='sum':
603
+ cat_n_embd = torch.sum(cat_n_embd,axis=2)
604
+ elif self.reduction_mode =='mean':
605
+ cat_n_embd = torch.mean(cat_n_embd,axis=2)
606
+ else:
607
+ cat_n_embd = cat_n_embd.reshape(BS, self.length,-1)
608
+
609
+ if self.use_classical_positional_encoder:
610
+ if self.reduction_mode =='none':
611
+ cat_n_embd = torch.cat([cat_n_embd,self.pe.repeat(BS,1,1)], 2) ##stack the positional encoder
612
+ else:
613
+ cat_n_embd = cat_n_embd+self.pe.repeat(BS,1,1) ##add the positional encoder
614
+ return cat_n_embd
615
+
616
+
617
+ ##compute the target
618
+ def get_cat_n_embd(self, cat_vars):
619
+ emb = []
620
+ for index, layer in enumerate(self.cat_n_embd):
621
+ emb.append(layer(cat_vars[:, :, index]).unsqueeze(2))
622
+
623
+ cat_n_embd = torch.cat(emb,dim=2)
624
+ return cat_n_embd
File without changes
@@ -0,0 +1,83 @@
1
+
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+
8
+
9
+ # -----------------------------------------------------------------------------
10
+
11
+ class NewGELU(nn.Module):
12
+ """
13
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
14
+ Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
15
+ """
16
+ def forward(self, x):
17
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
18
+
19
+ class CausalSelfAttention(nn.Module):
20
+ """
21
+ A vanilla multi-head masked self-attention layer with a projection at the end.
22
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
23
+ explicit implementation here to show that there is nothing too scary here.
24
+ """
25
+
26
+ def __init__(self, n_embd,n_head,attn_pdrop,resid_pdrop,block_size):
27
+ super().__init__()
28
+ assert n_embd % n_head == 0
29
+ # key, query, value projections for all heads, but in a batch
30
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd)
31
+ # output projection
32
+ self.c_proj = nn.Linear(n_embd, n_embd)
33
+ # regularization
34
+ self.attn_dropout = nn.Dropout(attn_pdrop)
35
+ self.resid_dropout = nn.Dropout(resid_pdrop)
36
+ # causal mask to ensure that attention is only applied to the left in the input sequence
37
+ self.register_buffer("bias", torch.tril(torch.ones(block_size, block_size))
38
+ .view(1, 1, block_size, block_size))
39
+ self.n_head = n_head
40
+ self.n_embd = n_embd
41
+
42
+ def forward(self, x):
43
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
44
+
45
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
46
+ q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
47
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
48
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
49
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
50
+
51
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
52
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
53
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
54
+ att = F.softmax(att, dim=-1)
55
+ att = self.attn_dropout(att)
56
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
57
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
58
+
59
+ # output projection
60
+ y = self.resid_dropout(self.c_proj(y))
61
+ return y
62
+
63
+ class Block(nn.Module):
64
+ """ an unassuming Transformer block """
65
+
66
+ def __init__(self, n_embd,resid_pdrop,n_head,attn_pdrop,block_size):
67
+ super().__init__()
68
+ self.ln_1 = nn.LayerNorm(n_embd)
69
+ self.attn = CausalSelfAttention(n_embd,n_head,attn_pdrop,resid_pdrop,block_size)
70
+ self.ln_2 = nn.LayerNorm(n_embd)
71
+ self.mlp = nn.ModuleDict(dict(
72
+ c_fc = nn.Linear(n_embd, 4 * n_embd),
73
+ c_proj = nn.Linear(4 * n_embd, n_embd),
74
+ act = NewGELU(),
75
+ dropout = nn.Dropout(resid_pdrop),
76
+ ))
77
+ m = self.mlp
78
+ self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward
79
+
80
+ def forward(self, x):
81
+ x = x + self.attn(self.ln_1(x))
82
+ x = x + self.mlpf(self.ln_2(x))
83
+ return x