SWoTTeD 1.0.2a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
swotted/__init__.py ADDED
@@ -0,0 +1,17 @@
1
+ """
2
+ SWoTTeD Module
3
+
4
+ @author: Hana Sebia and Thomas Guyet
5
+ @date: 2023
6
+ @institution: Inria
7
+ """
8
+
9
+ from swotted.swotted import swottedModule, swottedTrainer
10
+ from swotted.fastswotted import (
11
+ fastSWoTTeDDataset,
12
+ fastSWoTTeDModule,
13
+ fastSWoTTeDTrainer,
14
+ )
15
+
16
+
17
+ name = "swotted"
@@ -0,0 +1,87 @@
1
+ """
2
+ SWoTTeD Module: decomposition constraints
3
+ """
4
+
5
+ from functools import reduce
6
+ import torch
7
+
8
+
9
+ def sparsity_constraint(var):
10
+ """Sparcity constraint (L1 metric) for a tensor `var`. The lower the better.
11
+
12
+ Args:
13
+ var (torch.tensor): a tensor
14
+
15
+ Returns:
16
+ float: constraint value
17
+ """
18
+ return torch.norm(var, 1)
19
+
20
+
21
+ def nonnegative_projection(*var):
22
+ """Transform a tensor by replacing the negative values by zeros.
23
+
24
+ Inplace transformation of the `var` parameter.
25
+
26
+ Args:
27
+ var: collection of tensors or tensor
28
+ """
29
+ for X in var:
30
+ X.data[X.data < 0] = 0
31
+
32
+
33
+ def normalization_constraint(*var):
34
+ """Transform a tensor by replacing the negative values by zeros and
35
+ values greater than 1 by ones.
36
+
37
+ Inplace transformation of the `var` parameter.
38
+
39
+ Args:
40
+ var: collection of tensors or tensor
41
+ """
42
+ for X in var:
43
+ X.data = torch.clamp(X.data, 0, 1)
44
+
45
+
46
+ def phenotypeSuccession_constraint(Wk, Tw):
47
+ """
48
+ Parameters
49
+ ----------
50
+ Wk: torch.Tensor
51
+ A 3rd order tensor of size :math:`K * rank * (T-Tw+1)`
52
+ """
53
+ O = torch.transpose(torch.stack([torch.eye(Wk[0].shape[0])] * (2 * Tw + 1)), 0, 2)
54
+ penalisation = reduce(
55
+ torch.add,
56
+ [
57
+ torch.sum(
58
+ torch.clamp(
59
+ Wp * torch.log(10e-8 + torch.conv1d(Wp, O, padding=Tw)), min=0
60
+ )
61
+ )
62
+ for Wp in Wk
63
+ ],
64
+ )
65
+ return penalisation
66
+
67
+
68
+ def phenotype_uniqueness(Ph):
69
+ """Evaluate the redundancy between phenotypes. The larger, the more redundant are
70
+ the phenotypes.
71
+ It computes the sum of pairwise cosines-similarity (dot products) between phenotypes.
72
+
73
+ Args:
74
+ Ph: collection of phenotypes (2D tensors)
75
+
76
+ Returns:
77
+ float: constraint value
78
+ """
79
+ Ph = torch.transpose(Ph, 1, 2)
80
+ ps = 0
81
+ for i in range(Ph.shape[0]):
82
+ for p1 in range(Ph.shape[1]):
83
+ for j in range(i + 1, Ph.shape[0]):
84
+ for p2 in range(Ph.shape[1]):
85
+ ps += torch.dot(Ph[i][p1], Ph[j][p2])
86
+
87
+ return ps.data
swotted/fastswotted.py ADDED
@@ -0,0 +1,604 @@
1
+ # -*- coding: utf-8 -*-
2
+ """This module implements the FastSWoTTeD reconstruction of tensors based
3
+ on the temporal convolution of the temporal phenotypes with a pathways.
4
+
5
+ This fast implementation decomposes collection of pathway all having the
6
+ same length.
7
+ """
8
+ import numpy as np
9
+ import lightning.pytorch as pl
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ from torch.autograd import Variable
14
+
15
+ from omegaconf import DictConfig
16
+ from munkres import Munkres
17
+
18
+ from swotted.loss_metrics import *
19
+ from swotted.temporal_regularization import *
20
+ from swotted.utils import *
21
+
22
+
23
+ class fastSWoTTeDDataset(Dataset):
24
+ """Implementation of a dataset class for `FastSwotted`.
25
+
26
+ The dataset uses a 3D tensors with dimensions :math:`K * N * T` where
27
+ :math:`K` is the number of individuals, :math:`N` the number of features
28
+ and :math:`T` the shared duration.
29
+
30
+ *FastSwotted* requires the pathways to be of same length.
31
+ """
32
+
33
+ def __init__(self, dataset: torch.Tensor):
34
+ """
35
+ Parameters
36
+ ----------
37
+ dataset: torch.Tensor
38
+ 3D Tensor with dimensions :math:`K * N * T` where
39
+ :math:`K` is the number of individuals, :math:`N` the number of features
40
+ and :math:`T` the shared duration.
41
+ """
42
+ if not isinstance(dataset, torch.Tensor) or len(dataset.shape) != 3:
43
+ raise TypeError(
44
+ "Invalid type for dataset: excepted a tensor of 3D dimension"
45
+ )
46
+
47
+ self.dataset = dataset
48
+
49
+ def __getitem__(self, idx: int):
50
+ return self.dataset[idx, :, :], idx
51
+
52
+ def __len__(self):
53
+ return self.dataset.shape[0]
54
+
55
+
56
+ class SlidingWindowConv(nn.Module):
57
+ """Torch module that defines the convolution of phenotypes with the pathway
58
+ matrix
59
+ """
60
+
61
+ def __init__(self, dist=Loss()):
62
+ super().__init__()
63
+ self.setMetric(dist)
64
+
65
+ def setMetric(self, dist):
66
+ """Setter for the metric property
67
+
68
+ Parameters
69
+ ----------
70
+ dist: Loss
71
+ Selection of one loss metric used to evaluate the quality of the reconstruction
72
+
73
+ See
74
+ ---
75
+ loss_metric.py"""
76
+ self.metric = dist
77
+
78
+ def reconstruct(self, W: torch.Tensor, Ph: torch.Tensor) -> torch.Tensor:
79
+ """Reconstruction function based on a convolution operator
80
+
81
+ Parameters
82
+ ----------
83
+ W: torch.Tensor
84
+ Pathway containing all occurrences of the phenotypes
85
+ Ph: torch.Tensor
86
+ Description of the phenotypes
87
+
88
+ Returns
89
+ -------
90
+ torch.Tensor
91
+ The reconstructed pathway that combines all occurrences of the phenotypes
92
+ along time."""
93
+ Y = torch.conv1d(W, Ph.transpose(0, 1), padding=Ph.shape[2] - 1).squeeze(dim=0)
94
+
95
+ if W.shape[0]:
96
+ Y = Y.unsqueeze(0)
97
+ return Y
98
+
99
+ def loss(self, Xp: torch.Tensor, Wp: torch.Tensor, Ph: torch.Tensor) -> float:
100
+ """
101
+ Parameters
102
+ ----------
103
+ Xp: torch.Tensor
104
+ a 2nd-order tensor of size :math:`N * Tp`, where :math:`N` is the number of
105
+ drugs and :math:`Tp` is the time of the patient's stay
106
+ Ph: torch.Tensor
107
+ phenotypes of size :math:`R * N * \omega`, where :math:`R` is the number of
108
+ phenotypes and :math:`\omega` the length of the temporal window
109
+ Wp: torch.Tensor
110
+ assignement tensor of size :math:`R * Tp` for patient :math:`p`
111
+ """
112
+ Yp = self.reconstruct(Wp.unsqueeze(dim=0), Ph)
113
+ # evaluate the loss
114
+ return self.metric.compute(Xp, Yp)
115
+
116
+ def forward(
117
+ self, X: torch.Tensor, W: torch.Tensor, Ph: torch.Tensor, padding: bool = None
118
+ ) -> float:
119
+ """
120
+ Parameters
121
+ ----------
122
+ X: torch.Tensor
123
+ a 3rd-order tensor of size :math:`K * N * Tp`, where :math:`N` is the number of
124
+ drugs and :math:`Tp` is the time of the patients' stays
125
+ Ph: torch.Tensor
126
+ phenotypes of size :math:`R * N * \omega`, where :math:`R` is the number of
127
+ phenotypes and :math:`\omega` the length of the temporal window
128
+ W: torch.Tensor
129
+ assignement tensor of size :math:`K * R * Tp` for all patients
130
+ """
131
+ # W is a tensor of size K x N x Tp
132
+ Y = self.reconstruct(W, Ph)
133
+
134
+ if padding is not None:
135
+ if isinstance(padding, tuple) and len(padding) == 2:
136
+ return self.metric.compute(
137
+ torch.split(
138
+ X,
139
+ [padding[0], X.shape[1] - padding[0] - padding[1], padding[1]],
140
+ dim=1,
141
+ )[1],
142
+ torch.split(
143
+ Y,
144
+ [padding[0], Y.shape[1] - padding[0] - padding[1], padding[1]],
145
+ dim=1,
146
+ )[1],
147
+ )
148
+ else:
149
+ raise ValueError(
150
+ f"error in padding parameter, got {padding} and expected the tuple of length 2."
151
+ )
152
+ else:
153
+ return self.metric.compute(X, Y)
154
+
155
+
156
+ class fastSWoTTeDModule(pl.LightningModule):
157
+ """
158
+
159
+ Warning
160
+ -------
161
+ The fastSwottedModule has to be used with a fastSwottedTrainer. This trainer
162
+ ensures the initialisation of the internal :math:`W` and :math:`O` tensors,
163
+ when the dataset is known.
164
+
165
+ Warning
166
+ -------
167
+ The phenotypes that are discovered by this module have to be flipped to
168
+ correspond to the correct temporal order!
169
+
170
+ .. code-block:: python
171
+
172
+ swotted = fastSwottedModule()
173
+ ...
174
+ swotted.fit()
175
+ ...
176
+ Ph = swotted.Ph
177
+ Ph = Ph.flip(2)
178
+ """
179
+
180
+ def __init__(self, config: DictConfig):
181
+ """
182
+ Parameters
183
+ ----------
184
+ config: (omegaconf.DictConfig)
185
+ Configuration of the model, training parameters and prediction parameters
186
+ see FastswoTTed_test.py for the list of required parameters.
187
+ """
188
+ super().__init__()
189
+
190
+ # use config as parameter
191
+ self.params = config
192
+
193
+ self.model = SlidingWindowConv(eval(self.params.model.metric)())
194
+
195
+ try:
196
+ self.alpha = self.params.model.sparsity # sparsity
197
+ self.beta = self.params.model.non_succession # non-succession
198
+ self.adam = True
199
+
200
+ self.sparsity = self.params.model.sparsity > 0
201
+ self.pheno_succession = self.params.model.non_succession > 0
202
+ self.non_negativity = True
203
+ self.normalization = True
204
+
205
+ self.rank = self.params.model.rank
206
+ self.N = self.params.model.N
207
+ self.twl = self.params.model.twl
208
+ except Exception as exc:
209
+ print("Missing mandatory model parameters in the configuration")
210
+ raise exc
211
+
212
+ self.Ph = torch.nn.Parameter(
213
+ torch.rand(
214
+ (self.params.model.rank, self.params.model.N, self.params.model.twl)
215
+ )
216
+ )
217
+
218
+ # Important: Wk is not directly part of the model. This torch variable is initialized in the trainer.
219
+ self.Wk = None
220
+
221
+ # O is a tool tensor for non-succession constraint. It is initialized in the trainer.
222
+ self.O = None
223
+
224
+ # Important: This property activates manual optimization.
225
+ self.automatic_optimization = False
226
+
227
+ def configure_optimizers(self):
228
+ """
229
+ Parent override.
230
+ """
231
+
232
+ if self.adam:
233
+ optimizerPh = optim.Adam([self.Ph], lr=self.params.training.lr)
234
+ else:
235
+ optimizerPh = optim.SGD([self.Ph], lr=self.params.training.lr, momentum=0.9)
236
+
237
+ return optimizerPh
238
+
239
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
240
+ """This forward function makes the decomposition of the tensor `X`.
241
+ It contains an optimisation stage to find the best decomposition.
242
+ The optimisation does not modifies the phenotypes of the model.
243
+
244
+ Parameters
245
+ -----------
246
+ X: (torch.Tensor)
247
+ tensor of dimension :math:`K * N * T` to decompose according to the
248
+ phenotype of the model
249
+
250
+ Returns
251
+ --------
252
+ torch.Tensor
253
+ A tensor of dimension :math:`K * R * (T-\omega)` that is the decomposition
254
+ of X according to the :math:`R` phenotypes of the model
255
+ """
256
+
257
+ K = X.shape[0] # number of patients
258
+ if self.N != X.shape[1]: # number of medical events
259
+ raise ValueError(
260
+ f"The second dimension of X (number of features) is invalid (expected {self.N})."
261
+ )
262
+
263
+ with torch.inference_mode(False):
264
+ # torchlightning activates the inference mode that deeply disable the computation
265
+ # of gradients in the function. This is not sufficient to enable_grad() only.
266
+ Wk_batch = Variable(
267
+ torch.rand(K, self.rank, X.shape[2] - self.twl + 1), requires_grad=True
268
+ )
269
+ optimizerW = optim.Adam([Wk_batch], lr=self.params["predict"]["lr"])
270
+
271
+ n_epochs = self.params["predict"]["nepochs"]
272
+ for _ in range(n_epochs):
273
+
274
+ def closure():
275
+ optimizerW.zero_grad()
276
+ loss = self.model(X, Wk_batch, self.Ph.data)
277
+ if self.pheno_succession:
278
+ loss += self.beta * self.phenotypeNonSuccession_loss(
279
+ Wk_batch, self.twl
280
+ )
281
+ loss.backward()
282
+ return loss
283
+
284
+ optimizerW.step(closure)
285
+ if self.non_negativity:
286
+ Wk_batch.data[Wk_batch.data < 0] = 0
287
+ if self.normalization:
288
+ Wk_batch.data = torch.clamp(Wk_batch.data, 0, 1)
289
+ return Wk_batch
290
+
291
+ def predict_step(self, batch, batch_idx, dataloader_idx=0) -> float:
292
+ """
293
+ Parent override.
294
+ """
295
+ return self(batch) # it only calls the forward function
296
+
297
+ def training_step(self, batch, idx):
298
+ """
299
+ Parent override.
300
+ """
301
+
302
+ optimizerPh = self.optimizers()
303
+
304
+ D, indices = zip(*batch)
305
+ X = torch.stack(D, dim=0)
306
+
307
+ Wk_batch = self.Wk[indices, :, :].detach()
308
+ Wk_batch.requires_grad_(True)
309
+ Wk_batch_nograd = self.Wk[indices, :, :].data
310
+
311
+ if self.adam:
312
+ optimizerW = optim.Adam([Wk_batch], lr=self.params.training.lr)
313
+ else:
314
+ optimizerW = optim.SGD([Wk_batch], lr=self.params.training.lr, momentum=0.9)
315
+
316
+ def closure():
317
+ optimizerPh.zero_grad()
318
+ loss = self.model(X, Wk_batch_nograd, self.Ph)
319
+ self.log(
320
+ "train_reconstr_Ph",
321
+ loss,
322
+ on_step=True,
323
+ on_epoch=False,
324
+ prog_bar=False,
325
+ logger=True,
326
+ )
327
+ if self.sparsity:
328
+ sparsity_loss = torch.norm(self.Ph, 1)
329
+ self.log(
330
+ "train_sparsity_Ph",
331
+ sparsity_loss,
332
+ on_step=True,
333
+ on_epoch=False,
334
+ prog_bar=False,
335
+ logger=True,
336
+ )
337
+ loss += self.alpha * sparsity_loss
338
+ loss.backward()
339
+ self.log(
340
+ "train_loss_Ph",
341
+ loss,
342
+ on_step=True,
343
+ on_epoch=False,
344
+ prog_bar=False,
345
+ logger=True,
346
+ )
347
+ return loss
348
+
349
+ optimizerPh.step(closure)
350
+
351
+ if self.non_negativity:
352
+ self.Ph.data[self.Ph.data < 0] = 0
353
+ if self.normalization:
354
+ self.Ph.data = torch.clamp(self.Ph.data, 0, 1)
355
+
356
+ # update W
357
+ def closure():
358
+ optimizerW.zero_grad()
359
+ loss = self.model(X, Wk_batch, self.Ph.data)
360
+ self.log(
361
+ "train_reconstr_W",
362
+ loss,
363
+ on_step=True,
364
+ on_epoch=False,
365
+ prog_bar=False,
366
+ logger=True,
367
+ )
368
+ if self.pheno_succession:
369
+ nonsucc_loss = self.phenotypeNonSuccession_loss(Wk_batch, self.twl)
370
+ self.log(
371
+ "train_nonsucc_W",
372
+ nonsucc_loss,
373
+ on_step=True,
374
+ on_epoch=False,
375
+ prog_bar=False,
376
+ logger=True,
377
+ )
378
+ loss += self.beta * nonsucc_loss
379
+ loss.backward()
380
+ self.log(
381
+ "train_loss_W",
382
+ loss,
383
+ on_step=True,
384
+ on_epoch=False,
385
+ prog_bar=False,
386
+ logger=True,
387
+ )
388
+ return loss
389
+
390
+ optimizerW.step(closure)
391
+ if self.non_negativity:
392
+ Wk_batch.data[Wk_batch.data < 0] = 0
393
+ if self.normalization:
394
+ Wk_batch.data = torch.clamp(Wk_batch.data, 0, 1)
395
+
396
+ self.Wk[indices, :, :] = Wk_batch
397
+
398
+ def test_step(self, batch, batch_idx) -> float:
399
+ X, _ = zip(*batch)
400
+ W_hat = self(X)
401
+ loss = self.model(X, W_hat, self.Ph)
402
+ # self.log("test_loss", loss)
403
+ return loss
404
+
405
+ def validation_step(self, batch, batch_idx):
406
+ """
407
+ Parent override.
408
+
409
+ ***This function has not been tested***
410
+ """
411
+ X, y = zip(*batch)
412
+ W_hat = self(
413
+ X
414
+ ) # Apply the model on the data (requires optimisation of local W)
415
+ loss = self.model(X, W_hat, self.Ph)
416
+ # self.log("val_loss", loss)
417
+ return loss
418
+
419
+ def forecast(self, X: torch.Tensor) -> torch.Tensor:
420
+ """This function forecasts the next time step using the trained phenotypes.
421
+ This function can be used only with the parameter :math:`\\omega\\geq 2` (`twl>=2`)
422
+ (phenotypes with more than two instant).
423
+
424
+ This function makes a projection of the data with the phenotypes of the model.
425
+
426
+ For computational efficiency, the time dimension of :math:`X` is reduced to
427
+ :math:`\\omega`, and then is extended :math:`\\omega-1` time steps on the right with
428
+ empty values.
429
+
430
+ Parameters
431
+ ----------
432
+ X: (torch.Tensor)
433
+ tensor of dimension :math:`K* N* T` with :math:`T` to decompose
434
+ according to the phenotype of the model.
435
+
436
+ Returns
437
+ --------
438
+ torch.Tensor
439
+ A tensor of dimension :math:`K * N` that is the forecast of the
440
+ next time step of :math:`X`.
441
+ """
442
+
443
+ if self.twl < 2:
444
+ # trained with daily phenotypes
445
+ raise ValueError(
446
+ "The width of the phenotype does not always to make forecasts. \
447
+ It is possible only with phenotype having a width>1."
448
+ )
449
+
450
+ K = X.shape[0] # number of patients
451
+ if self.N != X.shape[1]: # number of medical events
452
+ raise ValueError(
453
+ f"The second dimension of X (number of features) is invalid (expected {self.N})."
454
+ )
455
+
456
+ # reduction of the data based on the last "window" of size twl with zeros of length twl (region to predict)
457
+ X = torch.cat(
458
+ (X[:, :, -(self.twl - 1) :], torch.zeros((K, self.N, self.twl))), axis=2
459
+ )
460
+
461
+ # now, we decompose the tensor ... without considering the last part of the
462
+ # reconstruction, ie the predicted part
463
+ with torch.inference_mode(False):
464
+ # torchlightning activates the inference mode that deeply disable the computation
465
+ # of gradients in the function. This is not sufficient to enable_grad() only.
466
+
467
+ Wk_batch = Variable(
468
+ torch.rand(K, self.rank, X.shape[2] - self.twl + 1), requires_grad=True
469
+ )
470
+ optimizerW = optim.Adam([Wk_batch], lr=self.params["predict"]["lr"])
471
+
472
+ n_epochs = self.params["predict"]["nepochs"]
473
+ for _ in range(n_epochs):
474
+
475
+ def closure():
476
+ optimizerW.zero_grad()
477
+ # evaluate the loss based on the beginning of the reconstruction only
478
+ loss = self.model(X, Wk_batch, self.Ph.data, padding=(0, self.twl))
479
+ if self.pheno_succession:
480
+ loss += self.beta * self.phenotypeNonSuccession_loss(
481
+ Wk_batch, self.twl
482
+ )
483
+ loss.backward()
484
+ return loss
485
+
486
+ optimizerW.step(closure)
487
+ if self.non_negativity:
488
+ Wk_batch.data[Wk_batch.data < 0] = 0
489
+ if self.normalization:
490
+ Wk_batch.data = torch.clamp(Wk_batch.data, 0, 1)
491
+
492
+ # make a reconstruction, and select only the next event
493
+ with torch.no_grad():
494
+ pred = self.model.reconstruct(Wk_batch, self.Ph.data)[:, :, self.twl]
495
+ return pred
496
+
497
+ def phenotypeNonSuccession_loss(self, Wk: torch.Tensor, Tw: torch.Tensor):
498
+ """Definition of a loss that pushes the decomposition to add the
499
+ description in phenotypes preferably to in the pathways.
500
+
501
+ Parameters
502
+ ----------
503
+ Wk: torch.Tensor
504
+ A 3rd order tensor of size :math:`K * R * (T-\\omega+1)`
505
+ """
506
+ return torch.sum(
507
+ torch.clamp(
508
+ Wk * torch.log(10e-8 + torch.conv1d(Wk, self.O, padding=Tw)), min=0
509
+ )
510
+ )
511
+
512
+ def reorderPhenotypes(
513
+ self, gen_pheno: torch.Tensor, Wk: torch.Tensor = None, tw: int = 2
514
+ ) -> torch.Tensor:
515
+ """
516
+ This function outputs reordered internal phenotypes and pathways.
517
+
518
+ Parameters
519
+ ----------
520
+ gen_pheno: (torch.Tensor)
521
+ generated phenotypes of size :math:`R * N * \\omega`, where :math:`R` is the number of
522
+ phenotypes, :math:`N` is the number of drugs and :math:`\\omega` is the length of the
523
+ temporal window
524
+ Wk: (torch.Tensor)
525
+ pathway to reorder. If None, the function uses the internal pathways
526
+ tw: (int)
527
+ windows size (:math:`\\omega`)
528
+
529
+ Returns
530
+ -------
531
+ torch.Tensor
532
+ A pair `(rPh,rW)` with reordered phenotypes (aligned at best with `gen_pheno`) and
533
+ the corresponding reodering of the pathways
534
+ """
535
+ if Wk is None:
536
+ Wk = self.Wk
537
+
538
+ if tw == 1:
539
+ gen_pheno = torch.unsqueeze(gen_pheno, 2) # transform into a matrix
540
+
541
+ if gen_pheno[0].shape != self.Ph[0].shape:
542
+ raise ValueError(
543
+ "The generated phenotypes and computed phenotypes doesn't have the same shape"
544
+ )
545
+
546
+ dic = np.zeros(
547
+ (gen_pheno.shape[0], self.Ph.shape[0])
548
+ ) # construct a cost matrix
549
+
550
+ Ph = self.Ph.flip(2)
551
+
552
+ for i in range(gen_pheno.shape[0]):
553
+ for j in range(Ph.shape[0]):
554
+ dic[i][j] = torch.norm((gen_pheno[i] - Ph[j]), p="fro").item()
555
+
556
+ m = Munkres() # Use of Hungarian Algorithm to find phenotypes correspondances
557
+ indexes = m.compute(dic)
558
+
559
+ # Reorder phenotypes
560
+ reordered_pheno = Ph.clone()
561
+ for row, column in indexes:
562
+ reordered_pheno[row] = Ph[column]
563
+
564
+ # Reorder pathways
565
+ reordered_pathways = Wk.clone()
566
+ for row, column in indexes:
567
+ reordered_pathways[:, row] = Wk[:, column]
568
+
569
+ return reordered_pheno, reordered_pathways
570
+
571
+
572
+ class fastSWoTTeDTrainer(pl.Trainer):
573
+ """Trainer for fast-SWoTTeD
574
+
575
+ This class redefines the lightning trainer to take into account the
576
+ specificity of the training procedure in tensor decomposition.
577
+ """
578
+
579
+ def fit(
580
+ self,
581
+ model: fastSWoTTeDModule,
582
+ train_dataloaders,
583
+ val_dataloaders=None,
584
+ datamodule=None,
585
+ ckpt_path=None,
586
+ ):
587
+ model.Wk = Variable(
588
+ torch.rand(
589
+ len(train_dataloaders.dataset),
590
+ model.rank,
591
+ train_dataloaders.dataset[0][0].shape[1] - model.twl + 1,
592
+ ),
593
+ requires_grad=False,
594
+ )
595
+
596
+ # O is a matrix used for the non-succession constraint
597
+ model.O = torch.transpose(
598
+ torch.stack([torch.eye(model.Wk.shape[1])] * (2 * model.twl + 1)), 0, 2
599
+ )
600
+ ret = super().fit(
601
+ model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
602
+ )
603
+
604
+ return ret