SURE-tools 2.4.22__py3-none-any.whl → 2.4.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +151 -69
- SURE/DensityFlow2.py +1422 -0
- SURE/DensityFlowLinear.py +1414 -0
- SURE/PerturbationAwareDecoder.py +162 -148
- SURE/__init__.py +3 -1
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/METADATA +1 -1
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/RECORD +11 -9
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/top_level.txt +0 -0
SURE/DensityFlow2.py
ADDED
|
@@ -0,0 +1,1422 @@
|
|
|
1
|
+
import pyro
|
|
2
|
+
import pyro.distributions as dist
|
|
3
|
+
from pyro.optim import ExponentialLR
|
|
4
|
+
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, config_enumerate
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_probs
|
|
10
|
+
from torch.distributions import constraints
|
|
11
|
+
from torch.distributions.transforms import SoftmaxTransform
|
|
12
|
+
|
|
13
|
+
from .utils.custom_mlp import MLP, Exp, ZeroBiasMLP
|
|
14
|
+
from .utils.utils import CustomDataset, CustomDataset2, CustomDataset4, tensor_to_numpy, convert_to_tensor
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import argparse
|
|
19
|
+
import random
|
|
20
|
+
import numpy as np
|
|
21
|
+
import datatable as dt
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
from scipy import sparse
|
|
24
|
+
|
|
25
|
+
import scanpy as sc
|
|
26
|
+
from .atac import binarize
|
|
27
|
+
|
|
28
|
+
from typing import Literal
|
|
29
|
+
|
|
30
|
+
import warnings
|
|
31
|
+
warnings.filterwarnings("ignore")
|
|
32
|
+
|
|
33
|
+
import dill as pickle
|
|
34
|
+
import gzip
|
|
35
|
+
from packaging.version import Version
|
|
36
|
+
torch_version = torch.__version__
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def set_random_seed(seed):
|
|
40
|
+
# Set seed for PyTorch
|
|
41
|
+
torch.manual_seed(seed)
|
|
42
|
+
|
|
43
|
+
# If using CUDA, set the seed for CUDA
|
|
44
|
+
if torch.cuda.is_available():
|
|
45
|
+
torch.cuda.manual_seed(seed)
|
|
46
|
+
torch.cuda.manual_seed_all(seed) # For multi-GPU setups.
|
|
47
|
+
|
|
48
|
+
# Set seed for NumPy
|
|
49
|
+
np.random.seed(seed)
|
|
50
|
+
|
|
51
|
+
# Set seed for Python's random module
|
|
52
|
+
random.seed(seed)
|
|
53
|
+
|
|
54
|
+
# Set seed for Pyro
|
|
55
|
+
pyro.set_rng_seed(seed)
|
|
56
|
+
|
|
57
|
+
class DensityFlow2(nn.Module):
|
|
58
|
+
def __init__(self,
|
|
59
|
+
input_size: int,
|
|
60
|
+
codebook_size: int = 200,
|
|
61
|
+
cell_factor_size: int = 0,
|
|
62
|
+
turn_off_cell_specific: bool = False,
|
|
63
|
+
supervised_mode: bool = False,
|
|
64
|
+
z_dim: int = 10,
|
|
65
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
67
|
+
inverse_dispersion: float = 10.0,
|
|
68
|
+
use_zeroinflate: bool = False,
|
|
69
|
+
hidden_layers: list = [500],
|
|
70
|
+
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
71
|
+
nn_dropout: float = 0.1,
|
|
72
|
+
post_layer_fct: list = ['layernorm'],
|
|
73
|
+
post_act_fct: list = None,
|
|
74
|
+
config_enum: str = 'parallel',
|
|
75
|
+
use_cuda: bool = True,
|
|
76
|
+
seed: int = 42,
|
|
77
|
+
zero_bias: bool|list = True,
|
|
78
|
+
dtype = torch.float32, # type: ignore
|
|
79
|
+
):
|
|
80
|
+
super().__init__()
|
|
81
|
+
|
|
82
|
+
self.input_size = input_size
|
|
83
|
+
self.cell_factor_size = cell_factor_size
|
|
84
|
+
self.inverse_dispersion = inverse_dispersion
|
|
85
|
+
self.latent_dim = z_dim
|
|
86
|
+
self.hidden_layers = hidden_layers
|
|
87
|
+
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
88
|
+
self.allow_broadcast = config_enum == 'parallel'
|
|
89
|
+
self.use_cuda = use_cuda
|
|
90
|
+
self.loss_func = loss_func
|
|
91
|
+
self.options = None
|
|
92
|
+
self.code_size=codebook_size
|
|
93
|
+
self.supervised_mode=supervised_mode
|
|
94
|
+
self.latent_dist = z_dist
|
|
95
|
+
self.dtype = dtype
|
|
96
|
+
self.use_zeroinflate=use_zeroinflate
|
|
97
|
+
self.nn_dropout = nn_dropout
|
|
98
|
+
self.post_layer_fct = post_layer_fct
|
|
99
|
+
self.post_act_fct = post_act_fct
|
|
100
|
+
self.hidden_layer_activation = hidden_layer_activation
|
|
101
|
+
if type(zero_bias) == list:
|
|
102
|
+
self.use_bias = [not x for x in zero_bias]
|
|
103
|
+
else:
|
|
104
|
+
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
105
|
+
#self.use_bias = not zero_bias
|
|
106
|
+
self.turn_off_cell_specific = turn_off_cell_specific
|
|
107
|
+
|
|
108
|
+
self.codebook_weights = None
|
|
109
|
+
|
|
110
|
+
set_random_seed(seed)
|
|
111
|
+
self.setup_networks()
|
|
112
|
+
|
|
113
|
+
print(f"🧬 DensityFlow2 Initialized:")
|
|
114
|
+
print(f" - Latent Dimension: {self.latent_dim}")
|
|
115
|
+
print(f" - Gene Dimension: {self.input_size}")
|
|
116
|
+
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
117
|
+
print(f" - Device: {self.get_device()}")
|
|
118
|
+
print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
|
|
119
|
+
|
|
120
|
+
def setup_networks(self):
|
|
121
|
+
latent_dim = self.latent_dim
|
|
122
|
+
hidden_sizes = self.hidden_layers
|
|
123
|
+
|
|
124
|
+
nn_layer_norm, nn_batch_norm, nn_layer_dropout = False, False, False
|
|
125
|
+
na_layer_norm, na_batch_norm, na_layer_dropout = False, False, False
|
|
126
|
+
|
|
127
|
+
if self.post_layer_fct is not None:
|
|
128
|
+
nn_layer_norm=True if ('layernorm' in self.post_layer_fct) or ('layer_norm' in self.post_layer_fct) else False
|
|
129
|
+
nn_batch_norm=True if ('batchnorm' in self.post_layer_fct) or ('batch_norm' in self.post_layer_fct) else False
|
|
130
|
+
nn_layer_dropout=True if 'dropout' in self.post_layer_fct else False
|
|
131
|
+
|
|
132
|
+
if self.post_act_fct is not None:
|
|
133
|
+
na_layer_norm=True if ('layernorm' in self.post_act_fct) or ('layer_norm' in self.post_act_fct) else False
|
|
134
|
+
na_batch_norm=True if ('batchnorm' in self.post_act_fct) or ('batch_norm' in self.post_act_fct) else False
|
|
135
|
+
na_layer_dropout=True if 'dropout' in self.post_act_fct else False
|
|
136
|
+
|
|
137
|
+
if nn_layer_norm and nn_batch_norm and nn_layer_dropout:
|
|
138
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
|
|
139
|
+
elif nn_layer_norm and nn_layer_dropout:
|
|
140
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
|
|
141
|
+
elif nn_batch_norm and nn_layer_dropout:
|
|
142
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
|
|
143
|
+
elif nn_layer_norm and nn_batch_norm:
|
|
144
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
|
|
145
|
+
elif nn_layer_norm:
|
|
146
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
|
|
147
|
+
elif nn_batch_norm:
|
|
148
|
+
post_layer_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
|
|
149
|
+
elif nn_layer_dropout:
|
|
150
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
|
|
151
|
+
else:
|
|
152
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: None
|
|
153
|
+
|
|
154
|
+
if na_layer_norm and na_batch_norm and na_layer_dropout:
|
|
155
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
|
|
156
|
+
elif na_layer_norm and na_layer_dropout:
|
|
157
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
|
|
158
|
+
elif na_batch_norm and na_layer_dropout:
|
|
159
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
|
|
160
|
+
elif na_layer_norm and na_batch_norm:
|
|
161
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
|
|
162
|
+
elif na_layer_norm:
|
|
163
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
|
|
164
|
+
elif na_batch_norm:
|
|
165
|
+
post_act_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
|
|
166
|
+
elif na_layer_dropout:
|
|
167
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
|
|
168
|
+
else:
|
|
169
|
+
post_act_fct = lambda layer_ix, total_layers, layer: None
|
|
170
|
+
|
|
171
|
+
if self.hidden_layer_activation == 'relu':
|
|
172
|
+
activate_fct = nn.ReLU
|
|
173
|
+
elif self.hidden_layer_activation == 'softplus':
|
|
174
|
+
activate_fct = nn.Softplus
|
|
175
|
+
elif self.hidden_layer_activation == 'leakyrelu':
|
|
176
|
+
activate_fct = nn.LeakyReLU
|
|
177
|
+
elif self.hidden_layer_activation == 'linear':
|
|
178
|
+
activate_fct = nn.Identity
|
|
179
|
+
|
|
180
|
+
if self.supervised_mode:
|
|
181
|
+
self.encoder_n = MLP(
|
|
182
|
+
[self.input_size] + hidden_sizes + [self.code_size],
|
|
183
|
+
activation=activate_fct,
|
|
184
|
+
output_activation=None,
|
|
185
|
+
post_layer_fct=post_layer_fct,
|
|
186
|
+
post_act_fct=post_act_fct,
|
|
187
|
+
allow_broadcast=self.allow_broadcast,
|
|
188
|
+
use_cuda=self.use_cuda,
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
self.encoder_n = MLP(
|
|
192
|
+
[self.latent_dim] + hidden_sizes + [self.code_size],
|
|
193
|
+
activation=activate_fct,
|
|
194
|
+
output_activation=None,
|
|
195
|
+
post_layer_fct=post_layer_fct,
|
|
196
|
+
post_act_fct=post_act_fct,
|
|
197
|
+
allow_broadcast=self.allow_broadcast,
|
|
198
|
+
use_cuda=self.use_cuda,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
self.encoder_zn = MLP(
|
|
202
|
+
[self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
|
|
203
|
+
activation=activate_fct,
|
|
204
|
+
output_activation=[None, Exp],
|
|
205
|
+
post_layer_fct=post_layer_fct,
|
|
206
|
+
post_act_fct=post_act_fct,
|
|
207
|
+
allow_broadcast=self.allow_broadcast,
|
|
208
|
+
use_cuda=self.use_cuda,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
if self.loss_func == 'negbinomial':
|
|
212
|
+
self.encoder_inverse_dispersion = MLP(
|
|
213
|
+
[self.latent_dim] + hidden_sizes + [[self.input_size, self.input_size]],
|
|
214
|
+
activation=activate_fct,
|
|
215
|
+
output_activation=[Exp, Exp],
|
|
216
|
+
post_layer_fct=post_layer_fct,
|
|
217
|
+
post_act_fct=post_act_fct,
|
|
218
|
+
allow_broadcast=self.allow_broadcast,
|
|
219
|
+
use_cuda=self.use_cuda,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if self.cell_factor_size>0:
|
|
223
|
+
self.cell_factor_effect = nn.ModuleList()
|
|
224
|
+
for i in np.arange(self.cell_factor_size):
|
|
225
|
+
if self.use_bias[i]:
|
|
226
|
+
if self.turn_off_cell_specific:
|
|
227
|
+
self.cell_factor_effect.append(MLP(
|
|
228
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
229
|
+
activation=activate_fct,
|
|
230
|
+
output_activation=None,
|
|
231
|
+
post_layer_fct=post_layer_fct,
|
|
232
|
+
post_act_fct=post_act_fct,
|
|
233
|
+
allow_broadcast=self.allow_broadcast,
|
|
234
|
+
use_cuda=self.use_cuda,
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
else:
|
|
238
|
+
self.cell_factor_effect.append(MLP(
|
|
239
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
240
|
+
activation=activate_fct,
|
|
241
|
+
output_activation=None,
|
|
242
|
+
post_layer_fct=post_layer_fct,
|
|
243
|
+
post_act_fct=post_act_fct,
|
|
244
|
+
allow_broadcast=self.allow_broadcast,
|
|
245
|
+
use_cuda=self.use_cuda,
|
|
246
|
+
)
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
if self.turn_off_cell_specific:
|
|
250
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
251
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
252
|
+
activation=activate_fct,
|
|
253
|
+
output_activation=None,
|
|
254
|
+
post_layer_fct=post_layer_fct,
|
|
255
|
+
post_act_fct=post_act_fct,
|
|
256
|
+
allow_broadcast=self.allow_broadcast,
|
|
257
|
+
use_cuda=self.use_cuda,
|
|
258
|
+
)
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
262
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
263
|
+
activation=activate_fct,
|
|
264
|
+
output_activation=None,
|
|
265
|
+
post_layer_fct=post_layer_fct,
|
|
266
|
+
post_act_fct=post_act_fct,
|
|
267
|
+
allow_broadcast=self.allow_broadcast,
|
|
268
|
+
use_cuda=self.use_cuda,
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
self.decoder_concentrate = MLP(
|
|
273
|
+
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
274
|
+
activation=activate_fct,
|
|
275
|
+
output_activation=None,
|
|
276
|
+
post_layer_fct=post_layer_fct,
|
|
277
|
+
post_act_fct=post_act_fct,
|
|
278
|
+
allow_broadcast=self.allow_broadcast,
|
|
279
|
+
use_cuda=self.use_cuda,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if self.latent_dist == 'studentt':
|
|
283
|
+
self.codebook = MLP(
|
|
284
|
+
[self.code_size] + hidden_sizes + [[latent_dim,latent_dim]],
|
|
285
|
+
activation=activate_fct,
|
|
286
|
+
output_activation=[Exp,None],
|
|
287
|
+
post_layer_fct=post_layer_fct,
|
|
288
|
+
post_act_fct=post_act_fct,
|
|
289
|
+
allow_broadcast=self.allow_broadcast,
|
|
290
|
+
use_cuda=self.use_cuda,
|
|
291
|
+
)
|
|
292
|
+
else:
|
|
293
|
+
self.codebook = MLP(
|
|
294
|
+
[self.code_size] + hidden_sizes + [latent_dim],
|
|
295
|
+
activation=activate_fct,
|
|
296
|
+
output_activation=None,
|
|
297
|
+
post_layer_fct=post_layer_fct,
|
|
298
|
+
post_act_fct=post_act_fct,
|
|
299
|
+
allow_broadcast=self.allow_broadcast,
|
|
300
|
+
use_cuda=self.use_cuda,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
if self.use_cuda:
|
|
304
|
+
self.cuda()
|
|
305
|
+
|
|
306
|
+
def get_device(self):
|
|
307
|
+
return next(self.parameters()).device
|
|
308
|
+
|
|
309
|
+
def cutoff(self, xs, thresh=None):
|
|
310
|
+
eps = torch.finfo(xs.dtype).eps
|
|
311
|
+
|
|
312
|
+
if not thresh is None:
|
|
313
|
+
if eps < thresh:
|
|
314
|
+
eps = thresh
|
|
315
|
+
|
|
316
|
+
xs = xs.clamp(min=eps)
|
|
317
|
+
|
|
318
|
+
if torch.any(torch.isnan(xs)):
|
|
319
|
+
xs[torch.isnan(xs)] = eps
|
|
320
|
+
|
|
321
|
+
return xs
|
|
322
|
+
|
|
323
|
+
def softmax(self, xs):
|
|
324
|
+
#xs = SoftmaxTransform()(xs)
|
|
325
|
+
xs = dist.Multinomial(total_count=1, logits=xs).mean
|
|
326
|
+
return xs
|
|
327
|
+
|
|
328
|
+
def sigmoid(self, xs):
|
|
329
|
+
#sigm_enc = nn.Sigmoid()
|
|
330
|
+
#xs = sigm_enc(xs)
|
|
331
|
+
#xs = clamp_probs(xs)
|
|
332
|
+
xs = dist.Bernoulli(logits=xs).mean
|
|
333
|
+
return xs
|
|
334
|
+
|
|
335
|
+
def softmax_logit(self, xs):
|
|
336
|
+
eps = torch.finfo(xs.dtype).eps
|
|
337
|
+
xs = self.softmax(xs)
|
|
338
|
+
xs = torch.logit(xs, eps=eps)
|
|
339
|
+
return xs
|
|
340
|
+
|
|
341
|
+
def logit(self, xs):
|
|
342
|
+
eps = torch.finfo(xs.dtype).eps
|
|
343
|
+
xs = torch.logit(xs, eps=eps)
|
|
344
|
+
return xs
|
|
345
|
+
|
|
346
|
+
def dirimulti_param(self, xs):
|
|
347
|
+
xs = self.dirimulti_mass * self.sigmoid(xs)
|
|
348
|
+
return xs
|
|
349
|
+
|
|
350
|
+
def multi_param(self, xs):
|
|
351
|
+
xs = self.softmax(xs)
|
|
352
|
+
return xs
|
|
353
|
+
|
|
354
|
+
def model1(self, xs):
|
|
355
|
+
pyro.module('DensityFlow2', self)
|
|
356
|
+
|
|
357
|
+
eps = torch.finfo(xs.dtype).eps
|
|
358
|
+
batch_size = xs.size(0)
|
|
359
|
+
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
360
|
+
|
|
361
|
+
if self.loss_func=='negbinomial':
|
|
362
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
363
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
364
|
+
|
|
365
|
+
if self.use_zeroinflate:
|
|
366
|
+
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
367
|
+
|
|
368
|
+
acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
|
|
369
|
+
|
|
370
|
+
I = torch.eye(self.code_size)
|
|
371
|
+
if self.latent_dist=='studentt':
|
|
372
|
+
acs_dof,acs_loc = self.codebook(I)
|
|
373
|
+
else:
|
|
374
|
+
acs_loc = self.codebook(I)
|
|
375
|
+
|
|
376
|
+
with pyro.plate('data'):
|
|
377
|
+
prior = torch.zeros(batch_size, self.code_size, **self.options)
|
|
378
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
|
|
379
|
+
|
|
380
|
+
zn_loc = torch.matmul(ns,acs_loc)
|
|
381
|
+
#zn_scale = torch.matmul(ns,acs_scale)
|
|
382
|
+
zn_scale = acs_scale
|
|
383
|
+
|
|
384
|
+
if self.latent_dist == 'studentt':
|
|
385
|
+
prior_dof = torch.matmul(ns,acs_dof)
|
|
386
|
+
zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
|
|
387
|
+
elif self.latent_dist == 'laplacian':
|
|
388
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
|
|
389
|
+
elif self.latent_dist == 'cauchy':
|
|
390
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
|
|
391
|
+
elif self.latent_dist == 'normal':
|
|
392
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
393
|
+
elif self.latent_dist == 'gumbel':
|
|
394
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
395
|
+
|
|
396
|
+
zs = zns
|
|
397
|
+
concentrate = self.decoder_concentrate(zs)
|
|
398
|
+
if self.loss_func in ['bernoulli']:
|
|
399
|
+
log_theta = concentrate
|
|
400
|
+
else:
|
|
401
|
+
rate = concentrate.exp()
|
|
402
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
403
|
+
if self.loss_func == 'poisson':
|
|
404
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
405
|
+
|
|
406
|
+
if self.loss_func == 'negbinomial':
|
|
407
|
+
if self.use_zeroinflate:
|
|
408
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
409
|
+
else:
|
|
410
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
411
|
+
elif self.loss_func == 'poisson':
|
|
412
|
+
if self.use_zeroinflate:
|
|
413
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
414
|
+
else:
|
|
415
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
416
|
+
elif self.loss_func == 'multinomial':
|
|
417
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
418
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
419
|
+
elif self.loss_func == 'bernoulli':
|
|
420
|
+
if self.use_zeroinflate:
|
|
421
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
422
|
+
else:
|
|
423
|
+
pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
|
|
424
|
+
|
|
425
|
+
def guide1(self, xs):
|
|
426
|
+
with pyro.plate('data'):
|
|
427
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
428
|
+
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
429
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
430
|
+
|
|
431
|
+
alpha = self.encoder_n(zns)
|
|
432
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
433
|
+
|
|
434
|
+
def model2(self, xs, us=None):
|
|
435
|
+
pyro.module('DensityFlow2', self)
|
|
436
|
+
|
|
437
|
+
eps = torch.finfo(xs.dtype).eps
|
|
438
|
+
batch_size = xs.size(0)
|
|
439
|
+
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
440
|
+
|
|
441
|
+
if self.loss_func=='negbinomial':
|
|
442
|
+
#total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
443
|
+
# xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
444
|
+
with pyro.plate("genes", self.input_size):
|
|
445
|
+
inverse_dispersion = pyro.sample("inverse_dispersion", dist.LogNormal(self.inverse_dispersion, 0.5).to_event(1))
|
|
446
|
+
|
|
447
|
+
if self.use_zeroinflate:
|
|
448
|
+
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
449
|
+
|
|
450
|
+
acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
|
|
451
|
+
|
|
452
|
+
I = torch.eye(self.code_size)
|
|
453
|
+
if self.latent_dist=='studentt':
|
|
454
|
+
acs_dof,acs_loc = self.codebook(I)
|
|
455
|
+
else:
|
|
456
|
+
acs_loc = self.codebook(I)
|
|
457
|
+
|
|
458
|
+
with pyro.plate('data'):
|
|
459
|
+
prior = torch.zeros(batch_size, self.code_size, **self.options)
|
|
460
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
|
|
461
|
+
|
|
462
|
+
zn_loc = torch.matmul(ns,acs_loc)
|
|
463
|
+
#zn_scale = torch.matmul(ns,acs_scale)
|
|
464
|
+
zn_scale = acs_scale
|
|
465
|
+
|
|
466
|
+
if self.latent_dist == 'studentt':
|
|
467
|
+
prior_dof = torch.matmul(ns,acs_dof)
|
|
468
|
+
zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
|
|
469
|
+
elif self.latent_dist == 'laplacian':
|
|
470
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
|
|
471
|
+
elif self.latent_dist == 'cauchy':
|
|
472
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
|
|
473
|
+
elif self.latent_dist == 'normal':
|
|
474
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
475
|
+
elif self.latent_dist == 'gumbel':
|
|
476
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
477
|
+
|
|
478
|
+
'''if self.cell_factor_size>0:
|
|
479
|
+
zus = self._total_shifts(zns, us)
|
|
480
|
+
zs = zns+zus
|
|
481
|
+
else:
|
|
482
|
+
zs = zns'''
|
|
483
|
+
|
|
484
|
+
zs = zns
|
|
485
|
+
concentrate = self.decoder_concentrate(zs)
|
|
486
|
+
for i in np.arange(self.cell_factor_size):
|
|
487
|
+
zus = self._cell_shift(zs, i, us[:,i].reshape(-1,1))
|
|
488
|
+
concentrate += self.decoder_concentrate(zus)
|
|
489
|
+
|
|
490
|
+
if self.loss_func in ['bernoulli']:
|
|
491
|
+
log_theta = concentrate
|
|
492
|
+
elif self.loss_func in ['negbinomial']:
|
|
493
|
+
mu = concentrate.exp()
|
|
494
|
+
else:
|
|
495
|
+
rate = concentrate.exp()
|
|
496
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
497
|
+
if self.loss_func == 'poisson':
|
|
498
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
499
|
+
|
|
500
|
+
if self.loss_func == 'negbinomial':
|
|
501
|
+
logits = (mu.log()-inverse_dispersion.log()).clamp(min=-10, max=10)
|
|
502
|
+
if self.use_zeroinflate:
|
|
503
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=inverse_dispersion,
|
|
504
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
505
|
+
else:
|
|
506
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=inverse_dispersion,
|
|
507
|
+
logits=logits).to_event(1), obs=xs)
|
|
508
|
+
elif self.loss_func == 'poisson':
|
|
509
|
+
if self.use_zeroinflate:
|
|
510
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
511
|
+
else:
|
|
512
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
513
|
+
elif self.loss_func == 'multinomial':
|
|
514
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
515
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
516
|
+
elif self.loss_func == 'bernoulli':
|
|
517
|
+
if self.use_zeroinflate:
|
|
518
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
519
|
+
else:
|
|
520
|
+
pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
|
|
521
|
+
|
|
522
|
+
def guide2(self, xs, us=None):
|
|
523
|
+
with pyro.plate('data'):
|
|
524
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
525
|
+
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
526
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
527
|
+
|
|
528
|
+
alpha = self.encoder_n(zns)
|
|
529
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
530
|
+
|
|
531
|
+
if self.loss_func == 'negbinomial':
|
|
532
|
+
id_loc,id_scale = self.encoder_inverse_dispersion(zns)
|
|
533
|
+
with pyro.plate("genes", self.input_size):
|
|
534
|
+
pyro.sample("inverse_dispersion", dist.LogNormal(id_loc, id_scale).to_event(1))
|
|
535
|
+
|
|
536
|
+
def model3(self, xs, ys, embeds=None):
|
|
537
|
+
pyro.module('DensityFlow2', self)
|
|
538
|
+
|
|
539
|
+
eps = torch.finfo(xs.dtype).eps
|
|
540
|
+
batch_size = xs.size(0)
|
|
541
|
+
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
542
|
+
|
|
543
|
+
if self.loss_func=='negbinomial':
|
|
544
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
545
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
546
|
+
|
|
547
|
+
if self.use_zeroinflate:
|
|
548
|
+
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
549
|
+
|
|
550
|
+
acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
|
|
551
|
+
|
|
552
|
+
I = torch.eye(self.code_size)
|
|
553
|
+
if self.latent_dist=='studentt':
|
|
554
|
+
acs_dof,acs_loc = self.codebook(I)
|
|
555
|
+
else:
|
|
556
|
+
acs_loc = self.codebook(I)
|
|
557
|
+
|
|
558
|
+
with pyro.plate('data'):
|
|
559
|
+
#prior = torch.zeros(batch_size, self.code_size, **self.options)
|
|
560
|
+
prior = self.encoder_n(xs)
|
|
561
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
|
|
562
|
+
|
|
563
|
+
zn_loc = torch.matmul(ns,acs_loc)
|
|
564
|
+
#prior_scale = torch.matmul(ns,acs_scale)
|
|
565
|
+
zn_scale = acs_scale
|
|
566
|
+
|
|
567
|
+
if self.latent_dist=='studentt':
|
|
568
|
+
prior_dof = torch.matmul(ns,acs_dof)
|
|
569
|
+
if embeds is None:
|
|
570
|
+
zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
|
|
571
|
+
else:
|
|
572
|
+
zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
|
|
573
|
+
elif self.latent_dist=='laplacian':
|
|
574
|
+
if embeds is None:
|
|
575
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
|
|
576
|
+
else:
|
|
577
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
578
|
+
elif self.latent_dist=='cauchy':
|
|
579
|
+
if embeds is None:
|
|
580
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
|
|
581
|
+
else:
|
|
582
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
583
|
+
elif self.latent_dist=='normal':
|
|
584
|
+
if embeds is None:
|
|
585
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
586
|
+
else:
|
|
587
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
588
|
+
elif self.z_dist == 'gumbel':
|
|
589
|
+
if embeds is None:
|
|
590
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
591
|
+
else:
|
|
592
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
593
|
+
|
|
594
|
+
zs = zns
|
|
595
|
+
|
|
596
|
+
concentrate = self.decoder_concentrate(zs)
|
|
597
|
+
if self.loss_func in ['bernoulli']:
|
|
598
|
+
log_theta = concentrate
|
|
599
|
+
else:
|
|
600
|
+
rate = concentrate.exp()
|
|
601
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
602
|
+
if self.loss_func == 'poisson':
|
|
603
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
604
|
+
|
|
605
|
+
if self.loss_func == 'negbinomial':
|
|
606
|
+
if self.use_zeroinflate:
|
|
607
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
608
|
+
else:
|
|
609
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
610
|
+
elif self.loss_func == 'poisson':
|
|
611
|
+
if self.use_zeroinflate:
|
|
612
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
613
|
+
else:
|
|
614
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
615
|
+
elif self.loss_func == 'multinomial':
|
|
616
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
617
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
618
|
+
elif self.loss_func == 'bernoulli':
|
|
619
|
+
if self.use_zeroinflate:
|
|
620
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
621
|
+
else:
|
|
622
|
+
pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
|
|
623
|
+
|
|
624
|
+
def guide3(self, xs, ys, embeds=None):
|
|
625
|
+
with pyro.plate('data'):
|
|
626
|
+
if embeds is None:
|
|
627
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
628
|
+
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
629
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
630
|
+
else:
|
|
631
|
+
zns = embeds
|
|
632
|
+
|
|
633
|
+
def model4(self, xs, us, ys, embeds=None):
|
|
634
|
+
pyro.module('DensityFlow2', self)
|
|
635
|
+
|
|
636
|
+
eps = torch.finfo(xs.dtype).eps
|
|
637
|
+
batch_size = xs.size(0)
|
|
638
|
+
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
639
|
+
|
|
640
|
+
if self.loss_func=='negbinomial':
|
|
641
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
642
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
643
|
+
|
|
644
|
+
if self.use_zeroinflate:
|
|
645
|
+
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
646
|
+
|
|
647
|
+
acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
|
|
648
|
+
|
|
649
|
+
I = torch.eye(self.code_size)
|
|
650
|
+
if self.latent_dist=='studentt':
|
|
651
|
+
acs_dof,acs_loc = self.codebook(I)
|
|
652
|
+
else:
|
|
653
|
+
acs_loc = self.codebook(I)
|
|
654
|
+
|
|
655
|
+
with pyro.plate('data'):
|
|
656
|
+
#prior = torch.zeros(batch_size, self.code_size, **self.options)
|
|
657
|
+
prior = self.encoder_n(xs)
|
|
658
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
|
|
659
|
+
|
|
660
|
+
zn_loc = torch.matmul(ns,acs_loc)
|
|
661
|
+
#prior_scale = torch.matmul(ns,acs_scale)
|
|
662
|
+
zn_scale = acs_scale
|
|
663
|
+
|
|
664
|
+
if self.latent_dist=='studentt':
|
|
665
|
+
prior_dof = torch.matmul(ns,acs_dof)
|
|
666
|
+
if embeds is None:
|
|
667
|
+
zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
|
|
668
|
+
else:
|
|
669
|
+
zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
|
|
670
|
+
elif self.latent_dist=='laplacian':
|
|
671
|
+
if embeds is None:
|
|
672
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
|
|
673
|
+
else:
|
|
674
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
675
|
+
elif self.latent_dist=='cauchy':
|
|
676
|
+
if embeds is None:
|
|
677
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
|
|
678
|
+
else:
|
|
679
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
680
|
+
elif self.latent_dist=='normal':
|
|
681
|
+
if embeds is None:
|
|
682
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
683
|
+
else:
|
|
684
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
685
|
+
elif self.z_dist == 'gumbel':
|
|
686
|
+
if embeds is None:
|
|
687
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
688
|
+
else:
|
|
689
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
690
|
+
|
|
691
|
+
'''if self.cell_factor_size>0:
|
|
692
|
+
zus = self._total_shifts(zns, us)
|
|
693
|
+
zs = zns+zus
|
|
694
|
+
else:
|
|
695
|
+
zs = zns'''
|
|
696
|
+
|
|
697
|
+
zs = zns
|
|
698
|
+
concentrate = self.decoder_concentrate(zs)
|
|
699
|
+
for i in np.arange(self.cell_factor_size):
|
|
700
|
+
zus = self._cell_shift(zs, i, us[:,i].reshape(-1,1))
|
|
701
|
+
concentrate += self.decoder_concentrate(zus)
|
|
702
|
+
|
|
703
|
+
if self.loss_func in ['bernoulli']:
|
|
704
|
+
log_theta = concentrate
|
|
705
|
+
else:
|
|
706
|
+
rate = concentrate.exp()
|
|
707
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
708
|
+
if self.loss_func == 'poisson':
|
|
709
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
710
|
+
|
|
711
|
+
if self.loss_func == 'negbinomial':
|
|
712
|
+
if self.use_zeroinflate:
|
|
713
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
714
|
+
else:
|
|
715
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
716
|
+
elif self.loss_func == 'poisson':
|
|
717
|
+
if self.use_zeroinflate:
|
|
718
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
719
|
+
else:
|
|
720
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
721
|
+
elif self.loss_func == 'multinomial':
|
|
722
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
723
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
724
|
+
elif self.loss_func == 'bernoulli':
|
|
725
|
+
if self.use_zeroinflate:
|
|
726
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
727
|
+
else:
|
|
728
|
+
pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
|
|
729
|
+
|
|
730
|
+
def guide4(self, xs, us, ys, embeds=None):
|
|
731
|
+
with pyro.plate('data'):
|
|
732
|
+
if embeds is None:
|
|
733
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
734
|
+
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
735
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
736
|
+
else:
|
|
737
|
+
zns = embeds
|
|
738
|
+
|
|
739
|
+
def _total_shifts(self, zns, us):
|
|
740
|
+
zus = None
|
|
741
|
+
for i in np.arange(self.cell_factor_size):
|
|
742
|
+
if i==0:
|
|
743
|
+
#if self.turn_off_cell_specific:
|
|
744
|
+
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
745
|
+
#else:
|
|
746
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
747
|
+
zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
748
|
+
else:
|
|
749
|
+
#if self.turn_off_cell_specific:
|
|
750
|
+
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
751
|
+
#else:
|
|
752
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
753
|
+
zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
754
|
+
return zus
|
|
755
|
+
|
|
756
|
+
def _get_codebook_identity(self):
|
|
757
|
+
return torch.eye(self.code_size, **self.options)
|
|
758
|
+
|
|
759
|
+
def _get_codebook(self):
|
|
760
|
+
I = torch.eye(self.code_size, **self.options)
|
|
761
|
+
if self.latent_dist=='studentt':
|
|
762
|
+
_,cb = self.codebook(I)
|
|
763
|
+
else:
|
|
764
|
+
cb = self.codebook(I)
|
|
765
|
+
return cb
|
|
766
|
+
|
|
767
|
+
def get_codebook(self):
|
|
768
|
+
"""
|
|
769
|
+
Return the mean part of metacell codebook
|
|
770
|
+
"""
|
|
771
|
+
cb = self._get_codebook()
|
|
772
|
+
cb = tensor_to_numpy(cb)
|
|
773
|
+
return cb
|
|
774
|
+
|
|
775
|
+
def _get_basal_embedding(self, xs):
|
|
776
|
+
loc, scale = self.encoder_zn(xs)
|
|
777
|
+
return loc, scale
|
|
778
|
+
|
|
779
|
+
def get_basal_embedding(self,
|
|
780
|
+
xs,
|
|
781
|
+
batch_size: int = 1024):
|
|
782
|
+
"""
|
|
783
|
+
Return cells' basal latent representations
|
|
784
|
+
|
|
785
|
+
Parameters
|
|
786
|
+
----------
|
|
787
|
+
xs
|
|
788
|
+
Single-cell expression matrix. It should be a Numpy array or a Pytorch Tensor.
|
|
789
|
+
batch_size
|
|
790
|
+
Size of batch processing.
|
|
791
|
+
use_decoder
|
|
792
|
+
If toggled on, the latent representations will be reconstructed from the metacell codebook
|
|
793
|
+
soft_assign
|
|
794
|
+
If toggled on, the assignments of cells will use probabilistic values.
|
|
795
|
+
"""
|
|
796
|
+
xs = self.preprocess(xs)
|
|
797
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
798
|
+
dataset = CustomDataset(xs)
|
|
799
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
800
|
+
|
|
801
|
+
Z = []
|
|
802
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
803
|
+
for X_batch, _ in dataloader:
|
|
804
|
+
zns,_ = self._get_basal_embedding(X_batch)
|
|
805
|
+
Z.append(tensor_to_numpy(zns))
|
|
806
|
+
pbar.update(1)
|
|
807
|
+
|
|
808
|
+
Z = np.concatenate(Z)
|
|
809
|
+
return Z
|
|
810
|
+
|
|
811
|
+
def _code(self, xs):
|
|
812
|
+
if self.supervised_mode:
|
|
813
|
+
alpha = self.encoder_n(xs)
|
|
814
|
+
else:
|
|
815
|
+
#zns,_ = self.encoder_zn(xs)
|
|
816
|
+
zns,_ = self._get_basal_embedding(xs)
|
|
817
|
+
alpha = self.encoder_n(zns)
|
|
818
|
+
return alpha
|
|
819
|
+
|
|
820
|
+
def code(self, xs, batch_size=1024):
|
|
821
|
+
xs = self.preprocess(xs)
|
|
822
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
823
|
+
dataset = CustomDataset(xs)
|
|
824
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
825
|
+
|
|
826
|
+
A = []
|
|
827
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
828
|
+
for X_batch, _ in dataloader:
|
|
829
|
+
a = self._code(X_batch)
|
|
830
|
+
A.append(tensor_to_numpy(a))
|
|
831
|
+
pbar.update(1)
|
|
832
|
+
|
|
833
|
+
A = np.concatenate(A)
|
|
834
|
+
return A
|
|
835
|
+
|
|
836
|
+
def _soft_assignments(self, xs):
|
|
837
|
+
alpha = self._code(xs)
|
|
838
|
+
alpha = self.softmax(alpha)
|
|
839
|
+
return alpha
|
|
840
|
+
|
|
841
|
+
def soft_assignments(self, xs, batch_size=1024):
|
|
842
|
+
"""
|
|
843
|
+
Map cells to metacells and return the probabilistic values of metacell assignments
|
|
844
|
+
"""
|
|
845
|
+
xs = self.preprocess(xs)
|
|
846
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
847
|
+
dataset = CustomDataset(xs)
|
|
848
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
849
|
+
|
|
850
|
+
A = []
|
|
851
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
852
|
+
for X_batch, _ in dataloader:
|
|
853
|
+
a = self._soft_assignments(X_batch)
|
|
854
|
+
A.append(tensor_to_numpy(a))
|
|
855
|
+
pbar.update(1)
|
|
856
|
+
|
|
857
|
+
A = np.concatenate(A)
|
|
858
|
+
return A
|
|
859
|
+
|
|
860
|
+
def _hard_assignments(self, xs):
|
|
861
|
+
alpha = self._code(xs)
|
|
862
|
+
res, ind = torch.topk(alpha, 1)
|
|
863
|
+
ns = torch.zeros_like(alpha).scatter_(1, ind, 1.0)
|
|
864
|
+
return ns
|
|
865
|
+
|
|
866
|
+
def hard_assignments(self, xs, batch_size=1024):
|
|
867
|
+
"""
|
|
868
|
+
Map cells to metacells and return the assigned metacell identities.
|
|
869
|
+
"""
|
|
870
|
+
xs = self.preprocess(xs)
|
|
871
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
872
|
+
dataset = CustomDataset(xs)
|
|
873
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
874
|
+
|
|
875
|
+
A = []
|
|
876
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
877
|
+
for X_batch, _ in dataloader:
|
|
878
|
+
a = self._hard_assignments(X_batch)
|
|
879
|
+
A.append(tensor_to_numpy(a))
|
|
880
|
+
pbar.update(1)
|
|
881
|
+
|
|
882
|
+
A = np.concatenate(A)
|
|
883
|
+
return A
|
|
884
|
+
|
|
885
|
+
def predict(self, xs, us, perturbs_predict:list, perturbs_reference:list, library_sizes=None):
|
|
886
|
+
perturbs_reference = np.array(perturbs_reference)
|
|
887
|
+
|
|
888
|
+
# basal embedding
|
|
889
|
+
zs = self.get_basal_embedding(xs)
|
|
890
|
+
for pert in perturbs_predict:
|
|
891
|
+
pert_idx = int(np.where(perturbs_reference==pert)[0])
|
|
892
|
+
us_i = us[:,pert_idx].reshape(-1,1)
|
|
893
|
+
|
|
894
|
+
# factor effect of xs
|
|
895
|
+
dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
896
|
+
|
|
897
|
+
# perturbation effect
|
|
898
|
+
ps = np.ones_like(us_i)
|
|
899
|
+
if np.sum(np.abs(ps-us_i))>=1:
|
|
900
|
+
dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
901
|
+
zs = zs + dzs0 + dzs
|
|
902
|
+
else:
|
|
903
|
+
zs = zs + dzs0
|
|
904
|
+
|
|
905
|
+
if library_sizes is None:
|
|
906
|
+
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
907
|
+
elif type(library_sizes) == list:
|
|
908
|
+
library_sizes = np.array(library_sizes)
|
|
909
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
910
|
+
elif len(library_sizes.shape)==1:
|
|
911
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
912
|
+
|
|
913
|
+
counts = self.get_counts(zs, library_sizes=library_sizes)
|
|
914
|
+
|
|
915
|
+
return counts, zs
|
|
916
|
+
|
|
917
|
+
def _cell_shift(self, zs, perturb_idx, perturb):
|
|
918
|
+
#zns,_ = self.encoder_zn(xs)
|
|
919
|
+
#zns,_ = self._get_basal_embedding(xs)
|
|
920
|
+
zns = zs
|
|
921
|
+
if perturb.ndim==2:
|
|
922
|
+
if self.turn_off_cell_specific:
|
|
923
|
+
ms = self.cell_factor_effect[perturb_idx](perturb)
|
|
924
|
+
else:
|
|
925
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
|
|
926
|
+
else:
|
|
927
|
+
if self.turn_off_cell_specific:
|
|
928
|
+
ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
|
|
929
|
+
else:
|
|
930
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
|
|
931
|
+
|
|
932
|
+
return ms
|
|
933
|
+
|
|
934
|
+
def get_cell_shift(self,
|
|
935
|
+
zs,
|
|
936
|
+
perturb_idx,
|
|
937
|
+
perturb_us,
|
|
938
|
+
batch_size: int = 1024):
|
|
939
|
+
"""
|
|
940
|
+
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
941
|
+
|
|
942
|
+
"""
|
|
943
|
+
#xs = self.preprocess(xs)
|
|
944
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
945
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
946
|
+
dataset = CustomDataset2(zs,ps)
|
|
947
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
948
|
+
|
|
949
|
+
Z = []
|
|
950
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
951
|
+
for Z_batch, P_batch, _ in dataloader:
|
|
952
|
+
zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
|
|
953
|
+
Z.append(tensor_to_numpy(zns))
|
|
954
|
+
pbar.update(1)
|
|
955
|
+
|
|
956
|
+
Z = np.concatenate(Z)
|
|
957
|
+
return Z
|
|
958
|
+
|
|
959
|
+
def _get_theta(self, delta_zs):
|
|
960
|
+
return self.decoder_concentrate(delta_zs)
|
|
961
|
+
|
|
962
|
+
def get_theta(self,
|
|
963
|
+
delta_zs,
|
|
964
|
+
batch_size: int = 1024):
|
|
965
|
+
"""
|
|
966
|
+
Return cells' changes in the feature space induced by specific perturbation of a factor
|
|
967
|
+
|
|
968
|
+
"""
|
|
969
|
+
delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
|
|
970
|
+
dataset = CustomDataset(delta_zs)
|
|
971
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
972
|
+
|
|
973
|
+
R = []
|
|
974
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
975
|
+
for delta_Z_batch, _ in dataloader:
|
|
976
|
+
r = self._get_theta(delta_Z_batch)
|
|
977
|
+
R.append(tensor_to_numpy(r))
|
|
978
|
+
pbar.update(1)
|
|
979
|
+
|
|
980
|
+
R = np.concatenate(R)
|
|
981
|
+
return R
|
|
982
|
+
|
|
983
|
+
def _count(self, concentrate, library_size=None):
|
|
984
|
+
if self.loss_func == 'bernoulli':
|
|
985
|
+
#counts = self.sigmoid(concentrate)
|
|
986
|
+
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
987
|
+
elif self.loss_func == 'multinomial':
|
|
988
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
|
|
989
|
+
counts = theta * library_size
|
|
990
|
+
else:
|
|
991
|
+
rate = concentrate.exp()
|
|
992
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
993
|
+
counts = theta * library_size
|
|
994
|
+
return counts
|
|
995
|
+
|
|
996
|
+
def get_counts(self, concentrate,
|
|
997
|
+
library_sizes,
|
|
998
|
+
batch_size: int = 1024):
|
|
999
|
+
|
|
1000
|
+
concentrate = convert_to_tensor(concentrate, device=self.get_device())
|
|
1001
|
+
|
|
1002
|
+
if type(library_sizes) == list:
|
|
1003
|
+
library_sizes = np.array(library_sizes).reshape(-1,1)
|
|
1004
|
+
elif len(library_sizes.shape)==1:
|
|
1005
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
1006
|
+
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
1007
|
+
|
|
1008
|
+
dataset = CustomDataset2(concentrate,ls)
|
|
1009
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
1010
|
+
|
|
1011
|
+
E = []
|
|
1012
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
1013
|
+
for C_batch, L_batch, _ in dataloader:
|
|
1014
|
+
counts = self._count(C_batch, L_batch)
|
|
1015
|
+
E.append(tensor_to_numpy(counts))
|
|
1016
|
+
pbar.update(1)
|
|
1017
|
+
|
|
1018
|
+
E = np.concatenate(E)
|
|
1019
|
+
return E
|
|
1020
|
+
|
|
1021
|
+
def preprocess(self, xs, threshold=0):
|
|
1022
|
+
if self.loss_func == 'bernoulli':
|
|
1023
|
+
ad = sc.AnnData(xs)
|
|
1024
|
+
binarize(ad, threshold=threshold)
|
|
1025
|
+
xs = ad.X.copy()
|
|
1026
|
+
else:
|
|
1027
|
+
xs = np.round(xs)
|
|
1028
|
+
|
|
1029
|
+
if sparse.issparse(xs):
|
|
1030
|
+
xs = xs.toarray()
|
|
1031
|
+
return xs
|
|
1032
|
+
|
|
1033
|
+
def fit(self, xs,
|
|
1034
|
+
us = None,
|
|
1035
|
+
ys = None,
|
|
1036
|
+
zs = None,
|
|
1037
|
+
num_epochs: int = 500,
|
|
1038
|
+
learning_rate: float = 0.0001,
|
|
1039
|
+
batch_size: int = 256,
|
|
1040
|
+
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
1041
|
+
beta_1: float = 0.9,
|
|
1042
|
+
weight_decay: float = 0.005,
|
|
1043
|
+
decay_rate: float = 0.9,
|
|
1044
|
+
config_enum: str = 'parallel',
|
|
1045
|
+
threshold: int = 0,
|
|
1046
|
+
use_jax: bool = True):
|
|
1047
|
+
"""
|
|
1048
|
+
Train the DensityFlow2 model.
|
|
1049
|
+
|
|
1050
|
+
Parameters
|
|
1051
|
+
----------
|
|
1052
|
+
xs
|
|
1053
|
+
Single-cell experssion matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are features.
|
|
1054
|
+
us
|
|
1055
|
+
cell-level factor matrix.
|
|
1056
|
+
ys
|
|
1057
|
+
Desired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are desired factors.
|
|
1058
|
+
num_epochs
|
|
1059
|
+
Number of training epochs.
|
|
1060
|
+
learning_rate
|
|
1061
|
+
Parameter for training.
|
|
1062
|
+
batch_size
|
|
1063
|
+
Size of batch processing.
|
|
1064
|
+
algo
|
|
1065
|
+
Optimization algorithm.
|
|
1066
|
+
beta_1
|
|
1067
|
+
Parameter for optimization.
|
|
1068
|
+
weight_decay
|
|
1069
|
+
Parameter for optimization.
|
|
1070
|
+
decay_rate
|
|
1071
|
+
Parameter for optimization.
|
|
1072
|
+
use_jax
|
|
1073
|
+
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
1074
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow2 in the shell command.
|
|
1075
|
+
"""
|
|
1076
|
+
xs = self.preprocess(xs, threshold=threshold)
|
|
1077
|
+
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
1078
|
+
if us is not None:
|
|
1079
|
+
us = convert_to_tensor(us, dtype=self.dtype, device=self.get_device())
|
|
1080
|
+
if ys is not None:
|
|
1081
|
+
ys = convert_to_tensor(ys, dtype=self.dtype, device=self.get_device())
|
|
1082
|
+
if zs is not None:
|
|
1083
|
+
zs = convert_to_tensor(zs, dtype=self.dtype, device=self.get_device())
|
|
1084
|
+
|
|
1085
|
+
dataset = CustomDataset4(xs, us, ys, zs)
|
|
1086
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
1087
|
+
|
|
1088
|
+
# setup the optimizer
|
|
1089
|
+
optim_params = {'lr': learning_rate, 'betas': (beta_1, 0.999), 'weight_decay': weight_decay}
|
|
1090
|
+
|
|
1091
|
+
if algo.lower()=='rmsprop':
|
|
1092
|
+
optimizer = torch.optim.RMSprop
|
|
1093
|
+
elif algo.lower()=='adam':
|
|
1094
|
+
optimizer = torch.optim.Adam
|
|
1095
|
+
elif algo.lower() == 'adamw':
|
|
1096
|
+
optimizer = torch.optim.AdamW
|
|
1097
|
+
else:
|
|
1098
|
+
raise ValueError("An optimization algorithm must be specified.")
|
|
1099
|
+
scheduler = ExponentialLR({'optimizer': optimizer, 'optim_args': optim_params, 'gamma': decay_rate})
|
|
1100
|
+
|
|
1101
|
+
pyro.clear_param_store()
|
|
1102
|
+
|
|
1103
|
+
# set up the loss(es) for inference, wrapping the guide in config_enumerate builds the loss as a sum
|
|
1104
|
+
# by enumerating each class label form the sampled discrete categorical distribution in the model
|
|
1105
|
+
Elbo = JitTraceEnum_ELBO if use_jax else TraceEnum_ELBO
|
|
1106
|
+
elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False)
|
|
1107
|
+
if us is None:
|
|
1108
|
+
if ys is None:
|
|
1109
|
+
guide = config_enumerate(self.guide1, config_enum, expand=True)
|
|
1110
|
+
loss_basic = SVI(self.model1, guide, scheduler, loss=elbo)
|
|
1111
|
+
else:
|
|
1112
|
+
guide = config_enumerate(self.guide3, config_enum, expand=True)
|
|
1113
|
+
loss_basic = SVI(self.model3, guide, scheduler, loss=elbo)
|
|
1114
|
+
else:
|
|
1115
|
+
if ys is None:
|
|
1116
|
+
guide = config_enumerate(self.guide2, config_enum, expand=True)
|
|
1117
|
+
loss_basic = SVI(self.model2, guide, scheduler, loss=elbo)
|
|
1118
|
+
else:
|
|
1119
|
+
guide = config_enumerate(self.guide4, config_enum, expand=True)
|
|
1120
|
+
loss_basic = SVI(self.model4, guide, scheduler, loss=elbo)
|
|
1121
|
+
|
|
1122
|
+
# build a list of all losses considered
|
|
1123
|
+
losses = [loss_basic]
|
|
1124
|
+
num_losses = len(losses)
|
|
1125
|
+
|
|
1126
|
+
with tqdm(total=num_epochs, desc='Training', unit='epoch') as pbar:
|
|
1127
|
+
for epoch in range(num_epochs):
|
|
1128
|
+
epoch_losses = [0.0] * num_losses
|
|
1129
|
+
for batch_x, batch_u, batch_y, batch_z, _ in dataloader:
|
|
1130
|
+
if us is None:
|
|
1131
|
+
batch_u = None
|
|
1132
|
+
if ys is None:
|
|
1133
|
+
batch_y = None
|
|
1134
|
+
if zs is None:
|
|
1135
|
+
batch_z = None
|
|
1136
|
+
|
|
1137
|
+
for loss_id in range(num_losses):
|
|
1138
|
+
if batch_u is None:
|
|
1139
|
+
if batch_y is None:
|
|
1140
|
+
new_loss = losses[loss_id].step(batch_x)
|
|
1141
|
+
else:
|
|
1142
|
+
new_loss = losses[loss_id].step(batch_x, batch_y, batch_z)
|
|
1143
|
+
else:
|
|
1144
|
+
if batch_y is None:
|
|
1145
|
+
new_loss = losses[loss_id].step(batch_x, batch_u)
|
|
1146
|
+
else:
|
|
1147
|
+
new_loss = losses[loss_id].step(batch_x, batch_u, batch_y, batch_z)
|
|
1148
|
+
epoch_losses[loss_id] += new_loss
|
|
1149
|
+
|
|
1150
|
+
avg_epoch_losses_ = map(lambda v: v / len(dataloader), epoch_losses)
|
|
1151
|
+
avg_epoch_losses = map(lambda v: "{:.4f}".format(v), avg_epoch_losses_)
|
|
1152
|
+
|
|
1153
|
+
# store the loss
|
|
1154
|
+
str_loss = " ".join(map(str, avg_epoch_losses))
|
|
1155
|
+
|
|
1156
|
+
# Update progress bar
|
|
1157
|
+
pbar.set_postfix({'loss': str_loss})
|
|
1158
|
+
pbar.update(1)
|
|
1159
|
+
|
|
1160
|
+
@classmethod
|
|
1161
|
+
def save_model(cls, model, file_path, compression=False):
|
|
1162
|
+
"""Save the model to the specified file path."""
|
|
1163
|
+
file_path = os.path.abspath(file_path)
|
|
1164
|
+
|
|
1165
|
+
model.eval()
|
|
1166
|
+
if compression:
|
|
1167
|
+
with gzip.open(file_path, 'wb') as pickle_file:
|
|
1168
|
+
pickle.dump(model, pickle_file)
|
|
1169
|
+
else:
|
|
1170
|
+
with open(file_path, 'wb') as pickle_file:
|
|
1171
|
+
pickle.dump(model, pickle_file)
|
|
1172
|
+
|
|
1173
|
+
print(f'Model saved to {file_path}')
|
|
1174
|
+
|
|
1175
|
+
@classmethod
|
|
1176
|
+
def load_model(cls, file_path):
|
|
1177
|
+
"""Load the model from the specified file path and return an instance."""
|
|
1178
|
+
print(f'Model loaded from {file_path}')
|
|
1179
|
+
|
|
1180
|
+
file_path = os.path.abspath(file_path)
|
|
1181
|
+
if file_path.endswith('gz'):
|
|
1182
|
+
with gzip.open(file_path, 'rb') as pickle_file:
|
|
1183
|
+
model = pickle.load(pickle_file)
|
|
1184
|
+
else:
|
|
1185
|
+
with open(file_path, 'rb') as pickle_file:
|
|
1186
|
+
model = pickle.load(pickle_file)
|
|
1187
|
+
|
|
1188
|
+
return model
|
|
1189
|
+
|
|
1190
|
+
|
|
1191
|
+
EXAMPLE_RUN = (
|
|
1192
|
+
"example run: DensityFlow2 --help"
|
|
1193
|
+
)
|
|
1194
|
+
|
|
1195
|
+
def parse_args():
|
|
1196
|
+
parser = argparse.ArgumentParser(
|
|
1197
|
+
description="DensityFlow2\n{}".format(EXAMPLE_RUN))
|
|
1198
|
+
|
|
1199
|
+
parser.add_argument(
|
|
1200
|
+
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
1201
|
+
)
|
|
1202
|
+
parser.add_argument(
|
|
1203
|
+
"--jit", action="store_true", help="use PyTorch jit to speed up training"
|
|
1204
|
+
)
|
|
1205
|
+
parser.add_argument(
|
|
1206
|
+
"-n", "--num-epochs", default=200, type=int, help="number of epochs to run"
|
|
1207
|
+
)
|
|
1208
|
+
parser.add_argument(
|
|
1209
|
+
"-enum",
|
|
1210
|
+
"--enum-discrete",
|
|
1211
|
+
default="parallel",
|
|
1212
|
+
help="parallel, sequential or none. uses parallel enumeration by default",
|
|
1213
|
+
)
|
|
1214
|
+
parser.add_argument(
|
|
1215
|
+
"-data",
|
|
1216
|
+
"--data-file",
|
|
1217
|
+
default=None,
|
|
1218
|
+
type=str,
|
|
1219
|
+
help="the data file",
|
|
1220
|
+
)
|
|
1221
|
+
parser.add_argument(
|
|
1222
|
+
"-cf",
|
|
1223
|
+
"--cell-factor-file",
|
|
1224
|
+
default=None,
|
|
1225
|
+
type=str,
|
|
1226
|
+
help="the file for the record of cell-level factors",
|
|
1227
|
+
)
|
|
1228
|
+
parser.add_argument(
|
|
1229
|
+
"-bs",
|
|
1230
|
+
"--batch-size",
|
|
1231
|
+
default=1000,
|
|
1232
|
+
type=int,
|
|
1233
|
+
help="number of cells to be considered in a batch",
|
|
1234
|
+
)
|
|
1235
|
+
parser.add_argument(
|
|
1236
|
+
"-lr",
|
|
1237
|
+
"--learning-rate",
|
|
1238
|
+
default=0.0001,
|
|
1239
|
+
type=float,
|
|
1240
|
+
help="learning rate for Adam optimizer",
|
|
1241
|
+
)
|
|
1242
|
+
parser.add_argument(
|
|
1243
|
+
"-cs",
|
|
1244
|
+
"--codebook-size",
|
|
1245
|
+
default=100,
|
|
1246
|
+
type=int,
|
|
1247
|
+
help="size of vector quantization codebook",
|
|
1248
|
+
)
|
|
1249
|
+
parser.add_argument(
|
|
1250
|
+
"--z-dist",
|
|
1251
|
+
default='gumbel',
|
|
1252
|
+
type=str,
|
|
1253
|
+
choices=['normal','laplacian','studentt','gumbel','cauchy'],
|
|
1254
|
+
help="distribution model for latent representation",
|
|
1255
|
+
)
|
|
1256
|
+
parser.add_argument(
|
|
1257
|
+
"-zd",
|
|
1258
|
+
"--z-dim",
|
|
1259
|
+
default=10,
|
|
1260
|
+
type=int,
|
|
1261
|
+
help="size of the tensor representing the latent variable z variable",
|
|
1262
|
+
)
|
|
1263
|
+
parser.add_argument(
|
|
1264
|
+
"-likeli",
|
|
1265
|
+
"--likelihood",
|
|
1266
|
+
default='negbinomial',
|
|
1267
|
+
type=str,
|
|
1268
|
+
choices=['negbinomial', 'multinomial', 'poisson', 'bernoulli'],
|
|
1269
|
+
help="specify the distribution likelihood function",
|
|
1270
|
+
)
|
|
1271
|
+
parser.add_argument(
|
|
1272
|
+
"-zi",
|
|
1273
|
+
"--zeroinflate",
|
|
1274
|
+
action="store_true",
|
|
1275
|
+
help="use zero-inflated estimation",
|
|
1276
|
+
)
|
|
1277
|
+
parser.add_argument(
|
|
1278
|
+
"-id",
|
|
1279
|
+
"--inverse-dispersion",
|
|
1280
|
+
default=10.0,
|
|
1281
|
+
type=float,
|
|
1282
|
+
help="inverse dispersion prior for negative binomial",
|
|
1283
|
+
)
|
|
1284
|
+
parser.add_argument(
|
|
1285
|
+
"-hl",
|
|
1286
|
+
"--hidden-layers",
|
|
1287
|
+
nargs="+",
|
|
1288
|
+
default=[500],
|
|
1289
|
+
type=int,
|
|
1290
|
+
help="a tuple (or list) of MLP layers to be used in the neural networks "
|
|
1291
|
+
"representing the parameters of the distributions in our model",
|
|
1292
|
+
)
|
|
1293
|
+
parser.add_argument(
|
|
1294
|
+
"-hla",
|
|
1295
|
+
"--hidden-layer-activation",
|
|
1296
|
+
default='relu',
|
|
1297
|
+
type=str,
|
|
1298
|
+
choices=['relu','softplus','leakyrelu','linear'],
|
|
1299
|
+
help="activation function for hidden layers",
|
|
1300
|
+
)
|
|
1301
|
+
parser.add_argument(
|
|
1302
|
+
"-plf",
|
|
1303
|
+
"--post-layer-function",
|
|
1304
|
+
nargs="+",
|
|
1305
|
+
default=['layernorm'],
|
|
1306
|
+
type=str,
|
|
1307
|
+
help="post functions for hidden layers, could be none, dropout, layernorm, batchnorm, or combination, default is 'dropout layernorm'",
|
|
1308
|
+
)
|
|
1309
|
+
parser.add_argument(
|
|
1310
|
+
"-paf",
|
|
1311
|
+
"--post-activation-function",
|
|
1312
|
+
nargs="+",
|
|
1313
|
+
default=['none'],
|
|
1314
|
+
type=str,
|
|
1315
|
+
help="post functions for activation layers, could be none or dropout, default is 'none'",
|
|
1316
|
+
)
|
|
1317
|
+
parser.add_argument(
|
|
1318
|
+
"-64",
|
|
1319
|
+
"--float64",
|
|
1320
|
+
action="store_true",
|
|
1321
|
+
help="use double float precision",
|
|
1322
|
+
)
|
|
1323
|
+
parser.add_argument(
|
|
1324
|
+
"-dr",
|
|
1325
|
+
"--decay-rate",
|
|
1326
|
+
default=0.9,
|
|
1327
|
+
type=float,
|
|
1328
|
+
help="decay rate for Adam optimizer",
|
|
1329
|
+
)
|
|
1330
|
+
parser.add_argument(
|
|
1331
|
+
"--layer-dropout-rate",
|
|
1332
|
+
default=0.1,
|
|
1333
|
+
type=float,
|
|
1334
|
+
help="droput rate for neural networks",
|
|
1335
|
+
)
|
|
1336
|
+
parser.add_argument(
|
|
1337
|
+
"-b1",
|
|
1338
|
+
"--beta-1",
|
|
1339
|
+
default=0.95,
|
|
1340
|
+
type=float,
|
|
1341
|
+
help="beta-1 parameter for Adam optimizer",
|
|
1342
|
+
)
|
|
1343
|
+
parser.add_argument(
|
|
1344
|
+
"--seed",
|
|
1345
|
+
default=None,
|
|
1346
|
+
type=int,
|
|
1347
|
+
help="seed for controlling randomness in this example",
|
|
1348
|
+
)
|
|
1349
|
+
parser.add_argument(
|
|
1350
|
+
"--save-model",
|
|
1351
|
+
default=None,
|
|
1352
|
+
type=str,
|
|
1353
|
+
help="path to save model for prediction",
|
|
1354
|
+
)
|
|
1355
|
+
args = parser.parse_args()
|
|
1356
|
+
return args
|
|
1357
|
+
|
|
1358
|
+
def main():
|
|
1359
|
+
args = parse_args()
|
|
1360
|
+
assert (
|
|
1361
|
+
(args.data_file is not None) and (
|
|
1362
|
+
os.path.exists(args.data_file))
|
|
1363
|
+
), "data file must be provided"
|
|
1364
|
+
|
|
1365
|
+
if args.seed is not None:
|
|
1366
|
+
set_random_seed(args.seed)
|
|
1367
|
+
|
|
1368
|
+
if args.float64:
|
|
1369
|
+
dtype = torch.float64
|
|
1370
|
+
torch.set_default_dtype(torch.float64)
|
|
1371
|
+
else:
|
|
1372
|
+
dtype = torch.float32
|
|
1373
|
+
torch.set_default_dtype(torch.float32)
|
|
1374
|
+
|
|
1375
|
+
xs = dt.fread(file=args.data_file, header=True).to_numpy()
|
|
1376
|
+
us = None
|
|
1377
|
+
if args.cell_factor_file is not None:
|
|
1378
|
+
us = dt.fread(file=args.cell_factor_file, header=True).to_numpy()
|
|
1379
|
+
|
|
1380
|
+
input_size = xs.shape[1]
|
|
1381
|
+
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1382
|
+
|
|
1383
|
+
###########################################
|
|
1384
|
+
df = DensityFlow2(
|
|
1385
|
+
input_size=input_size,
|
|
1386
|
+
cell_factor_size=cell_factor_size,
|
|
1387
|
+
inverse_dispersion=args.inverse_dispersion,
|
|
1388
|
+
z_dim=args.z_dim,
|
|
1389
|
+
hidden_layers=args.hidden_layers,
|
|
1390
|
+
hidden_layer_activation=args.hidden_layer_activation,
|
|
1391
|
+
use_cuda=args.cuda,
|
|
1392
|
+
config_enum=args.enum_discrete,
|
|
1393
|
+
use_zeroinflate=args.zeroinflate,
|
|
1394
|
+
loss_func=args.likelihood,
|
|
1395
|
+
nn_dropout=args.layer_dropout_rate,
|
|
1396
|
+
post_layer_fct=args.post_layer_function,
|
|
1397
|
+
post_act_fct=args.post_activation_function,
|
|
1398
|
+
codebook_size=args.codebook_size,
|
|
1399
|
+
z_dist = args.z_dist,
|
|
1400
|
+
dtype=dtype,
|
|
1401
|
+
)
|
|
1402
|
+
|
|
1403
|
+
df.fit(xs, us=us,
|
|
1404
|
+
num_epochs=args.num_epochs,
|
|
1405
|
+
learning_rate=args.learning_rate,
|
|
1406
|
+
batch_size=args.batch_size,
|
|
1407
|
+
beta_1=args.beta_1,
|
|
1408
|
+
decay_rate=args.decay_rate,
|
|
1409
|
+
use_jax=args.jit,
|
|
1410
|
+
config_enum=args.enum_discrete,
|
|
1411
|
+
)
|
|
1412
|
+
|
|
1413
|
+
if args.save_model is not None:
|
|
1414
|
+
if args.save_model.endswith('gz'):
|
|
1415
|
+
DensityFlow2.save_model(df, args.save_model, compression=True)
|
|
1416
|
+
else:
|
|
1417
|
+
DensityFlow2.save_model(df, args.save_model)
|
|
1418
|
+
|
|
1419
|
+
|
|
1420
|
+
|
|
1421
|
+
if __name__ == "__main__":
|
|
1422
|
+
main()
|