SWoTTeD 1.0.2a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- swotted/__init__.py +17 -0
- swotted/decomposition_contraints.py +87 -0
- swotted/fastswotted.py +604 -0
- swotted/loss_metrics.py +39 -0
- swotted/slidingWindow_model.py +161 -0
- swotted/swotted.py +417 -0
- swotted/temporal_regularization.py +56 -0
- swotted/utils.py +52 -0
- swotted/version.py +9 -0
- swotted-1.0.2a4.dist-info/METADATA +249 -0
- swotted-1.0.2a4.dist-info/RECORD +13 -0
- swotted-1.0.2a4.dist-info/WHEEL +4 -0
- swotted-1.0.2a4.dist-info/licenses/LICENSE.txt +165 -0
swotted/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SWoTTeD Module
|
|
3
|
+
|
|
4
|
+
@author: Hana Sebia and Thomas Guyet
|
|
5
|
+
@date: 2023
|
|
6
|
+
@institution: Inria
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from swotted.swotted import swottedModule, swottedTrainer
|
|
10
|
+
from swotted.fastswotted import (
|
|
11
|
+
fastSWoTTeDDataset,
|
|
12
|
+
fastSWoTTeDModule,
|
|
13
|
+
fastSWoTTeDTrainer,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
name = "swotted"
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SWoTTeD Module: decomposition constraints
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from functools import reduce
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def sparsity_constraint(var):
|
|
10
|
+
"""Sparcity constraint (L1 metric) for a tensor `var`. The lower the better.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
var (torch.tensor): a tensor
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
float: constraint value
|
|
17
|
+
"""
|
|
18
|
+
return torch.norm(var, 1)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def nonnegative_projection(*var):
|
|
22
|
+
"""Transform a tensor by replacing the negative values by zeros.
|
|
23
|
+
|
|
24
|
+
Inplace transformation of the `var` parameter.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
var: collection of tensors or tensor
|
|
28
|
+
"""
|
|
29
|
+
for X in var:
|
|
30
|
+
X.data[X.data < 0] = 0
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def normalization_constraint(*var):
|
|
34
|
+
"""Transform a tensor by replacing the negative values by zeros and
|
|
35
|
+
values greater than 1 by ones.
|
|
36
|
+
|
|
37
|
+
Inplace transformation of the `var` parameter.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
var: collection of tensors or tensor
|
|
41
|
+
"""
|
|
42
|
+
for X in var:
|
|
43
|
+
X.data = torch.clamp(X.data, 0, 1)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def phenotypeSuccession_constraint(Wk, Tw):
|
|
47
|
+
"""
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
Wk: torch.Tensor
|
|
51
|
+
A 3rd order tensor of size :math:`K * rank * (T-Tw+1)`
|
|
52
|
+
"""
|
|
53
|
+
O = torch.transpose(torch.stack([torch.eye(Wk[0].shape[0])] * (2 * Tw + 1)), 0, 2)
|
|
54
|
+
penalisation = reduce(
|
|
55
|
+
torch.add,
|
|
56
|
+
[
|
|
57
|
+
torch.sum(
|
|
58
|
+
torch.clamp(
|
|
59
|
+
Wp * torch.log(10e-8 + torch.conv1d(Wp, O, padding=Tw)), min=0
|
|
60
|
+
)
|
|
61
|
+
)
|
|
62
|
+
for Wp in Wk
|
|
63
|
+
],
|
|
64
|
+
)
|
|
65
|
+
return penalisation
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def phenotype_uniqueness(Ph):
|
|
69
|
+
"""Evaluate the redundancy between phenotypes. The larger, the more redundant are
|
|
70
|
+
the phenotypes.
|
|
71
|
+
It computes the sum of pairwise cosines-similarity (dot products) between phenotypes.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
Ph: collection of phenotypes (2D tensors)
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
float: constraint value
|
|
78
|
+
"""
|
|
79
|
+
Ph = torch.transpose(Ph, 1, 2)
|
|
80
|
+
ps = 0
|
|
81
|
+
for i in range(Ph.shape[0]):
|
|
82
|
+
for p1 in range(Ph.shape[1]):
|
|
83
|
+
for j in range(i + 1, Ph.shape[0]):
|
|
84
|
+
for p2 in range(Ph.shape[1]):
|
|
85
|
+
ps += torch.dot(Ph[i][p1], Ph[j][p2])
|
|
86
|
+
|
|
87
|
+
return ps.data
|
swotted/fastswotted.py
ADDED
|
@@ -0,0 +1,604 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""This module implements the FastSWoTTeD reconstruction of tensors based
|
|
3
|
+
on the temporal convolution of the temporal phenotypes with a pathways.
|
|
4
|
+
|
|
5
|
+
This fast implementation decomposes collection of pathway all having the
|
|
6
|
+
same length.
|
|
7
|
+
"""
|
|
8
|
+
import numpy as np
|
|
9
|
+
import lightning.pytorch as pl
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
import torch.optim as optim
|
|
13
|
+
from torch.autograd import Variable
|
|
14
|
+
|
|
15
|
+
from omegaconf import DictConfig
|
|
16
|
+
from munkres import Munkres
|
|
17
|
+
|
|
18
|
+
from swotted.loss_metrics import *
|
|
19
|
+
from swotted.temporal_regularization import *
|
|
20
|
+
from swotted.utils import *
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class fastSWoTTeDDataset(Dataset):
|
|
24
|
+
"""Implementation of a dataset class for `FastSwotted`.
|
|
25
|
+
|
|
26
|
+
The dataset uses a 3D tensors with dimensions :math:`K * N * T` where
|
|
27
|
+
:math:`K` is the number of individuals, :math:`N` the number of features
|
|
28
|
+
and :math:`T` the shared duration.
|
|
29
|
+
|
|
30
|
+
*FastSwotted* requires the pathways to be of same length.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, dataset: torch.Tensor):
|
|
34
|
+
"""
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
dataset: torch.Tensor
|
|
38
|
+
3D Tensor with dimensions :math:`K * N * T` where
|
|
39
|
+
:math:`K` is the number of individuals, :math:`N` the number of features
|
|
40
|
+
and :math:`T` the shared duration.
|
|
41
|
+
"""
|
|
42
|
+
if not isinstance(dataset, torch.Tensor) or len(dataset.shape) != 3:
|
|
43
|
+
raise TypeError(
|
|
44
|
+
"Invalid type for dataset: excepted a tensor of 3D dimension"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
self.dataset = dataset
|
|
48
|
+
|
|
49
|
+
def __getitem__(self, idx: int):
|
|
50
|
+
return self.dataset[idx, :, :], idx
|
|
51
|
+
|
|
52
|
+
def __len__(self):
|
|
53
|
+
return self.dataset.shape[0]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class SlidingWindowConv(nn.Module):
|
|
57
|
+
"""Torch module that defines the convolution of phenotypes with the pathway
|
|
58
|
+
matrix
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, dist=Loss()):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.setMetric(dist)
|
|
64
|
+
|
|
65
|
+
def setMetric(self, dist):
|
|
66
|
+
"""Setter for the metric property
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
dist: Loss
|
|
71
|
+
Selection of one loss metric used to evaluate the quality of the reconstruction
|
|
72
|
+
|
|
73
|
+
See
|
|
74
|
+
---
|
|
75
|
+
loss_metric.py"""
|
|
76
|
+
self.metric = dist
|
|
77
|
+
|
|
78
|
+
def reconstruct(self, W: torch.Tensor, Ph: torch.Tensor) -> torch.Tensor:
|
|
79
|
+
"""Reconstruction function based on a convolution operator
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
W: torch.Tensor
|
|
84
|
+
Pathway containing all occurrences of the phenotypes
|
|
85
|
+
Ph: torch.Tensor
|
|
86
|
+
Description of the phenotypes
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
torch.Tensor
|
|
91
|
+
The reconstructed pathway that combines all occurrences of the phenotypes
|
|
92
|
+
along time."""
|
|
93
|
+
Y = torch.conv1d(W, Ph.transpose(0, 1), padding=Ph.shape[2] - 1).squeeze(dim=0)
|
|
94
|
+
|
|
95
|
+
if W.shape[0]:
|
|
96
|
+
Y = Y.unsqueeze(0)
|
|
97
|
+
return Y
|
|
98
|
+
|
|
99
|
+
def loss(self, Xp: torch.Tensor, Wp: torch.Tensor, Ph: torch.Tensor) -> float:
|
|
100
|
+
"""
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
Xp: torch.Tensor
|
|
104
|
+
a 2nd-order tensor of size :math:`N * Tp`, where :math:`N` is the number of
|
|
105
|
+
drugs and :math:`Tp` is the time of the patient's stay
|
|
106
|
+
Ph: torch.Tensor
|
|
107
|
+
phenotypes of size :math:`R * N * \omega`, where :math:`R` is the number of
|
|
108
|
+
phenotypes and :math:`\omega` the length of the temporal window
|
|
109
|
+
Wp: torch.Tensor
|
|
110
|
+
assignement tensor of size :math:`R * Tp` for patient :math:`p`
|
|
111
|
+
"""
|
|
112
|
+
Yp = self.reconstruct(Wp.unsqueeze(dim=0), Ph)
|
|
113
|
+
# evaluate the loss
|
|
114
|
+
return self.metric.compute(Xp, Yp)
|
|
115
|
+
|
|
116
|
+
def forward(
|
|
117
|
+
self, X: torch.Tensor, W: torch.Tensor, Ph: torch.Tensor, padding: bool = None
|
|
118
|
+
) -> float:
|
|
119
|
+
"""
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
X: torch.Tensor
|
|
123
|
+
a 3rd-order tensor of size :math:`K * N * Tp`, where :math:`N` is the number of
|
|
124
|
+
drugs and :math:`Tp` is the time of the patients' stays
|
|
125
|
+
Ph: torch.Tensor
|
|
126
|
+
phenotypes of size :math:`R * N * \omega`, where :math:`R` is the number of
|
|
127
|
+
phenotypes and :math:`\omega` the length of the temporal window
|
|
128
|
+
W: torch.Tensor
|
|
129
|
+
assignement tensor of size :math:`K * R * Tp` for all patients
|
|
130
|
+
"""
|
|
131
|
+
# W is a tensor of size K x N x Tp
|
|
132
|
+
Y = self.reconstruct(W, Ph)
|
|
133
|
+
|
|
134
|
+
if padding is not None:
|
|
135
|
+
if isinstance(padding, tuple) and len(padding) == 2:
|
|
136
|
+
return self.metric.compute(
|
|
137
|
+
torch.split(
|
|
138
|
+
X,
|
|
139
|
+
[padding[0], X.shape[1] - padding[0] - padding[1], padding[1]],
|
|
140
|
+
dim=1,
|
|
141
|
+
)[1],
|
|
142
|
+
torch.split(
|
|
143
|
+
Y,
|
|
144
|
+
[padding[0], Y.shape[1] - padding[0] - padding[1], padding[1]],
|
|
145
|
+
dim=1,
|
|
146
|
+
)[1],
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"error in padding parameter, got {padding} and expected the tuple of length 2."
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
return self.metric.compute(X, Y)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class fastSWoTTeDModule(pl.LightningModule):
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
Warning
|
|
160
|
+
-------
|
|
161
|
+
The fastSwottedModule has to be used with a fastSwottedTrainer. This trainer
|
|
162
|
+
ensures the initialisation of the internal :math:`W` and :math:`O` tensors,
|
|
163
|
+
when the dataset is known.
|
|
164
|
+
|
|
165
|
+
Warning
|
|
166
|
+
-------
|
|
167
|
+
The phenotypes that are discovered by this module have to be flipped to
|
|
168
|
+
correspond to the correct temporal order!
|
|
169
|
+
|
|
170
|
+
.. code-block:: python
|
|
171
|
+
|
|
172
|
+
swotted = fastSwottedModule()
|
|
173
|
+
...
|
|
174
|
+
swotted.fit()
|
|
175
|
+
...
|
|
176
|
+
Ph = swotted.Ph
|
|
177
|
+
Ph = Ph.flip(2)
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
def __init__(self, config: DictConfig):
|
|
181
|
+
"""
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
config: (omegaconf.DictConfig)
|
|
185
|
+
Configuration of the model, training parameters and prediction parameters
|
|
186
|
+
see FastswoTTed_test.py for the list of required parameters.
|
|
187
|
+
"""
|
|
188
|
+
super().__init__()
|
|
189
|
+
|
|
190
|
+
# use config as parameter
|
|
191
|
+
self.params = config
|
|
192
|
+
|
|
193
|
+
self.model = SlidingWindowConv(eval(self.params.model.metric)())
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
self.alpha = self.params.model.sparsity # sparsity
|
|
197
|
+
self.beta = self.params.model.non_succession # non-succession
|
|
198
|
+
self.adam = True
|
|
199
|
+
|
|
200
|
+
self.sparsity = self.params.model.sparsity > 0
|
|
201
|
+
self.pheno_succession = self.params.model.non_succession > 0
|
|
202
|
+
self.non_negativity = True
|
|
203
|
+
self.normalization = True
|
|
204
|
+
|
|
205
|
+
self.rank = self.params.model.rank
|
|
206
|
+
self.N = self.params.model.N
|
|
207
|
+
self.twl = self.params.model.twl
|
|
208
|
+
except Exception as exc:
|
|
209
|
+
print("Missing mandatory model parameters in the configuration")
|
|
210
|
+
raise exc
|
|
211
|
+
|
|
212
|
+
self.Ph = torch.nn.Parameter(
|
|
213
|
+
torch.rand(
|
|
214
|
+
(self.params.model.rank, self.params.model.N, self.params.model.twl)
|
|
215
|
+
)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Important: Wk is not directly part of the model. This torch variable is initialized in the trainer.
|
|
219
|
+
self.Wk = None
|
|
220
|
+
|
|
221
|
+
# O is a tool tensor for non-succession constraint. It is initialized in the trainer.
|
|
222
|
+
self.O = None
|
|
223
|
+
|
|
224
|
+
# Important: This property activates manual optimization.
|
|
225
|
+
self.automatic_optimization = False
|
|
226
|
+
|
|
227
|
+
def configure_optimizers(self):
|
|
228
|
+
"""
|
|
229
|
+
Parent override.
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
if self.adam:
|
|
233
|
+
optimizerPh = optim.Adam([self.Ph], lr=self.params.training.lr)
|
|
234
|
+
else:
|
|
235
|
+
optimizerPh = optim.SGD([self.Ph], lr=self.params.training.lr, momentum=0.9)
|
|
236
|
+
|
|
237
|
+
return optimizerPh
|
|
238
|
+
|
|
239
|
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
240
|
+
"""This forward function makes the decomposition of the tensor `X`.
|
|
241
|
+
It contains an optimisation stage to find the best decomposition.
|
|
242
|
+
The optimisation does not modifies the phenotypes of the model.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
-----------
|
|
246
|
+
X: (torch.Tensor)
|
|
247
|
+
tensor of dimension :math:`K * N * T` to decompose according to the
|
|
248
|
+
phenotype of the model
|
|
249
|
+
|
|
250
|
+
Returns
|
|
251
|
+
--------
|
|
252
|
+
torch.Tensor
|
|
253
|
+
A tensor of dimension :math:`K * R * (T-\omega)` that is the decomposition
|
|
254
|
+
of X according to the :math:`R` phenotypes of the model
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
K = X.shape[0] # number of patients
|
|
258
|
+
if self.N != X.shape[1]: # number of medical events
|
|
259
|
+
raise ValueError(
|
|
260
|
+
f"The second dimension of X (number of features) is invalid (expected {self.N})."
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
with torch.inference_mode(False):
|
|
264
|
+
# torchlightning activates the inference mode that deeply disable the computation
|
|
265
|
+
# of gradients in the function. This is not sufficient to enable_grad() only.
|
|
266
|
+
Wk_batch = Variable(
|
|
267
|
+
torch.rand(K, self.rank, X.shape[2] - self.twl + 1), requires_grad=True
|
|
268
|
+
)
|
|
269
|
+
optimizerW = optim.Adam([Wk_batch], lr=self.params["predict"]["lr"])
|
|
270
|
+
|
|
271
|
+
n_epochs = self.params["predict"]["nepochs"]
|
|
272
|
+
for _ in range(n_epochs):
|
|
273
|
+
|
|
274
|
+
def closure():
|
|
275
|
+
optimizerW.zero_grad()
|
|
276
|
+
loss = self.model(X, Wk_batch, self.Ph.data)
|
|
277
|
+
if self.pheno_succession:
|
|
278
|
+
loss += self.beta * self.phenotypeNonSuccession_loss(
|
|
279
|
+
Wk_batch, self.twl
|
|
280
|
+
)
|
|
281
|
+
loss.backward()
|
|
282
|
+
return loss
|
|
283
|
+
|
|
284
|
+
optimizerW.step(closure)
|
|
285
|
+
if self.non_negativity:
|
|
286
|
+
Wk_batch.data[Wk_batch.data < 0] = 0
|
|
287
|
+
if self.normalization:
|
|
288
|
+
Wk_batch.data = torch.clamp(Wk_batch.data, 0, 1)
|
|
289
|
+
return Wk_batch
|
|
290
|
+
|
|
291
|
+
def predict_step(self, batch, batch_idx, dataloader_idx=0) -> float:
|
|
292
|
+
"""
|
|
293
|
+
Parent override.
|
|
294
|
+
"""
|
|
295
|
+
return self(batch) # it only calls the forward function
|
|
296
|
+
|
|
297
|
+
def training_step(self, batch, idx):
|
|
298
|
+
"""
|
|
299
|
+
Parent override.
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
optimizerPh = self.optimizers()
|
|
303
|
+
|
|
304
|
+
D, indices = zip(*batch)
|
|
305
|
+
X = torch.stack(D, dim=0)
|
|
306
|
+
|
|
307
|
+
Wk_batch = self.Wk[indices, :, :].detach()
|
|
308
|
+
Wk_batch.requires_grad_(True)
|
|
309
|
+
Wk_batch_nograd = self.Wk[indices, :, :].data
|
|
310
|
+
|
|
311
|
+
if self.adam:
|
|
312
|
+
optimizerW = optim.Adam([Wk_batch], lr=self.params.training.lr)
|
|
313
|
+
else:
|
|
314
|
+
optimizerW = optim.SGD([Wk_batch], lr=self.params.training.lr, momentum=0.9)
|
|
315
|
+
|
|
316
|
+
def closure():
|
|
317
|
+
optimizerPh.zero_grad()
|
|
318
|
+
loss = self.model(X, Wk_batch_nograd, self.Ph)
|
|
319
|
+
self.log(
|
|
320
|
+
"train_reconstr_Ph",
|
|
321
|
+
loss,
|
|
322
|
+
on_step=True,
|
|
323
|
+
on_epoch=False,
|
|
324
|
+
prog_bar=False,
|
|
325
|
+
logger=True,
|
|
326
|
+
)
|
|
327
|
+
if self.sparsity:
|
|
328
|
+
sparsity_loss = torch.norm(self.Ph, 1)
|
|
329
|
+
self.log(
|
|
330
|
+
"train_sparsity_Ph",
|
|
331
|
+
sparsity_loss,
|
|
332
|
+
on_step=True,
|
|
333
|
+
on_epoch=False,
|
|
334
|
+
prog_bar=False,
|
|
335
|
+
logger=True,
|
|
336
|
+
)
|
|
337
|
+
loss += self.alpha * sparsity_loss
|
|
338
|
+
loss.backward()
|
|
339
|
+
self.log(
|
|
340
|
+
"train_loss_Ph",
|
|
341
|
+
loss,
|
|
342
|
+
on_step=True,
|
|
343
|
+
on_epoch=False,
|
|
344
|
+
prog_bar=False,
|
|
345
|
+
logger=True,
|
|
346
|
+
)
|
|
347
|
+
return loss
|
|
348
|
+
|
|
349
|
+
optimizerPh.step(closure)
|
|
350
|
+
|
|
351
|
+
if self.non_negativity:
|
|
352
|
+
self.Ph.data[self.Ph.data < 0] = 0
|
|
353
|
+
if self.normalization:
|
|
354
|
+
self.Ph.data = torch.clamp(self.Ph.data, 0, 1)
|
|
355
|
+
|
|
356
|
+
# update W
|
|
357
|
+
def closure():
|
|
358
|
+
optimizerW.zero_grad()
|
|
359
|
+
loss = self.model(X, Wk_batch, self.Ph.data)
|
|
360
|
+
self.log(
|
|
361
|
+
"train_reconstr_W",
|
|
362
|
+
loss,
|
|
363
|
+
on_step=True,
|
|
364
|
+
on_epoch=False,
|
|
365
|
+
prog_bar=False,
|
|
366
|
+
logger=True,
|
|
367
|
+
)
|
|
368
|
+
if self.pheno_succession:
|
|
369
|
+
nonsucc_loss = self.phenotypeNonSuccession_loss(Wk_batch, self.twl)
|
|
370
|
+
self.log(
|
|
371
|
+
"train_nonsucc_W",
|
|
372
|
+
nonsucc_loss,
|
|
373
|
+
on_step=True,
|
|
374
|
+
on_epoch=False,
|
|
375
|
+
prog_bar=False,
|
|
376
|
+
logger=True,
|
|
377
|
+
)
|
|
378
|
+
loss += self.beta * nonsucc_loss
|
|
379
|
+
loss.backward()
|
|
380
|
+
self.log(
|
|
381
|
+
"train_loss_W",
|
|
382
|
+
loss,
|
|
383
|
+
on_step=True,
|
|
384
|
+
on_epoch=False,
|
|
385
|
+
prog_bar=False,
|
|
386
|
+
logger=True,
|
|
387
|
+
)
|
|
388
|
+
return loss
|
|
389
|
+
|
|
390
|
+
optimizerW.step(closure)
|
|
391
|
+
if self.non_negativity:
|
|
392
|
+
Wk_batch.data[Wk_batch.data < 0] = 0
|
|
393
|
+
if self.normalization:
|
|
394
|
+
Wk_batch.data = torch.clamp(Wk_batch.data, 0, 1)
|
|
395
|
+
|
|
396
|
+
self.Wk[indices, :, :] = Wk_batch
|
|
397
|
+
|
|
398
|
+
def test_step(self, batch, batch_idx) -> float:
|
|
399
|
+
X, _ = zip(*batch)
|
|
400
|
+
W_hat = self(X)
|
|
401
|
+
loss = self.model(X, W_hat, self.Ph)
|
|
402
|
+
# self.log("test_loss", loss)
|
|
403
|
+
return loss
|
|
404
|
+
|
|
405
|
+
def validation_step(self, batch, batch_idx):
|
|
406
|
+
"""
|
|
407
|
+
Parent override.
|
|
408
|
+
|
|
409
|
+
***This function has not been tested***
|
|
410
|
+
"""
|
|
411
|
+
X, y = zip(*batch)
|
|
412
|
+
W_hat = self(
|
|
413
|
+
X
|
|
414
|
+
) # Apply the model on the data (requires optimisation of local W)
|
|
415
|
+
loss = self.model(X, W_hat, self.Ph)
|
|
416
|
+
# self.log("val_loss", loss)
|
|
417
|
+
return loss
|
|
418
|
+
|
|
419
|
+
def forecast(self, X: torch.Tensor) -> torch.Tensor:
|
|
420
|
+
"""This function forecasts the next time step using the trained phenotypes.
|
|
421
|
+
This function can be used only with the parameter :math:`\\omega\\geq 2` (`twl>=2`)
|
|
422
|
+
(phenotypes with more than two instant).
|
|
423
|
+
|
|
424
|
+
This function makes a projection of the data with the phenotypes of the model.
|
|
425
|
+
|
|
426
|
+
For computational efficiency, the time dimension of :math:`X` is reduced to
|
|
427
|
+
:math:`\\omega`, and then is extended :math:`\\omega-1` time steps on the right with
|
|
428
|
+
empty values.
|
|
429
|
+
|
|
430
|
+
Parameters
|
|
431
|
+
----------
|
|
432
|
+
X: (torch.Tensor)
|
|
433
|
+
tensor of dimension :math:`K* N* T` with :math:`T` to decompose
|
|
434
|
+
according to the phenotype of the model.
|
|
435
|
+
|
|
436
|
+
Returns
|
|
437
|
+
--------
|
|
438
|
+
torch.Tensor
|
|
439
|
+
A tensor of dimension :math:`K * N` that is the forecast of the
|
|
440
|
+
next time step of :math:`X`.
|
|
441
|
+
"""
|
|
442
|
+
|
|
443
|
+
if self.twl < 2:
|
|
444
|
+
# trained with daily phenotypes
|
|
445
|
+
raise ValueError(
|
|
446
|
+
"The width of the phenotype does not always to make forecasts. \
|
|
447
|
+
It is possible only with phenotype having a width>1."
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
K = X.shape[0] # number of patients
|
|
451
|
+
if self.N != X.shape[1]: # number of medical events
|
|
452
|
+
raise ValueError(
|
|
453
|
+
f"The second dimension of X (number of features) is invalid (expected {self.N})."
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# reduction of the data based on the last "window" of size twl with zeros of length twl (region to predict)
|
|
457
|
+
X = torch.cat(
|
|
458
|
+
(X[:, :, -(self.twl - 1) :], torch.zeros((K, self.N, self.twl))), axis=2
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# now, we decompose the tensor ... without considering the last part of the
|
|
462
|
+
# reconstruction, ie the predicted part
|
|
463
|
+
with torch.inference_mode(False):
|
|
464
|
+
# torchlightning activates the inference mode that deeply disable the computation
|
|
465
|
+
# of gradients in the function. This is not sufficient to enable_grad() only.
|
|
466
|
+
|
|
467
|
+
Wk_batch = Variable(
|
|
468
|
+
torch.rand(K, self.rank, X.shape[2] - self.twl + 1), requires_grad=True
|
|
469
|
+
)
|
|
470
|
+
optimizerW = optim.Adam([Wk_batch], lr=self.params["predict"]["lr"])
|
|
471
|
+
|
|
472
|
+
n_epochs = self.params["predict"]["nepochs"]
|
|
473
|
+
for _ in range(n_epochs):
|
|
474
|
+
|
|
475
|
+
def closure():
|
|
476
|
+
optimizerW.zero_grad()
|
|
477
|
+
# evaluate the loss based on the beginning of the reconstruction only
|
|
478
|
+
loss = self.model(X, Wk_batch, self.Ph.data, padding=(0, self.twl))
|
|
479
|
+
if self.pheno_succession:
|
|
480
|
+
loss += self.beta * self.phenotypeNonSuccession_loss(
|
|
481
|
+
Wk_batch, self.twl
|
|
482
|
+
)
|
|
483
|
+
loss.backward()
|
|
484
|
+
return loss
|
|
485
|
+
|
|
486
|
+
optimizerW.step(closure)
|
|
487
|
+
if self.non_negativity:
|
|
488
|
+
Wk_batch.data[Wk_batch.data < 0] = 0
|
|
489
|
+
if self.normalization:
|
|
490
|
+
Wk_batch.data = torch.clamp(Wk_batch.data, 0, 1)
|
|
491
|
+
|
|
492
|
+
# make a reconstruction, and select only the next event
|
|
493
|
+
with torch.no_grad():
|
|
494
|
+
pred = self.model.reconstruct(Wk_batch, self.Ph.data)[:, :, self.twl]
|
|
495
|
+
return pred
|
|
496
|
+
|
|
497
|
+
def phenotypeNonSuccession_loss(self, Wk: torch.Tensor, Tw: torch.Tensor):
|
|
498
|
+
"""Definition of a loss that pushes the decomposition to add the
|
|
499
|
+
description in phenotypes preferably to in the pathways.
|
|
500
|
+
|
|
501
|
+
Parameters
|
|
502
|
+
----------
|
|
503
|
+
Wk: torch.Tensor
|
|
504
|
+
A 3rd order tensor of size :math:`K * R * (T-\\omega+1)`
|
|
505
|
+
"""
|
|
506
|
+
return torch.sum(
|
|
507
|
+
torch.clamp(
|
|
508
|
+
Wk * torch.log(10e-8 + torch.conv1d(Wk, self.O, padding=Tw)), min=0
|
|
509
|
+
)
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
def reorderPhenotypes(
|
|
513
|
+
self, gen_pheno: torch.Tensor, Wk: torch.Tensor = None, tw: int = 2
|
|
514
|
+
) -> torch.Tensor:
|
|
515
|
+
"""
|
|
516
|
+
This function outputs reordered internal phenotypes and pathways.
|
|
517
|
+
|
|
518
|
+
Parameters
|
|
519
|
+
----------
|
|
520
|
+
gen_pheno: (torch.Tensor)
|
|
521
|
+
generated phenotypes of size :math:`R * N * \\omega`, where :math:`R` is the number of
|
|
522
|
+
phenotypes, :math:`N` is the number of drugs and :math:`\\omega` is the length of the
|
|
523
|
+
temporal window
|
|
524
|
+
Wk: (torch.Tensor)
|
|
525
|
+
pathway to reorder. If None, the function uses the internal pathways
|
|
526
|
+
tw: (int)
|
|
527
|
+
windows size (:math:`\\omega`)
|
|
528
|
+
|
|
529
|
+
Returns
|
|
530
|
+
-------
|
|
531
|
+
torch.Tensor
|
|
532
|
+
A pair `(rPh,rW)` with reordered phenotypes (aligned at best with `gen_pheno`) and
|
|
533
|
+
the corresponding reodering of the pathways
|
|
534
|
+
"""
|
|
535
|
+
if Wk is None:
|
|
536
|
+
Wk = self.Wk
|
|
537
|
+
|
|
538
|
+
if tw == 1:
|
|
539
|
+
gen_pheno = torch.unsqueeze(gen_pheno, 2) # transform into a matrix
|
|
540
|
+
|
|
541
|
+
if gen_pheno[0].shape != self.Ph[0].shape:
|
|
542
|
+
raise ValueError(
|
|
543
|
+
"The generated phenotypes and computed phenotypes doesn't have the same shape"
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
dic = np.zeros(
|
|
547
|
+
(gen_pheno.shape[0], self.Ph.shape[0])
|
|
548
|
+
) # construct a cost matrix
|
|
549
|
+
|
|
550
|
+
Ph = self.Ph.flip(2)
|
|
551
|
+
|
|
552
|
+
for i in range(gen_pheno.shape[0]):
|
|
553
|
+
for j in range(Ph.shape[0]):
|
|
554
|
+
dic[i][j] = torch.norm((gen_pheno[i] - Ph[j]), p="fro").item()
|
|
555
|
+
|
|
556
|
+
m = Munkres() # Use of Hungarian Algorithm to find phenotypes correspondances
|
|
557
|
+
indexes = m.compute(dic)
|
|
558
|
+
|
|
559
|
+
# Reorder phenotypes
|
|
560
|
+
reordered_pheno = Ph.clone()
|
|
561
|
+
for row, column in indexes:
|
|
562
|
+
reordered_pheno[row] = Ph[column]
|
|
563
|
+
|
|
564
|
+
# Reorder pathways
|
|
565
|
+
reordered_pathways = Wk.clone()
|
|
566
|
+
for row, column in indexes:
|
|
567
|
+
reordered_pathways[:, row] = Wk[:, column]
|
|
568
|
+
|
|
569
|
+
return reordered_pheno, reordered_pathways
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
class fastSWoTTeDTrainer(pl.Trainer):
|
|
573
|
+
"""Trainer for fast-SWoTTeD
|
|
574
|
+
|
|
575
|
+
This class redefines the lightning trainer to take into account the
|
|
576
|
+
specificity of the training procedure in tensor decomposition.
|
|
577
|
+
"""
|
|
578
|
+
|
|
579
|
+
def fit(
|
|
580
|
+
self,
|
|
581
|
+
model: fastSWoTTeDModule,
|
|
582
|
+
train_dataloaders,
|
|
583
|
+
val_dataloaders=None,
|
|
584
|
+
datamodule=None,
|
|
585
|
+
ckpt_path=None,
|
|
586
|
+
):
|
|
587
|
+
model.Wk = Variable(
|
|
588
|
+
torch.rand(
|
|
589
|
+
len(train_dataloaders.dataset),
|
|
590
|
+
model.rank,
|
|
591
|
+
train_dataloaders.dataset[0][0].shape[1] - model.twl + 1,
|
|
592
|
+
),
|
|
593
|
+
requires_grad=False,
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
# O is a matrix used for the non-succession constraint
|
|
597
|
+
model.O = torch.transpose(
|
|
598
|
+
torch.stack([torch.eye(model.Wk.shape[1])] * (2 * model.twl + 1)), 0, 2
|
|
599
|
+
)
|
|
600
|
+
ret = super().fit(
|
|
601
|
+
model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
return ret
|