SWoTTeD 1.0.2a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- swotted/__init__.py +17 -0
- swotted/decomposition_contraints.py +87 -0
- swotted/fastswotted.py +604 -0
- swotted/loss_metrics.py +39 -0
- swotted/slidingWindow_model.py +161 -0
- swotted/swotted.py +417 -0
- swotted/temporal_regularization.py +56 -0
- swotted/utils.py +52 -0
- swotted/version.py +9 -0
- swotted-1.0.2a4.dist-info/METADATA +249 -0
- swotted-1.0.2a4.dist-info/RECORD +13 -0
- swotted-1.0.2a4.dist-info/WHEEL +4 -0
- swotted-1.0.2a4.dist-info/licenses/LICENSE.txt +165 -0
swotted/loss_metrics.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""This module contains the alternative losses that can be used
|
|
3
|
+
in tensor decomposition tasks.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Loss:
|
|
10
|
+
"""Difference loss"""
|
|
11
|
+
|
|
12
|
+
def compute(self, X, Y):
|
|
13
|
+
return (X - Y).sum()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Frobenius(Loss):
|
|
17
|
+
"""Frobenius loss to be used with data assuming a gaussian distribution
|
|
18
|
+
of their values."""
|
|
19
|
+
|
|
20
|
+
def compute(self, X, Y):
|
|
21
|
+
return torch.norm((X - Y), p="fro").sum()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Poisson(Loss):
|
|
25
|
+
"""Frobenius loss to be used with data assuming a Poisson distribution
|
|
26
|
+
of their values (counting attribute)."""
|
|
27
|
+
|
|
28
|
+
def compute(self, X, Y):
|
|
29
|
+
return Y.sum() - (X * torch.log(Y.clamp(min=1e-10))).sum()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Bernoulli(Loss):
|
|
33
|
+
"""Frobenius loss to be used with data assuming a bernoulli distribution
|
|
34
|
+
of their values (discrete values)."""
|
|
35
|
+
|
|
36
|
+
def compute(self, X, Y):
|
|
37
|
+
return (torch.log(1 + Y.clamp(min=1e-10))).sum() - (
|
|
38
|
+
X * torch.log(Y.clamp(min=1e-10))
|
|
39
|
+
).sum()
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""Sliding Windows reconstruction module
|
|
3
|
+
|
|
4
|
+
This module implements the SWoTTeD reconstruction of tensors based
|
|
5
|
+
on the temporal convolution of the temporal phenotypes with a pathways.
|
|
6
|
+
|
|
7
|
+
Example
|
|
8
|
+
--------
|
|
9
|
+
.. code-block:: python
|
|
10
|
+
|
|
11
|
+
from model.slidingWindow_model import SlidingWindow
|
|
12
|
+
from model.loss_metrics import Bernoulli
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
Ph = torch.rand( (5,10,3) ) # generation of 5 phenotypes with 10 features and 3 timestamps
|
|
16
|
+
Wp = torch.rand( (5,12) ) # generation of a pathway describing the occurrences of the 5 phenotypes across time
|
|
17
|
+
|
|
18
|
+
sw=SlidingWindow()
|
|
19
|
+
sw.setMetric(Bernoulli())
|
|
20
|
+
|
|
21
|
+
Yp=sw.reconstruct(Wp,Ph)
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
from functools import reduce
|
|
25
|
+
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn as nn
|
|
28
|
+
|
|
29
|
+
from swotted.loss_metrics import *
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SlidingWindow(nn.Module):
|
|
33
|
+
"""Torch module for the computation of the reconstruction error
|
|
34
|
+
by sliding phenotypes.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def setMetric(self, dist=Loss()):
|
|
38
|
+
"""
|
|
39
|
+
Define the loss used to evaluate the tensor reconstruction.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
-----------
|
|
43
|
+
dist: Loss
|
|
44
|
+
one of the loss metric available in the loss_metric module.
|
|
45
|
+
"""
|
|
46
|
+
self.metric = dist
|
|
47
|
+
|
|
48
|
+
def reconstruct(self, Wp, Ph):
|
|
49
|
+
"""
|
|
50
|
+
Implementation of the SWoTTeD reconstruction scheme (convolutional reconstruction).
|
|
51
|
+
|
|
52
|
+
Notes
|
|
53
|
+
-----
|
|
54
|
+
The function does not ensure that the output values belongs to [0,1]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
Ph: torch.Tensor
|
|
60
|
+
Phenotypes of size :math:`R * N * \\omega`, where :math:`R` is the
|
|
61
|
+
number of phenotypes and :math:`\\omega` the length of the temporal window
|
|
62
|
+
Wp: torch.Tensor
|
|
63
|
+
Assignement tensor of size :math:`R * (Tp-\\omega+1)` for patient :math:`p`
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
torch.Tensor
|
|
68
|
+
the **SWoTTeD** reconstruction of a pathway from :math:`Wp` and :math:`Ph`.
|
|
69
|
+
"""
|
|
70
|
+
# create a tensor of windows
|
|
71
|
+
Yp = torch.conv1d(
|
|
72
|
+
Wp.squeeze(dim=0), Ph.transpose(0, 1).flip(2), padding=Ph.shape[2] - 1
|
|
73
|
+
)
|
|
74
|
+
return Yp
|
|
75
|
+
|
|
76
|
+
def loss(self, Xp, Wp, Ph, padding=None):
|
|
77
|
+
"""Evaluation of the SWoTTeD reconstruction loss (see reconstruct method).
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
-----------
|
|
81
|
+
Xp: torch.Tensor
|
|
82
|
+
A 2nd-order tensor of size :math:`N * Tp`, where :math:`N` is the number
|
|
83
|
+
of drugs and :math:`Tp` is the time of the patient's stay
|
|
84
|
+
Ph: torch.Tensor
|
|
85
|
+
Phenotypes of size :math:`R * N * \\omega`, where :math:`R` is the
|
|
86
|
+
number of phenotypes and :math:`\\omega` the length of the temporal window
|
|
87
|
+
Wp: torch.Tensor
|
|
88
|
+
Assignement tensor of size :math:`R * Tp` for patient :math:`p`
|
|
89
|
+
padding: None, bool or tuple
|
|
90
|
+
If `padding` is True then the loss is evaluated on the interval
|
|
91
|
+
:math:`[\\omega, L-\\omega]` of the pathway.
|
|
92
|
+
If `padding` is a tuple `(a,b)`, then the loss is evaluated on the
|
|
93
|
+
interval :math:`[a, L-b]`.
|
|
94
|
+
Default is None (no padding)
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
float
|
|
99
|
+
the SWoTTeD reconstruction loss of one patient.
|
|
100
|
+
"""
|
|
101
|
+
Yp = self.reconstruct(Wp, Ph)
|
|
102
|
+
Twindow = Ph.shape[2]
|
|
103
|
+
|
|
104
|
+
if padding is not None:
|
|
105
|
+
if isinstance(padding, bool) and padding:
|
|
106
|
+
Yp = torch.split(
|
|
107
|
+
Yp,
|
|
108
|
+
[Twindow - 1, Yp.shape[1] - 2 * (Twindow - 1), Twindow - 1],
|
|
109
|
+
dim=1,
|
|
110
|
+
)[1]
|
|
111
|
+
Xp = torch.split(
|
|
112
|
+
Xp,
|
|
113
|
+
[Twindow - 1, Xp.shape[1] - 2 * (Twindow - 1), Twindow - 1],
|
|
114
|
+
dim=1,
|
|
115
|
+
)[1]
|
|
116
|
+
elif isinstance(padding, tuple) and len(padding) == 2:
|
|
117
|
+
Yp = torch.split(
|
|
118
|
+
Yp,
|
|
119
|
+
[padding[0], Yp.shape[1] - padding[0] - padding[1], padding[1]],
|
|
120
|
+
dim=1,
|
|
121
|
+
)[1]
|
|
122
|
+
Xp = torch.split(
|
|
123
|
+
Xp,
|
|
124
|
+
[padding[0], Xp.shape[1] - padding[0] - padding[1], padding[1]],
|
|
125
|
+
dim=1,
|
|
126
|
+
)[1]
|
|
127
|
+
|
|
128
|
+
# evaluate the loss
|
|
129
|
+
return self.metric.compute(Xp, Yp)
|
|
130
|
+
|
|
131
|
+
def forward(self, X, W, Ph, padding=None):
|
|
132
|
+
"""Evaluation of the SWoTTeD reconstruction loss for a collection of patients
|
|
133
|
+
(see reconstruct method).
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
Xp: list[torch.Tensor]
|
|
138
|
+
A 3nd-order tensor of size :math:`K* N * Tp`, where :math:`K` is the number
|
|
139
|
+
of patients, :math:`N` is the number of drugs and :math:`Tp` is the time of the
|
|
140
|
+
patient's stay
|
|
141
|
+
Ph: torch.Tensor
|
|
142
|
+
Phenotypes of size :math:`R * N * \\omega`, where :math:`R` is the
|
|
143
|
+
number of phenotypes and :math:`\\omega` the length of the temporal window
|
|
144
|
+
Wp: list[torch.Tensor]
|
|
145
|
+
Assignement tensor of size :math:`K* R * Tp` for patient :math:`p`
|
|
146
|
+
padding: None, bool or tuple
|
|
147
|
+
If `padding` is True then the loss is evaluated on the interval
|
|
148
|
+
:math:`[\\omega, L-\\omega]` of the pathway.
|
|
149
|
+
If `padding` is a tuple `(a,b)`, then the loss is evaluated on the interval
|
|
150
|
+
:math:`[a, L-b]`.
|
|
151
|
+
Default is `None` (no padding)
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
float
|
|
156
|
+
The SWoTTeD reconstruction loss of a collection of patients, that is the sum of
|
|
157
|
+
the losses for all patients.
|
|
158
|
+
"""
|
|
159
|
+
return reduce(
|
|
160
|
+
torch.add, [self.loss(Xp, Wp, Ph, padding) for Xp, Wp in zip(X, W)]
|
|
161
|
+
)
|
swotted/swotted.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""The SWoTTeD module
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import lightning.pytorch as pl
|
|
8
|
+
import torch.optim as optim
|
|
9
|
+
from torch.autograd import Variable
|
|
10
|
+
from munkres import Munkres
|
|
11
|
+
|
|
12
|
+
from omegaconf import DictConfig
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from swotted.slidingWindow_model import SlidingWindow
|
|
16
|
+
from swotted.loss_metrics import *
|
|
17
|
+
from swotted.decomposition_contraints import *
|
|
18
|
+
from swotted.temporal_regularization import *
|
|
19
|
+
from swotted.utils import *
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class swottedModule(pl.LightningModule):
|
|
23
|
+
"""SwoTTeD module (lightning module)"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, config: DictConfig):
|
|
26
|
+
super().__init__()
|
|
27
|
+
|
|
28
|
+
# use config as parameter
|
|
29
|
+
self.params = config
|
|
30
|
+
|
|
31
|
+
self.model = SlidingWindow()
|
|
32
|
+
self.model.setMetric(eval(self.params.model.metric)())
|
|
33
|
+
|
|
34
|
+
self.alpha = self.params.model.sparsity # sparsity
|
|
35
|
+
self.beta = self.params.model.non_succession # non-succession
|
|
36
|
+
self.adam = True
|
|
37
|
+
|
|
38
|
+
self.sparsity = self.params.model.sparsity > 0
|
|
39
|
+
self.pheno_succession = self.params.model.non_succession > 0
|
|
40
|
+
self.non_negativity = True
|
|
41
|
+
self.normalization = True
|
|
42
|
+
|
|
43
|
+
self.rank = self.params.model.rank
|
|
44
|
+
self.N = self.params.model.N
|
|
45
|
+
self.twl = self.params.model.twl
|
|
46
|
+
|
|
47
|
+
self.Ph = torch.nn.Parameter(
|
|
48
|
+
torch.rand(
|
|
49
|
+
(self.params.model.rank, self.params.model.N, self.params.model.twl)
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Important: Wk is not directly part of the model
|
|
54
|
+
self.Wk = None
|
|
55
|
+
|
|
56
|
+
# Important: This property activates manual optimization.
|
|
57
|
+
self.automatic_optimization = False
|
|
58
|
+
|
|
59
|
+
def configure_optimizers(self):
|
|
60
|
+
"""
|
|
61
|
+
Parent override.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
if self.adam:
|
|
65
|
+
optimizerPh = optim.Adam([self.Ph], lr=self.params.training.lr)
|
|
66
|
+
else:
|
|
67
|
+
optimizerPh = optim.SGD([self.Ph], lr=self.params.training.lr, momentum=0.9)
|
|
68
|
+
|
|
69
|
+
if self.adam:
|
|
70
|
+
optimizerW = optim.Adam(self.Wk, lr=self.params.training.lr)
|
|
71
|
+
else:
|
|
72
|
+
optimizerW = optim.SGD(self.Wk, lr=self.params.training.lr, momentum=0.9)
|
|
73
|
+
|
|
74
|
+
return optimizerPh, optimizerW
|
|
75
|
+
|
|
76
|
+
def forward(self, X):
|
|
77
|
+
"""
|
|
78
|
+
This forward function makes the decomposition of the tensor `X`.
|
|
79
|
+
It contains an optimisation stage to find the best decomposition.
|
|
80
|
+
The optimisation does not modifies the phenotypes of the model.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
-----------
|
|
84
|
+
X: torch.Tensor
|
|
85
|
+
tensor of dimension :math:`K * N * T` to decompose according to
|
|
86
|
+
the phenotype of the model
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
--------
|
|
90
|
+
torch.Tensor
|
|
91
|
+
A tensor of dimension :math:`K * R * (T-Tw)` that is the decomposition
|
|
92
|
+
of X according to the :math:`R` phenotypes of the model
|
|
93
|
+
"""
|
|
94
|
+
# self.unfreeze()
|
|
95
|
+
K = len(X) # number of patients
|
|
96
|
+
if self.N != X[0].shape[0]: # number of medical events
|
|
97
|
+
# TODO throw an error
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
with torch.inference_mode(False):
|
|
101
|
+
# torchlightning activates the inference mode that deeply disable the computation
|
|
102
|
+
# of gradients in the function. This is not sufficient to enable_grad() only.
|
|
103
|
+
|
|
104
|
+
Wk_batch = [
|
|
105
|
+
Variable(
|
|
106
|
+
torch.rand(self.rank, X[Tp].shape[1] - self.twl + 1),
|
|
107
|
+
requires_grad=True,
|
|
108
|
+
)
|
|
109
|
+
for Tp in range(K)
|
|
110
|
+
]
|
|
111
|
+
optimizerW = optim.Adam(Wk_batch, lr=self.params["predict"]["lr"])
|
|
112
|
+
|
|
113
|
+
n_epochs = self.params["predict"]["nepochs"]
|
|
114
|
+
for _ in range(n_epochs):
|
|
115
|
+
|
|
116
|
+
def closure():
|
|
117
|
+
optimizerW.zero_grad()
|
|
118
|
+
loss = self.model(X, Wk_batch, self.Ph.data)
|
|
119
|
+
if self.pheno_succession:
|
|
120
|
+
loss += self.beta * phenotypeSuccession_constraint(
|
|
121
|
+
Wk_batch, self.twl
|
|
122
|
+
)
|
|
123
|
+
loss.backward()
|
|
124
|
+
return loss
|
|
125
|
+
|
|
126
|
+
optimizerW.step(closure)
|
|
127
|
+
if self.non_negativity:
|
|
128
|
+
nonnegative_projection(*Wk_batch)
|
|
129
|
+
if self.normalization:
|
|
130
|
+
normalization_constraint(*Wk_batch)
|
|
131
|
+
# self.freeze()
|
|
132
|
+
return Wk_batch
|
|
133
|
+
|
|
134
|
+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
|
135
|
+
"""
|
|
136
|
+
Parent override.
|
|
137
|
+
"""
|
|
138
|
+
return self(batch) # it only calls the forward function
|
|
139
|
+
|
|
140
|
+
def training_step(self, batch, idx):
|
|
141
|
+
"""
|
|
142
|
+
Parent override.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
optimizerPh, optimizerW = self.optimizers()
|
|
146
|
+
|
|
147
|
+
D, indices = zip(*batch)
|
|
148
|
+
X = D
|
|
149
|
+
Wk_batch = [self.Wk[p] for p in indices]
|
|
150
|
+
Wk_batch_nograd = [self.Wk[p].data for p in indices]
|
|
151
|
+
|
|
152
|
+
def closure():
|
|
153
|
+
optimizerPh.zero_grad()
|
|
154
|
+
loss = self.model(X, Wk_batch_nograd, self.Ph)
|
|
155
|
+
self.log(
|
|
156
|
+
"train_reconstr_Ph",
|
|
157
|
+
loss,
|
|
158
|
+
on_step=True,
|
|
159
|
+
on_epoch=False,
|
|
160
|
+
prog_bar=False,
|
|
161
|
+
logger=True,
|
|
162
|
+
)
|
|
163
|
+
if self.sparsity:
|
|
164
|
+
sparsity_loss = sparsity_constraint(self.Ph)
|
|
165
|
+
self.log(
|
|
166
|
+
"train_sparsity_Ph",
|
|
167
|
+
sparsity_loss,
|
|
168
|
+
on_step=True,
|
|
169
|
+
on_epoch=False,
|
|
170
|
+
prog_bar=False,
|
|
171
|
+
logger=True,
|
|
172
|
+
)
|
|
173
|
+
loss += self.alpha * sparsity_loss
|
|
174
|
+
loss.backward()
|
|
175
|
+
self.log(
|
|
176
|
+
"train_loss_Ph",
|
|
177
|
+
loss,
|
|
178
|
+
on_step=True,
|
|
179
|
+
on_epoch=False,
|
|
180
|
+
prog_bar=False,
|
|
181
|
+
logger=True,
|
|
182
|
+
batch_size=len(indices),
|
|
183
|
+
)
|
|
184
|
+
return loss
|
|
185
|
+
|
|
186
|
+
optimizerPh.step(closure)
|
|
187
|
+
|
|
188
|
+
if self.non_negativity:
|
|
189
|
+
nonnegative_projection(*self.Ph) # non-negativity constraint
|
|
190
|
+
if self.normalization:
|
|
191
|
+
normalization_constraint(*self.Ph) # normalization constraint
|
|
192
|
+
|
|
193
|
+
# update W
|
|
194
|
+
def closure():
|
|
195
|
+
optimizerW.zero_grad()
|
|
196
|
+
loss = self.model(X, Wk_batch, self.Ph.data)
|
|
197
|
+
self.log(
|
|
198
|
+
"train_reconstr_W",
|
|
199
|
+
loss,
|
|
200
|
+
on_step=True,
|
|
201
|
+
on_epoch=False,
|
|
202
|
+
prog_bar=False,
|
|
203
|
+
logger=True,
|
|
204
|
+
)
|
|
205
|
+
if self.pheno_succession:
|
|
206
|
+
nonsucc_loss = phenotypeSuccession_constraint(Wk_batch, self.twl)
|
|
207
|
+
self.log(
|
|
208
|
+
"train_nonsucc_W",
|
|
209
|
+
nonsucc_loss,
|
|
210
|
+
on_step=True,
|
|
211
|
+
on_epoch=False,
|
|
212
|
+
prog_bar=False,
|
|
213
|
+
logger=True,
|
|
214
|
+
)
|
|
215
|
+
loss += self.beta * nonsucc_loss
|
|
216
|
+
loss.backward()
|
|
217
|
+
self.log(
|
|
218
|
+
"train_loss_W",
|
|
219
|
+
loss,
|
|
220
|
+
on_step=True,
|
|
221
|
+
on_epoch=False,
|
|
222
|
+
prog_bar=False,
|
|
223
|
+
logger=True,
|
|
224
|
+
batch_size=len(indices),
|
|
225
|
+
)
|
|
226
|
+
return loss
|
|
227
|
+
|
|
228
|
+
optimizerW.step(closure)
|
|
229
|
+
if self.non_negativity:
|
|
230
|
+
nonnegative_projection(*Wk_batch)
|
|
231
|
+
if self.normalization:
|
|
232
|
+
normalization_constraint(*Wk_batch)
|
|
233
|
+
|
|
234
|
+
def test_step(self, batch, batch_idx):
|
|
235
|
+
"""test step"""
|
|
236
|
+
X, _ = zip(*batch)
|
|
237
|
+
W_hat = self(X)
|
|
238
|
+
loss = self.model(X, W_hat, self.Ph)
|
|
239
|
+
self.log("test_loss", loss)
|
|
240
|
+
return loss
|
|
241
|
+
|
|
242
|
+
def validation_step(self, batch, batch_idx):
|
|
243
|
+
"""
|
|
244
|
+
Parent override.
|
|
245
|
+
|
|
246
|
+
***This function has not been tested***
|
|
247
|
+
"""
|
|
248
|
+
X, y = zip(*batch)
|
|
249
|
+
W_hat = self(
|
|
250
|
+
X
|
|
251
|
+
) # Apply the model on the data (requires optimisation of local W)
|
|
252
|
+
loss = self.model(X, W_hat, self.Ph)
|
|
253
|
+
# self.log("val_loss", loss)
|
|
254
|
+
return loss
|
|
255
|
+
|
|
256
|
+
def forecast(self, X):
|
|
257
|
+
"""
|
|
258
|
+
This function forecasts the next time step using the trained phenotypes.
|
|
259
|
+
This function can be used only with the parameter :math:`$\\omega\\geq 2` (`twl>=2`)
|
|
260
|
+
(phenotypes with more than two instant).
|
|
261
|
+
|
|
262
|
+
This function makes a projection of the data with the phenotypes of the model.
|
|
263
|
+
|
|
264
|
+
For computational efficiency, the time dimension of :math:`X` is reduced to
|
|
265
|
+
:math:`\\omega`, and then is extended :math:`\\omega-1` time steps on the right with
|
|
266
|
+
empty values.
|
|
267
|
+
|
|
268
|
+
Parameters
|
|
269
|
+
----------
|
|
270
|
+
X: torch.Tensor
|
|
271
|
+
tensor of dimension :math:`K* N* T` with :math:`T` to decompose
|
|
272
|
+
according to the phenotype of the model.
|
|
273
|
+
|
|
274
|
+
Returns
|
|
275
|
+
--------
|
|
276
|
+
torch.Tensor
|
|
277
|
+
A tensor of dimension :math:`K* N` that is the forecast of the
|
|
278
|
+
next time step of :math:`X`.
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
if self.twl < 2:
|
|
282
|
+
# trained with daily phenotypes
|
|
283
|
+
# TODO throw an error
|
|
284
|
+
return None
|
|
285
|
+
|
|
286
|
+
K = len(X) # number of patients
|
|
287
|
+
if self.N != X[0].shape[0]: # number of medical events
|
|
288
|
+
# TODO throw an error
|
|
289
|
+
return None
|
|
290
|
+
|
|
291
|
+
# reduction of the data based on the last "window" of size twl with zeros
|
|
292
|
+
# of length twl (region to predict)
|
|
293
|
+
X = [
|
|
294
|
+
torch.cat(
|
|
295
|
+
(xi[:, -(self.twl - 1) :], torch.zeros((self.N, self.twl))), axis=1
|
|
296
|
+
)
|
|
297
|
+
for xi in X
|
|
298
|
+
]
|
|
299
|
+
|
|
300
|
+
# now, we decompose the tensor ... without considering the last part of the
|
|
301
|
+
# reconstruction, ie the predicted part
|
|
302
|
+
with torch.inference_mode(False):
|
|
303
|
+
# torchlightning activates the inference mode that deeply disable the computation
|
|
304
|
+
# of gradients in the function. This is not sufficient to enable_grad() only.
|
|
305
|
+
|
|
306
|
+
Wk_batch = [
|
|
307
|
+
Variable(
|
|
308
|
+
torch.rand(self.rank, X[Tp].shape[1] - self.twl + 1),
|
|
309
|
+
requires_grad=True,
|
|
310
|
+
)
|
|
311
|
+
for Tp in range(K)
|
|
312
|
+
]
|
|
313
|
+
optimizerW = optim.Adam(Wk_batch, lr=self.params["predict"]["lr"])
|
|
314
|
+
|
|
315
|
+
n_epochs = self.params["predict"]["nepochs"]
|
|
316
|
+
for _ in range(n_epochs):
|
|
317
|
+
|
|
318
|
+
def closure():
|
|
319
|
+
optimizerW.zero_grad()
|
|
320
|
+
# evaluate the loss based on the beginning of the reconstruction only
|
|
321
|
+
loss = self.model(X, Wk_batch, self.Ph.data, padding=(0, self.twl))
|
|
322
|
+
if self.pheno_succession:
|
|
323
|
+
loss += self.beta * phenotypeSuccession_constraint(
|
|
324
|
+
Wk_batch, self.twl
|
|
325
|
+
)
|
|
326
|
+
loss.backward()
|
|
327
|
+
return loss
|
|
328
|
+
|
|
329
|
+
optimizerW.step(closure)
|
|
330
|
+
if self.non_negativity:
|
|
331
|
+
nonnegative_projection(*Wk_batch)
|
|
332
|
+
if self.normalization:
|
|
333
|
+
normalization_constraint(*Wk_batch)
|
|
334
|
+
|
|
335
|
+
# make a reconstruction, and select only the next event
|
|
336
|
+
with torch.no_grad():
|
|
337
|
+
pred = [
|
|
338
|
+
self.model.reconstruct(x, self.Ph.data)[:, self.twl] for x in Wk_batch
|
|
339
|
+
]
|
|
340
|
+
return pred
|
|
341
|
+
|
|
342
|
+
def reorderPhenotypes(self, gen_pheno, Wk=None, tw=2):
|
|
343
|
+
"""
|
|
344
|
+
This function outputs reordered internal phenotypes and pathways.
|
|
345
|
+
|
|
346
|
+
Parameters
|
|
347
|
+
----------
|
|
348
|
+
gen_pheno: torch.Tensor
|
|
349
|
+
generated phenotypes of size :math:`R x N x Tw`, where :math:`R` is the number of
|
|
350
|
+
phenotypes, :math:`N` is the number of drugs and :math:`Tw` is the length of the
|
|
351
|
+
temporal window
|
|
352
|
+
Wk: torch.Tensor
|
|
353
|
+
pathway to reorder, if None, it uses the internal pathways
|
|
354
|
+
tw: int
|
|
355
|
+
windows size
|
|
356
|
+
|
|
357
|
+
Returns
|
|
358
|
+
-------
|
|
359
|
+
A pair :math:`(rPh,rW)` with reordered phenotypes (aligned at best with gen_pheno) and the
|
|
360
|
+
corresponding reodering of the pathways
|
|
361
|
+
"""
|
|
362
|
+
if Wk is None:
|
|
363
|
+
Wk = self.Wk
|
|
364
|
+
|
|
365
|
+
if tw == 1:
|
|
366
|
+
gen_pheno = torch.unsqueeze(gen_pheno, 2) # transform into a matrix
|
|
367
|
+
|
|
368
|
+
if gen_pheno[0].shape != self.Ph[0].shape:
|
|
369
|
+
raise ValueError(
|
|
370
|
+
f"The generated phenotypes ({gen_pheno[0].shape}) and computed phenotypes \
|
|
371
|
+
({self.Ph[0].shape}) doesn't have the same shape."
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
dic = np.zeros(
|
|
375
|
+
(gen_pheno.shape[0], self.Ph.shape[0])
|
|
376
|
+
) # construct a cost matrix
|
|
377
|
+
|
|
378
|
+
for i in range(gen_pheno.shape[0]):
|
|
379
|
+
for j in range(self.Ph.shape[0]):
|
|
380
|
+
dic[i][j] = torch.norm((gen_pheno[i] - self.Ph[j]), p="fro").item()
|
|
381
|
+
|
|
382
|
+
m = Munkres() # Use of Hungarian Algorithm to find phenotypes correspondances
|
|
383
|
+
indexes = m.compute(dic)
|
|
384
|
+
|
|
385
|
+
# Reorder phenotypes
|
|
386
|
+
reordered_pheno = self.Ph.clone()
|
|
387
|
+
for row, column in indexes:
|
|
388
|
+
reordered_pheno[row] = self.Ph[column]
|
|
389
|
+
|
|
390
|
+
# Reorder pathways
|
|
391
|
+
reordered_pathways = [Wk[i].clone() for i in range(len(Wk))]
|
|
392
|
+
for i in range(len(Wk)):
|
|
393
|
+
for row, column in indexes:
|
|
394
|
+
reordered_pathways[i][row] = Wk[i][column]
|
|
395
|
+
|
|
396
|
+
return reordered_pheno, reordered_pathways
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class swottedTrainer(pl.Trainer):
|
|
400
|
+
def fit(
|
|
401
|
+
self,
|
|
402
|
+
model: swottedModule,
|
|
403
|
+
train_dataloaders,
|
|
404
|
+
val_dataloaders=None,
|
|
405
|
+
datamodule=None,
|
|
406
|
+
ckpt_path=None,
|
|
407
|
+
):
|
|
408
|
+
model.Wk = [
|
|
409
|
+
Variable(
|
|
410
|
+
torch.rand(model.rank, ds[0].shape[1] - model.twl + 1),
|
|
411
|
+
requires_grad=True,
|
|
412
|
+
)
|
|
413
|
+
for ds in train_dataloaders.dataset
|
|
414
|
+
]
|
|
415
|
+
return super().fit(
|
|
416
|
+
model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
|
|
417
|
+
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""Temporal regularization module
|
|
3
|
+
"""
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TemporalDependency(nn.Module):
|
|
9
|
+
"""Torch Module to implement the temporal regularization losses
|
|
10
|
+
This module is based on a LSTM.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, rank, nlayers, nhidden, dropout):
|
|
14
|
+
super(TemporalDependency, self).__init__()
|
|
15
|
+
|
|
16
|
+
self.nlayers = nlayers
|
|
17
|
+
self.nhid = nhidden
|
|
18
|
+
|
|
19
|
+
self.rnn = nn.LSTM(
|
|
20
|
+
input_size=rank,
|
|
21
|
+
hidden_size=nhidden,
|
|
22
|
+
num_layers=nlayers,
|
|
23
|
+
dropout=dropout,
|
|
24
|
+
batch_first=True,
|
|
25
|
+
)
|
|
26
|
+
self.decoder = nn.Sequential(nn.Linear(nhidden, rank), nn.ReLU())
|
|
27
|
+
self.init_weights()
|
|
28
|
+
|
|
29
|
+
def init_weights(self):
|
|
30
|
+
init_range = 0.1
|
|
31
|
+
for m in self.modules():
|
|
32
|
+
if isinstance(m, nn.Linear):
|
|
33
|
+
m.weight.data.uniform_(-init_range, init_range)
|
|
34
|
+
m.bias.data.zero_()
|
|
35
|
+
|
|
36
|
+
def forward(self, Ws, device):
|
|
37
|
+
train_loss = 0.0
|
|
38
|
+
for Wp in Ws:
|
|
39
|
+
inputs, targets = Wp[:-1, :], Wp[1:, :] # seq_len x n_dim
|
|
40
|
+
seq_len, n_dims = inputs.size()
|
|
41
|
+
|
|
42
|
+
hidden = self.init_hidden(1)
|
|
43
|
+
# seq_len x n_dims --> 1 x seq_len x n_dims
|
|
44
|
+
outputs, _ = self.rnn(inputs.unsqueeze(0), hidden)
|
|
45
|
+
logits = self.decoder(outputs.contiguous().view(-1, self.nhid))
|
|
46
|
+
loss = self.loss(logits, targets)
|
|
47
|
+
train_loss += loss
|
|
48
|
+
return train_loss
|
|
49
|
+
|
|
50
|
+
def init_hidden(self, batch_sz):
|
|
51
|
+
size = (self.nlayers, batch_sz, self.nhid)
|
|
52
|
+
weight = next(self.parameters())
|
|
53
|
+
return (weight.new_zeros(*size), weight.new_zeros(*size))
|
|
54
|
+
|
|
55
|
+
def loss(self, input, target):
|
|
56
|
+
return torch.mean((input - target) ** 2)
|