SWoTTeD 1.0.2a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,39 @@
1
+ # -*- coding: utf-8 -*-
2
+ """This module contains the alternative losses that can be used
3
+ in tensor decomposition tasks.
4
+ """
5
+
6
+ import torch
7
+
8
+
9
+ class Loss:
10
+ """Difference loss"""
11
+
12
+ def compute(self, X, Y):
13
+ return (X - Y).sum()
14
+
15
+
16
+ class Frobenius(Loss):
17
+ """Frobenius loss to be used with data assuming a gaussian distribution
18
+ of their values."""
19
+
20
+ def compute(self, X, Y):
21
+ return torch.norm((X - Y), p="fro").sum()
22
+
23
+
24
+ class Poisson(Loss):
25
+ """Frobenius loss to be used with data assuming a Poisson distribution
26
+ of their values (counting attribute)."""
27
+
28
+ def compute(self, X, Y):
29
+ return Y.sum() - (X * torch.log(Y.clamp(min=1e-10))).sum()
30
+
31
+
32
+ class Bernoulli(Loss):
33
+ """Frobenius loss to be used with data assuming a bernoulli distribution
34
+ of their values (discrete values)."""
35
+
36
+ def compute(self, X, Y):
37
+ return (torch.log(1 + Y.clamp(min=1e-10))).sum() - (
38
+ X * torch.log(Y.clamp(min=1e-10))
39
+ ).sum()
@@ -0,0 +1,161 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Sliding Windows reconstruction module
3
+
4
+ This module implements the SWoTTeD reconstruction of tensors based
5
+ on the temporal convolution of the temporal phenotypes with a pathways.
6
+
7
+ Example
8
+ --------
9
+ .. code-block:: python
10
+
11
+ from model.slidingWindow_model import SlidingWindow
12
+ from model.loss_metrics import Bernoulli
13
+ import torch
14
+
15
+ Ph = torch.rand( (5,10,3) ) # generation of 5 phenotypes with 10 features and 3 timestamps
16
+ Wp = torch.rand( (5,12) ) # generation of a pathway describing the occurrences of the 5 phenotypes across time
17
+
18
+ sw=SlidingWindow()
19
+ sw.setMetric(Bernoulli())
20
+
21
+ Yp=sw.reconstruct(Wp,Ph)
22
+
23
+ """
24
+ from functools import reduce
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+ from swotted.loss_metrics import *
30
+
31
+
32
+ class SlidingWindow(nn.Module):
33
+ """Torch module for the computation of the reconstruction error
34
+ by sliding phenotypes.
35
+ """
36
+
37
+ def setMetric(self, dist=Loss()):
38
+ """
39
+ Define the loss used to evaluate the tensor reconstruction.
40
+
41
+ Parameters
42
+ -----------
43
+ dist: Loss
44
+ one of the loss metric available in the loss_metric module.
45
+ """
46
+ self.metric = dist
47
+
48
+ def reconstruct(self, Wp, Ph):
49
+ """
50
+ Implementation of the SWoTTeD reconstruction scheme (convolutional reconstruction).
51
+
52
+ Notes
53
+ -----
54
+ The function does not ensure that the output values belongs to [0,1]
55
+
56
+
57
+ Parameters
58
+ ----------
59
+ Ph: torch.Tensor
60
+ Phenotypes of size :math:`R * N * \\omega`, where :math:`R` is the
61
+ number of phenotypes and :math:`\\omega` the length of the temporal window
62
+ Wp: torch.Tensor
63
+ Assignement tensor of size :math:`R * (Tp-\\omega+1)` for patient :math:`p`
64
+
65
+ Returns
66
+ -------
67
+ torch.Tensor
68
+ the **SWoTTeD** reconstruction of a pathway from :math:`Wp` and :math:`Ph`.
69
+ """
70
+ # create a tensor of windows
71
+ Yp = torch.conv1d(
72
+ Wp.squeeze(dim=0), Ph.transpose(0, 1).flip(2), padding=Ph.shape[2] - 1
73
+ )
74
+ return Yp
75
+
76
+ def loss(self, Xp, Wp, Ph, padding=None):
77
+ """Evaluation of the SWoTTeD reconstruction loss (see reconstruct method).
78
+
79
+ Parameters
80
+ -----------
81
+ Xp: torch.Tensor
82
+ A 2nd-order tensor of size :math:`N * Tp`, where :math:`N` is the number
83
+ of drugs and :math:`Tp` is the time of the patient's stay
84
+ Ph: torch.Tensor
85
+ Phenotypes of size :math:`R * N * \\omega`, where :math:`R` is the
86
+ number of phenotypes and :math:`\\omega` the length of the temporal window
87
+ Wp: torch.Tensor
88
+ Assignement tensor of size :math:`R * Tp` for patient :math:`p`
89
+ padding: None, bool or tuple
90
+ If `padding` is True then the loss is evaluated on the interval
91
+ :math:`[\\omega, L-\\omega]` of the pathway.
92
+ If `padding` is a tuple `(a,b)`, then the loss is evaluated on the
93
+ interval :math:`[a, L-b]`.
94
+ Default is None (no padding)
95
+
96
+ Returns
97
+ -------
98
+ float
99
+ the SWoTTeD reconstruction loss of one patient.
100
+ """
101
+ Yp = self.reconstruct(Wp, Ph)
102
+ Twindow = Ph.shape[2]
103
+
104
+ if padding is not None:
105
+ if isinstance(padding, bool) and padding:
106
+ Yp = torch.split(
107
+ Yp,
108
+ [Twindow - 1, Yp.shape[1] - 2 * (Twindow - 1), Twindow - 1],
109
+ dim=1,
110
+ )[1]
111
+ Xp = torch.split(
112
+ Xp,
113
+ [Twindow - 1, Xp.shape[1] - 2 * (Twindow - 1), Twindow - 1],
114
+ dim=1,
115
+ )[1]
116
+ elif isinstance(padding, tuple) and len(padding) == 2:
117
+ Yp = torch.split(
118
+ Yp,
119
+ [padding[0], Yp.shape[1] - padding[0] - padding[1], padding[1]],
120
+ dim=1,
121
+ )[1]
122
+ Xp = torch.split(
123
+ Xp,
124
+ [padding[0], Xp.shape[1] - padding[0] - padding[1], padding[1]],
125
+ dim=1,
126
+ )[1]
127
+
128
+ # evaluate the loss
129
+ return self.metric.compute(Xp, Yp)
130
+
131
+ def forward(self, X, W, Ph, padding=None):
132
+ """Evaluation of the SWoTTeD reconstruction loss for a collection of patients
133
+ (see reconstruct method).
134
+
135
+ Parameters
136
+ ----------
137
+ Xp: list[torch.Tensor]
138
+ A 3nd-order tensor of size :math:`K* N * Tp`, where :math:`K` is the number
139
+ of patients, :math:`N` is the number of drugs and :math:`Tp` is the time of the
140
+ patient's stay
141
+ Ph: torch.Tensor
142
+ Phenotypes of size :math:`R * N * \\omega`, where :math:`R` is the
143
+ number of phenotypes and :math:`\\omega` the length of the temporal window
144
+ Wp: list[torch.Tensor]
145
+ Assignement tensor of size :math:`K* R * Tp` for patient :math:`p`
146
+ padding: None, bool or tuple
147
+ If `padding` is True then the loss is evaluated on the interval
148
+ :math:`[\\omega, L-\\omega]` of the pathway.
149
+ If `padding` is a tuple `(a,b)`, then the loss is evaluated on the interval
150
+ :math:`[a, L-b]`.
151
+ Default is `None` (no padding)
152
+
153
+ Returns
154
+ -------
155
+ float
156
+ The SWoTTeD reconstruction loss of a collection of patients, that is the sum of
157
+ the losses for all patients.
158
+ """
159
+ return reduce(
160
+ torch.add, [self.loss(Xp, Wp, Ph, padding) for Xp, Wp in zip(X, W)]
161
+ )
swotted/swotted.py ADDED
@@ -0,0 +1,417 @@
1
+ # -*- coding: utf-8 -*-
2
+ """The SWoTTeD module
3
+ """
4
+
5
+ import numpy as np
6
+ import torch
7
+ import lightning.pytorch as pl
8
+ import torch.optim as optim
9
+ from torch.autograd import Variable
10
+ from munkres import Munkres
11
+
12
+ from omegaconf import DictConfig
13
+
14
+
15
+ from swotted.slidingWindow_model import SlidingWindow
16
+ from swotted.loss_metrics import *
17
+ from swotted.decomposition_contraints import *
18
+ from swotted.temporal_regularization import *
19
+ from swotted.utils import *
20
+
21
+
22
+ class swottedModule(pl.LightningModule):
23
+ """SwoTTeD module (lightning module)"""
24
+
25
+ def __init__(self, config: DictConfig):
26
+ super().__init__()
27
+
28
+ # use config as parameter
29
+ self.params = config
30
+
31
+ self.model = SlidingWindow()
32
+ self.model.setMetric(eval(self.params.model.metric)())
33
+
34
+ self.alpha = self.params.model.sparsity # sparsity
35
+ self.beta = self.params.model.non_succession # non-succession
36
+ self.adam = True
37
+
38
+ self.sparsity = self.params.model.sparsity > 0
39
+ self.pheno_succession = self.params.model.non_succession > 0
40
+ self.non_negativity = True
41
+ self.normalization = True
42
+
43
+ self.rank = self.params.model.rank
44
+ self.N = self.params.model.N
45
+ self.twl = self.params.model.twl
46
+
47
+ self.Ph = torch.nn.Parameter(
48
+ torch.rand(
49
+ (self.params.model.rank, self.params.model.N, self.params.model.twl)
50
+ )
51
+ )
52
+
53
+ # Important: Wk is not directly part of the model
54
+ self.Wk = None
55
+
56
+ # Important: This property activates manual optimization.
57
+ self.automatic_optimization = False
58
+
59
+ def configure_optimizers(self):
60
+ """
61
+ Parent override.
62
+ """
63
+
64
+ if self.adam:
65
+ optimizerPh = optim.Adam([self.Ph], lr=self.params.training.lr)
66
+ else:
67
+ optimizerPh = optim.SGD([self.Ph], lr=self.params.training.lr, momentum=0.9)
68
+
69
+ if self.adam:
70
+ optimizerW = optim.Adam(self.Wk, lr=self.params.training.lr)
71
+ else:
72
+ optimizerW = optim.SGD(self.Wk, lr=self.params.training.lr, momentum=0.9)
73
+
74
+ return optimizerPh, optimizerW
75
+
76
+ def forward(self, X):
77
+ """
78
+ This forward function makes the decomposition of the tensor `X`.
79
+ It contains an optimisation stage to find the best decomposition.
80
+ The optimisation does not modifies the phenotypes of the model.
81
+
82
+ Parameters
83
+ -----------
84
+ X: torch.Tensor
85
+ tensor of dimension :math:`K * N * T` to decompose according to
86
+ the phenotype of the model
87
+
88
+ Returns
89
+ --------
90
+ torch.Tensor
91
+ A tensor of dimension :math:`K * R * (T-Tw)` that is the decomposition
92
+ of X according to the :math:`R` phenotypes of the model
93
+ """
94
+ # self.unfreeze()
95
+ K = len(X) # number of patients
96
+ if self.N != X[0].shape[0]: # number of medical events
97
+ # TODO throw an error
98
+ return None
99
+
100
+ with torch.inference_mode(False):
101
+ # torchlightning activates the inference mode that deeply disable the computation
102
+ # of gradients in the function. This is not sufficient to enable_grad() only.
103
+
104
+ Wk_batch = [
105
+ Variable(
106
+ torch.rand(self.rank, X[Tp].shape[1] - self.twl + 1),
107
+ requires_grad=True,
108
+ )
109
+ for Tp in range(K)
110
+ ]
111
+ optimizerW = optim.Adam(Wk_batch, lr=self.params["predict"]["lr"])
112
+
113
+ n_epochs = self.params["predict"]["nepochs"]
114
+ for _ in range(n_epochs):
115
+
116
+ def closure():
117
+ optimizerW.zero_grad()
118
+ loss = self.model(X, Wk_batch, self.Ph.data)
119
+ if self.pheno_succession:
120
+ loss += self.beta * phenotypeSuccession_constraint(
121
+ Wk_batch, self.twl
122
+ )
123
+ loss.backward()
124
+ return loss
125
+
126
+ optimizerW.step(closure)
127
+ if self.non_negativity:
128
+ nonnegative_projection(*Wk_batch)
129
+ if self.normalization:
130
+ normalization_constraint(*Wk_batch)
131
+ # self.freeze()
132
+ return Wk_batch
133
+
134
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
135
+ """
136
+ Parent override.
137
+ """
138
+ return self(batch) # it only calls the forward function
139
+
140
+ def training_step(self, batch, idx):
141
+ """
142
+ Parent override.
143
+ """
144
+
145
+ optimizerPh, optimizerW = self.optimizers()
146
+
147
+ D, indices = zip(*batch)
148
+ X = D
149
+ Wk_batch = [self.Wk[p] for p in indices]
150
+ Wk_batch_nograd = [self.Wk[p].data for p in indices]
151
+
152
+ def closure():
153
+ optimizerPh.zero_grad()
154
+ loss = self.model(X, Wk_batch_nograd, self.Ph)
155
+ self.log(
156
+ "train_reconstr_Ph",
157
+ loss,
158
+ on_step=True,
159
+ on_epoch=False,
160
+ prog_bar=False,
161
+ logger=True,
162
+ )
163
+ if self.sparsity:
164
+ sparsity_loss = sparsity_constraint(self.Ph)
165
+ self.log(
166
+ "train_sparsity_Ph",
167
+ sparsity_loss,
168
+ on_step=True,
169
+ on_epoch=False,
170
+ prog_bar=False,
171
+ logger=True,
172
+ )
173
+ loss += self.alpha * sparsity_loss
174
+ loss.backward()
175
+ self.log(
176
+ "train_loss_Ph",
177
+ loss,
178
+ on_step=True,
179
+ on_epoch=False,
180
+ prog_bar=False,
181
+ logger=True,
182
+ batch_size=len(indices),
183
+ )
184
+ return loss
185
+
186
+ optimizerPh.step(closure)
187
+
188
+ if self.non_negativity:
189
+ nonnegative_projection(*self.Ph) # non-negativity constraint
190
+ if self.normalization:
191
+ normalization_constraint(*self.Ph) # normalization constraint
192
+
193
+ # update W
194
+ def closure():
195
+ optimizerW.zero_grad()
196
+ loss = self.model(X, Wk_batch, self.Ph.data)
197
+ self.log(
198
+ "train_reconstr_W",
199
+ loss,
200
+ on_step=True,
201
+ on_epoch=False,
202
+ prog_bar=False,
203
+ logger=True,
204
+ )
205
+ if self.pheno_succession:
206
+ nonsucc_loss = phenotypeSuccession_constraint(Wk_batch, self.twl)
207
+ self.log(
208
+ "train_nonsucc_W",
209
+ nonsucc_loss,
210
+ on_step=True,
211
+ on_epoch=False,
212
+ prog_bar=False,
213
+ logger=True,
214
+ )
215
+ loss += self.beta * nonsucc_loss
216
+ loss.backward()
217
+ self.log(
218
+ "train_loss_W",
219
+ loss,
220
+ on_step=True,
221
+ on_epoch=False,
222
+ prog_bar=False,
223
+ logger=True,
224
+ batch_size=len(indices),
225
+ )
226
+ return loss
227
+
228
+ optimizerW.step(closure)
229
+ if self.non_negativity:
230
+ nonnegative_projection(*Wk_batch)
231
+ if self.normalization:
232
+ normalization_constraint(*Wk_batch)
233
+
234
+ def test_step(self, batch, batch_idx):
235
+ """test step"""
236
+ X, _ = zip(*batch)
237
+ W_hat = self(X)
238
+ loss = self.model(X, W_hat, self.Ph)
239
+ self.log("test_loss", loss)
240
+ return loss
241
+
242
+ def validation_step(self, batch, batch_idx):
243
+ """
244
+ Parent override.
245
+
246
+ ***This function has not been tested***
247
+ """
248
+ X, y = zip(*batch)
249
+ W_hat = self(
250
+ X
251
+ ) # Apply the model on the data (requires optimisation of local W)
252
+ loss = self.model(X, W_hat, self.Ph)
253
+ # self.log("val_loss", loss)
254
+ return loss
255
+
256
+ def forecast(self, X):
257
+ """
258
+ This function forecasts the next time step using the trained phenotypes.
259
+ This function can be used only with the parameter :math:`$\\omega\\geq 2` (`twl>=2`)
260
+ (phenotypes with more than two instant).
261
+
262
+ This function makes a projection of the data with the phenotypes of the model.
263
+
264
+ For computational efficiency, the time dimension of :math:`X` is reduced to
265
+ :math:`\\omega`, and then is extended :math:`\\omega-1` time steps on the right with
266
+ empty values.
267
+
268
+ Parameters
269
+ ----------
270
+ X: torch.Tensor
271
+ tensor of dimension :math:`K* N* T` with :math:`T` to decompose
272
+ according to the phenotype of the model.
273
+
274
+ Returns
275
+ --------
276
+ torch.Tensor
277
+ A tensor of dimension :math:`K* N` that is the forecast of the
278
+ next time step of :math:`X`.
279
+ """
280
+
281
+ if self.twl < 2:
282
+ # trained with daily phenotypes
283
+ # TODO throw an error
284
+ return None
285
+
286
+ K = len(X) # number of patients
287
+ if self.N != X[0].shape[0]: # number of medical events
288
+ # TODO throw an error
289
+ return None
290
+
291
+ # reduction of the data based on the last "window" of size twl with zeros
292
+ # of length twl (region to predict)
293
+ X = [
294
+ torch.cat(
295
+ (xi[:, -(self.twl - 1) :], torch.zeros((self.N, self.twl))), axis=1
296
+ )
297
+ for xi in X
298
+ ]
299
+
300
+ # now, we decompose the tensor ... without considering the last part of the
301
+ # reconstruction, ie the predicted part
302
+ with torch.inference_mode(False):
303
+ # torchlightning activates the inference mode that deeply disable the computation
304
+ # of gradients in the function. This is not sufficient to enable_grad() only.
305
+
306
+ Wk_batch = [
307
+ Variable(
308
+ torch.rand(self.rank, X[Tp].shape[1] - self.twl + 1),
309
+ requires_grad=True,
310
+ )
311
+ for Tp in range(K)
312
+ ]
313
+ optimizerW = optim.Adam(Wk_batch, lr=self.params["predict"]["lr"])
314
+
315
+ n_epochs = self.params["predict"]["nepochs"]
316
+ for _ in range(n_epochs):
317
+
318
+ def closure():
319
+ optimizerW.zero_grad()
320
+ # evaluate the loss based on the beginning of the reconstruction only
321
+ loss = self.model(X, Wk_batch, self.Ph.data, padding=(0, self.twl))
322
+ if self.pheno_succession:
323
+ loss += self.beta * phenotypeSuccession_constraint(
324
+ Wk_batch, self.twl
325
+ )
326
+ loss.backward()
327
+ return loss
328
+
329
+ optimizerW.step(closure)
330
+ if self.non_negativity:
331
+ nonnegative_projection(*Wk_batch)
332
+ if self.normalization:
333
+ normalization_constraint(*Wk_batch)
334
+
335
+ # make a reconstruction, and select only the next event
336
+ with torch.no_grad():
337
+ pred = [
338
+ self.model.reconstruct(x, self.Ph.data)[:, self.twl] for x in Wk_batch
339
+ ]
340
+ return pred
341
+
342
+ def reorderPhenotypes(self, gen_pheno, Wk=None, tw=2):
343
+ """
344
+ This function outputs reordered internal phenotypes and pathways.
345
+
346
+ Parameters
347
+ ----------
348
+ gen_pheno: torch.Tensor
349
+ generated phenotypes of size :math:`R x N x Tw`, where :math:`R` is the number of
350
+ phenotypes, :math:`N` is the number of drugs and :math:`Tw` is the length of the
351
+ temporal window
352
+ Wk: torch.Tensor
353
+ pathway to reorder, if None, it uses the internal pathways
354
+ tw: int
355
+ windows size
356
+
357
+ Returns
358
+ -------
359
+ A pair :math:`(rPh,rW)` with reordered phenotypes (aligned at best with gen_pheno) and the
360
+ corresponding reodering of the pathways
361
+ """
362
+ if Wk is None:
363
+ Wk = self.Wk
364
+
365
+ if tw == 1:
366
+ gen_pheno = torch.unsqueeze(gen_pheno, 2) # transform into a matrix
367
+
368
+ if gen_pheno[0].shape != self.Ph[0].shape:
369
+ raise ValueError(
370
+ f"The generated phenotypes ({gen_pheno[0].shape}) and computed phenotypes \
371
+ ({self.Ph[0].shape}) doesn't have the same shape."
372
+ )
373
+
374
+ dic = np.zeros(
375
+ (gen_pheno.shape[0], self.Ph.shape[0])
376
+ ) # construct a cost matrix
377
+
378
+ for i in range(gen_pheno.shape[0]):
379
+ for j in range(self.Ph.shape[0]):
380
+ dic[i][j] = torch.norm((gen_pheno[i] - self.Ph[j]), p="fro").item()
381
+
382
+ m = Munkres() # Use of Hungarian Algorithm to find phenotypes correspondances
383
+ indexes = m.compute(dic)
384
+
385
+ # Reorder phenotypes
386
+ reordered_pheno = self.Ph.clone()
387
+ for row, column in indexes:
388
+ reordered_pheno[row] = self.Ph[column]
389
+
390
+ # Reorder pathways
391
+ reordered_pathways = [Wk[i].clone() for i in range(len(Wk))]
392
+ for i in range(len(Wk)):
393
+ for row, column in indexes:
394
+ reordered_pathways[i][row] = Wk[i][column]
395
+
396
+ return reordered_pheno, reordered_pathways
397
+
398
+
399
+ class swottedTrainer(pl.Trainer):
400
+ def fit(
401
+ self,
402
+ model: swottedModule,
403
+ train_dataloaders,
404
+ val_dataloaders=None,
405
+ datamodule=None,
406
+ ckpt_path=None,
407
+ ):
408
+ model.Wk = [
409
+ Variable(
410
+ torch.rand(model.rank, ds[0].shape[1] - model.twl + 1),
411
+ requires_grad=True,
412
+ )
413
+ for ds in train_dataloaders.dataset
414
+ ]
415
+ return super().fit(
416
+ model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
417
+ )
@@ -0,0 +1,56 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Temporal regularization module
3
+ """
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class TemporalDependency(nn.Module):
9
+ """Torch Module to implement the temporal regularization losses
10
+ This module is based on a LSTM.
11
+ """
12
+
13
+ def __init__(self, rank, nlayers, nhidden, dropout):
14
+ super(TemporalDependency, self).__init__()
15
+
16
+ self.nlayers = nlayers
17
+ self.nhid = nhidden
18
+
19
+ self.rnn = nn.LSTM(
20
+ input_size=rank,
21
+ hidden_size=nhidden,
22
+ num_layers=nlayers,
23
+ dropout=dropout,
24
+ batch_first=True,
25
+ )
26
+ self.decoder = nn.Sequential(nn.Linear(nhidden, rank), nn.ReLU())
27
+ self.init_weights()
28
+
29
+ def init_weights(self):
30
+ init_range = 0.1
31
+ for m in self.modules():
32
+ if isinstance(m, nn.Linear):
33
+ m.weight.data.uniform_(-init_range, init_range)
34
+ m.bias.data.zero_()
35
+
36
+ def forward(self, Ws, device):
37
+ train_loss = 0.0
38
+ for Wp in Ws:
39
+ inputs, targets = Wp[:-1, :], Wp[1:, :] # seq_len x n_dim
40
+ seq_len, n_dims = inputs.size()
41
+
42
+ hidden = self.init_hidden(1)
43
+ # seq_len x n_dims --> 1 x seq_len x n_dims
44
+ outputs, _ = self.rnn(inputs.unsqueeze(0), hidden)
45
+ logits = self.decoder(outputs.contiguous().view(-1, self.nhid))
46
+ loss = self.loss(logits, targets)
47
+ train_loss += loss
48
+ return train_loss
49
+
50
+ def init_hidden(self, batch_sz):
51
+ size = (self.nlayers, batch_sz, self.nhid)
52
+ weight = next(self.parameters())
53
+ return (weight.new_zeros(*size), weight.new_zeros(*size))
54
+
55
+ def loss(self, input, target):
56
+ return torch.mean((input - target) ** 2)