SURE-tools 2.0.9__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/PerturbFlow.py +1324 -0
- SURE/SURE.py +0 -47
- SURE/__init__.py +3 -2
- {sure_tools-2.0.9.dist-info → sure_tools-2.1.0.dist-info}/METADATA +1 -1
- {sure_tools-2.0.9.dist-info → sure_tools-2.1.0.dist-info}/RECORD +9 -8
- sure_tools-2.1.0.dist-info/entry_points.txt +3 -0
- sure_tools-2.0.9.dist-info/entry_points.txt +0 -2
- {sure_tools-2.0.9.dist-info → sure_tools-2.1.0.dist-info}/WHEEL +0 -0
- {sure_tools-2.0.9.dist-info → sure_tools-2.1.0.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.0.9.dist-info → sure_tools-2.1.0.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
ADDED
|
@@ -0,0 +1,1324 @@
|
|
|
1
|
+
import pyro
|
|
2
|
+
import pyro.distributions as dist
|
|
3
|
+
from pyro.optim import ExponentialLR
|
|
4
|
+
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, config_enumerate
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_probs
|
|
10
|
+
from torch.distributions import constraints
|
|
11
|
+
from torch.distributions.transforms import SoftmaxTransform
|
|
12
|
+
|
|
13
|
+
from .utils.custom_mlp import MLP, Exp
|
|
14
|
+
from .utils.utils import CustomDataset, CustomDataset2, CustomDataset4, tensor_to_numpy, convert_to_tensor
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import argparse
|
|
19
|
+
import random
|
|
20
|
+
import numpy as np
|
|
21
|
+
import datatable as dt
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
from scipy import sparse
|
|
24
|
+
|
|
25
|
+
import scanpy as sc
|
|
26
|
+
from .atac import binarize
|
|
27
|
+
|
|
28
|
+
from typing import Literal
|
|
29
|
+
|
|
30
|
+
import warnings
|
|
31
|
+
warnings.filterwarnings("ignore")
|
|
32
|
+
|
|
33
|
+
import dill as pickle
|
|
34
|
+
import gzip
|
|
35
|
+
from packaging.version import Version
|
|
36
|
+
torch_version = torch.__version__
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def set_random_seed(seed):
|
|
40
|
+
# Set seed for PyTorch
|
|
41
|
+
torch.manual_seed(seed)
|
|
42
|
+
|
|
43
|
+
# If using CUDA, set the seed for CUDA
|
|
44
|
+
if torch.cuda.is_available():
|
|
45
|
+
torch.cuda.manual_seed(seed)
|
|
46
|
+
torch.cuda.manual_seed_all(seed) # For multi-GPU setups.
|
|
47
|
+
|
|
48
|
+
# Set seed for NumPy
|
|
49
|
+
np.random.seed(seed)
|
|
50
|
+
|
|
51
|
+
# Set seed for Python's random module
|
|
52
|
+
random.seed(seed)
|
|
53
|
+
|
|
54
|
+
# Set seed for Pyro
|
|
55
|
+
pyro.set_rng_seed(seed)
|
|
56
|
+
|
|
57
|
+
class PerturbFlow(nn.Module):
|
|
58
|
+
"""SUccinct REpresentation of single-omics cells
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
inpute_size
|
|
63
|
+
Number of features (e.g., genes, peaks, proteins, etc.) per cell.
|
|
64
|
+
codebook_size
|
|
65
|
+
Number of metacells.
|
|
66
|
+
cell_factor_size
|
|
67
|
+
Number of cell-level factors.
|
|
68
|
+
z_dim
|
|
69
|
+
Dimensionality of latent states and metacells.
|
|
70
|
+
hidden_layers
|
|
71
|
+
A list give the numbers of neurons for each hidden layer.
|
|
72
|
+
loss_func
|
|
73
|
+
The likelihood model for single-cell data generation.
|
|
74
|
+
|
|
75
|
+
One of the following:
|
|
76
|
+
* ``'negbinomial'`` - negative binomial distribution (default)
|
|
77
|
+
* ``'poisson'`` - poisson distribution
|
|
78
|
+
* ``'multinomial'`` - multinomial distribution
|
|
79
|
+
z_dist
|
|
80
|
+
The distribution model for latent states.
|
|
81
|
+
|
|
82
|
+
One of the following:
|
|
83
|
+
* ``'normal'`` - normal distribution
|
|
84
|
+
* ``'laplacian'`` - Laplacian distribution
|
|
85
|
+
* ``'studentt'`` - Student-t distribution.
|
|
86
|
+
use_cuda
|
|
87
|
+
A boolean option for switching on cuda device.
|
|
88
|
+
|
|
89
|
+
Examples
|
|
90
|
+
--------
|
|
91
|
+
>>>
|
|
92
|
+
>>>
|
|
93
|
+
>>>
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
def __init__(self,
|
|
97
|
+
input_size: int,
|
|
98
|
+
codebook_size: int = 200,
|
|
99
|
+
cell_factor_size: int = 0,
|
|
100
|
+
cell_factor_names: list = None,
|
|
101
|
+
supervised_mode: bool = False,
|
|
102
|
+
z_dim: int = 10,
|
|
103
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
104
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
105
|
+
inverse_dispersion: float = 10.0,
|
|
106
|
+
use_zeroinflate: bool = True,
|
|
107
|
+
hidden_layers: list = [300],
|
|
108
|
+
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
109
|
+
nn_dropout: float = 0.1,
|
|
110
|
+
post_layer_fct: list = ['layernorm'],
|
|
111
|
+
post_act_fct: list = None,
|
|
112
|
+
config_enum: str = 'parallel',
|
|
113
|
+
use_cuda: bool = False,
|
|
114
|
+
seed: int = 42,
|
|
115
|
+
dtype = torch.float32, # type: ignore
|
|
116
|
+
):
|
|
117
|
+
super().__init__()
|
|
118
|
+
|
|
119
|
+
self.input_size = input_size
|
|
120
|
+
self.cell_factor_size = cell_factor_size
|
|
121
|
+
self.inverse_dispersion = inverse_dispersion
|
|
122
|
+
self.latent_dim = z_dim
|
|
123
|
+
self.hidden_layers = hidden_layers
|
|
124
|
+
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
125
|
+
self.allow_broadcast = config_enum == 'parallel'
|
|
126
|
+
self.use_cuda = use_cuda
|
|
127
|
+
self.loss_func = loss_func
|
|
128
|
+
self.options = None
|
|
129
|
+
self.code_size=codebook_size
|
|
130
|
+
self.supervised_mode=supervised_mode
|
|
131
|
+
self.latent_dist = z_dist
|
|
132
|
+
self.dtype = dtype
|
|
133
|
+
self.use_zeroinflate=use_zeroinflate
|
|
134
|
+
self.nn_dropout = nn_dropout
|
|
135
|
+
self.post_layer_fct = post_layer_fct
|
|
136
|
+
self.post_act_fct = post_act_fct
|
|
137
|
+
self.hidden_layer_activation = hidden_layer_activation
|
|
138
|
+
self.cell_factor_names = cell_factor_names
|
|
139
|
+
|
|
140
|
+
self.codebook_weights = None
|
|
141
|
+
|
|
142
|
+
set_random_seed(seed)
|
|
143
|
+
self.setup_networks()
|
|
144
|
+
|
|
145
|
+
def setup_networks(self):
|
|
146
|
+
latent_dim = self.latent_dim
|
|
147
|
+
hidden_sizes = self.hidden_layers
|
|
148
|
+
|
|
149
|
+
nn_layer_norm, nn_batch_norm, nn_layer_dropout = False, False, False
|
|
150
|
+
na_layer_norm, na_batch_norm, na_layer_dropout = False, False, False
|
|
151
|
+
|
|
152
|
+
if self.post_layer_fct is not None:
|
|
153
|
+
nn_layer_norm=True if ('layernorm' in self.post_layer_fct) or ('layer_norm' in self.post_layer_fct) else False
|
|
154
|
+
nn_batch_norm=True if ('batchnorm' in self.post_layer_fct) or ('batch_norm' in self.post_layer_fct) else False
|
|
155
|
+
nn_layer_dropout=True if 'dropout' in self.post_layer_fct else False
|
|
156
|
+
|
|
157
|
+
if self.post_act_fct is not None:
|
|
158
|
+
na_layer_norm=True if ('layernorm' in self.post_act_fct) or ('layer_norm' in self.post_act_fct) else False
|
|
159
|
+
na_batch_norm=True if ('batchnorm' in self.post_act_fct) or ('batch_norm' in self.post_act_fct) else False
|
|
160
|
+
na_layer_dropout=True if 'dropout' in self.post_act_fct else False
|
|
161
|
+
|
|
162
|
+
if nn_layer_norm and nn_batch_norm and nn_layer_dropout:
|
|
163
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
|
|
164
|
+
elif nn_layer_norm and nn_layer_dropout:
|
|
165
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
|
|
166
|
+
elif nn_batch_norm and nn_layer_dropout:
|
|
167
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
|
|
168
|
+
elif nn_layer_norm and nn_batch_norm:
|
|
169
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
|
|
170
|
+
elif nn_layer_norm:
|
|
171
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
|
|
172
|
+
elif nn_batch_norm:
|
|
173
|
+
post_layer_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
|
|
174
|
+
elif nn_layer_dropout:
|
|
175
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
|
|
176
|
+
else:
|
|
177
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: None
|
|
178
|
+
|
|
179
|
+
if na_layer_norm and na_batch_norm and na_layer_dropout:
|
|
180
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
|
|
181
|
+
elif na_layer_norm and na_layer_dropout:
|
|
182
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
|
|
183
|
+
elif na_batch_norm and na_layer_dropout:
|
|
184
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
|
|
185
|
+
elif na_layer_norm and na_batch_norm:
|
|
186
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
|
|
187
|
+
elif na_layer_norm:
|
|
188
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
|
|
189
|
+
elif na_batch_norm:
|
|
190
|
+
post_act_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
|
|
191
|
+
elif na_layer_dropout:
|
|
192
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
|
|
193
|
+
else:
|
|
194
|
+
post_act_fct = lambda layer_ix, total_layers, layer: None
|
|
195
|
+
|
|
196
|
+
if self.hidden_layer_activation == 'relu':
|
|
197
|
+
activate_fct = nn.ReLU
|
|
198
|
+
elif self.hidden_layer_activation == 'softplus':
|
|
199
|
+
activate_fct = nn.Softplus
|
|
200
|
+
elif self.hidden_layer_activation == 'leakyrelu':
|
|
201
|
+
activate_fct = nn.LeakyReLU
|
|
202
|
+
elif self.hidden_layer_activation == 'linear':
|
|
203
|
+
activate_fct = nn.Identity
|
|
204
|
+
|
|
205
|
+
if self.supervised_mode:
|
|
206
|
+
self.encoder_n = MLP(
|
|
207
|
+
[self.input_size] + hidden_sizes + [self.code_size],
|
|
208
|
+
activation=activate_fct,
|
|
209
|
+
output_activation=None,
|
|
210
|
+
post_layer_fct=post_layer_fct,
|
|
211
|
+
post_act_fct=post_act_fct,
|
|
212
|
+
allow_broadcast=self.allow_broadcast,
|
|
213
|
+
use_cuda=self.use_cuda,
|
|
214
|
+
)
|
|
215
|
+
else:
|
|
216
|
+
self.encoder_n = MLP(
|
|
217
|
+
[self.latent_dim] + hidden_sizes + [self.code_size],
|
|
218
|
+
activation=activate_fct,
|
|
219
|
+
output_activation=None,
|
|
220
|
+
post_layer_fct=post_layer_fct,
|
|
221
|
+
post_act_fct=post_act_fct,
|
|
222
|
+
allow_broadcast=self.allow_broadcast,
|
|
223
|
+
use_cuda=self.use_cuda,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
self.encoder_zn = MLP(
|
|
227
|
+
[self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
|
|
228
|
+
activation=activate_fct,
|
|
229
|
+
output_activation=[None, Exp],
|
|
230
|
+
post_layer_fct=post_layer_fct,
|
|
231
|
+
post_act_fct=post_act_fct,
|
|
232
|
+
allow_broadcast=self.allow_broadcast,
|
|
233
|
+
use_cuda=self.use_cuda,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
if self.cell_factor_size>0:
|
|
237
|
+
self.cell_factor_effect = nn.ModuleList()
|
|
238
|
+
for i in np.arange(self.cell_factor_size):
|
|
239
|
+
self.cell_factor_effect.append(MLP(
|
|
240
|
+
[self.latent_dim+1] + hidden_sizes + [self.latent_dim],
|
|
241
|
+
activation=activate_fct,
|
|
242
|
+
output_activation=None,
|
|
243
|
+
post_layer_fct=post_layer_fct,
|
|
244
|
+
post_act_fct=post_act_fct,
|
|
245
|
+
allow_broadcast=self.allow_broadcast,
|
|
246
|
+
use_cuda=self.use_cuda,
|
|
247
|
+
)
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
self.decoder_concentrate = MLP(
|
|
251
|
+
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
252
|
+
activation=activate_fct,
|
|
253
|
+
output_activation=None,
|
|
254
|
+
post_layer_fct=post_layer_fct,
|
|
255
|
+
post_act_fct=post_act_fct,
|
|
256
|
+
allow_broadcast=self.allow_broadcast,
|
|
257
|
+
use_cuda=self.use_cuda,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
if self.latent_dist == 'studentt':
|
|
261
|
+
self.codebook = MLP(
|
|
262
|
+
[self.code_size] + hidden_sizes + [[latent_dim,latent_dim]],
|
|
263
|
+
activation=activate_fct,
|
|
264
|
+
output_activation=[Exp,None],
|
|
265
|
+
post_layer_fct=post_layer_fct,
|
|
266
|
+
post_act_fct=post_act_fct,
|
|
267
|
+
allow_broadcast=self.allow_broadcast,
|
|
268
|
+
use_cuda=self.use_cuda,
|
|
269
|
+
)
|
|
270
|
+
else:
|
|
271
|
+
self.codebook = MLP(
|
|
272
|
+
[self.code_size] + hidden_sizes + [latent_dim],
|
|
273
|
+
activation=activate_fct,
|
|
274
|
+
output_activation=None,
|
|
275
|
+
post_layer_fct=post_layer_fct,
|
|
276
|
+
post_act_fct=post_act_fct,
|
|
277
|
+
allow_broadcast=self.allow_broadcast,
|
|
278
|
+
use_cuda=self.use_cuda,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
if self.use_cuda:
|
|
282
|
+
self.cuda()
|
|
283
|
+
|
|
284
|
+
def get_device(self):
|
|
285
|
+
return next(self.parameters()).device
|
|
286
|
+
|
|
287
|
+
def cutoff(self, xs, thresh=None):
|
|
288
|
+
eps = torch.finfo(xs.dtype).eps
|
|
289
|
+
|
|
290
|
+
if not thresh is None:
|
|
291
|
+
if eps < thresh:
|
|
292
|
+
eps = thresh
|
|
293
|
+
|
|
294
|
+
xs = xs.clamp(min=eps)
|
|
295
|
+
|
|
296
|
+
if torch.any(torch.isnan(xs)):
|
|
297
|
+
xs[torch.isnan(xs)] = eps
|
|
298
|
+
|
|
299
|
+
return xs
|
|
300
|
+
|
|
301
|
+
def softmax(self, xs):
|
|
302
|
+
#xs = SoftmaxTransform()(xs)
|
|
303
|
+
xs = dist.Multinomial(total_count=1, logits=xs).mean
|
|
304
|
+
return xs
|
|
305
|
+
|
|
306
|
+
def sigmoid(self, xs):
|
|
307
|
+
#sigm_enc = nn.Sigmoid()
|
|
308
|
+
#xs = sigm_enc(xs)
|
|
309
|
+
#xs = clamp_probs(xs)
|
|
310
|
+
xs = dist.Bernoulli(logits=xs).mean
|
|
311
|
+
return xs
|
|
312
|
+
|
|
313
|
+
def softmax_logit(self, xs):
|
|
314
|
+
eps = torch.finfo(xs.dtype).eps
|
|
315
|
+
xs = self.softmax(xs)
|
|
316
|
+
xs = torch.logit(xs, eps=eps)
|
|
317
|
+
return xs
|
|
318
|
+
|
|
319
|
+
def logit(self, xs):
|
|
320
|
+
eps = torch.finfo(xs.dtype).eps
|
|
321
|
+
xs = torch.logit(xs, eps=eps)
|
|
322
|
+
return xs
|
|
323
|
+
|
|
324
|
+
def dirimulti_param(self, xs):
|
|
325
|
+
xs = self.dirimulti_mass * self.sigmoid(xs)
|
|
326
|
+
return xs
|
|
327
|
+
|
|
328
|
+
def multi_param(self, xs):
|
|
329
|
+
xs = self.softmax(xs)
|
|
330
|
+
return xs
|
|
331
|
+
|
|
332
|
+
def model1(self, xs):
|
|
333
|
+
pyro.module('PerturbFlow', self)
|
|
334
|
+
|
|
335
|
+
eps = torch.finfo(xs.dtype).eps
|
|
336
|
+
batch_size = xs.size(0)
|
|
337
|
+
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
338
|
+
|
|
339
|
+
if self.loss_func=='negbinomial':
|
|
340
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
341
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
342
|
+
|
|
343
|
+
if self.use_zeroinflate:
|
|
344
|
+
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
345
|
+
|
|
346
|
+
acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
|
|
347
|
+
|
|
348
|
+
I = torch.eye(self.code_size)
|
|
349
|
+
if self.latent_dist=='studentt':
|
|
350
|
+
acs_dof,acs_loc = self.codebook(I)
|
|
351
|
+
else:
|
|
352
|
+
acs_loc = self.codebook(I)
|
|
353
|
+
|
|
354
|
+
with pyro.plate('data'):
|
|
355
|
+
prior = torch.zeros(batch_size, self.code_size, **self.options)
|
|
356
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
|
|
357
|
+
|
|
358
|
+
zn_loc = torch.matmul(ns,acs_loc)
|
|
359
|
+
#zn_scale = torch.matmul(ns,acs_scale)
|
|
360
|
+
zn_scale = acs_scale
|
|
361
|
+
|
|
362
|
+
if self.latent_dist == 'studentt':
|
|
363
|
+
prior_dof = torch.matmul(ns,acs_dof)
|
|
364
|
+
zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
|
|
365
|
+
elif self.latent_dist == 'laplacian':
|
|
366
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
|
|
367
|
+
elif self.latent_dist == 'cauchy':
|
|
368
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
|
|
369
|
+
elif self.latent_dist == 'normal':
|
|
370
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
371
|
+
elif self.latent_dist == 'gumbel':
|
|
372
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
373
|
+
|
|
374
|
+
zs = zns
|
|
375
|
+
concentrate = self.decoder_concentrate(zs)
|
|
376
|
+
if self.loss_func == 'bernoulli':
|
|
377
|
+
log_theta = concentrate
|
|
378
|
+
else:
|
|
379
|
+
rate = concentrate.exp()
|
|
380
|
+
if self.loss_func != 'poisson':
|
|
381
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
382
|
+
|
|
383
|
+
if self.loss_func == 'negbinomial':
|
|
384
|
+
if self.use_zeroinflate:
|
|
385
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
386
|
+
else:
|
|
387
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
388
|
+
elif self.loss_func == 'poisson':
|
|
389
|
+
if self.use_zeroinflate:
|
|
390
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
391
|
+
else:
|
|
392
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
393
|
+
elif self.loss_func == 'multinomial':
|
|
394
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
395
|
+
elif self.loss_func == 'bernoulli':
|
|
396
|
+
if self.use_zeroinflate:
|
|
397
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
398
|
+
else:
|
|
399
|
+
pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
|
|
400
|
+
|
|
401
|
+
def guide1(self, xs):
|
|
402
|
+
with pyro.plate('data'):
|
|
403
|
+
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
404
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
405
|
+
|
|
406
|
+
alpha = self.encoder_n(zns)
|
|
407
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
408
|
+
|
|
409
|
+
def model2(self, xs, us=None):
|
|
410
|
+
pyro.module('PerturbFlow', self)
|
|
411
|
+
|
|
412
|
+
eps = torch.finfo(xs.dtype).eps
|
|
413
|
+
batch_size = xs.size(0)
|
|
414
|
+
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
415
|
+
|
|
416
|
+
if self.loss_func=='negbinomial':
|
|
417
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
418
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
419
|
+
|
|
420
|
+
if self.use_zeroinflate:
|
|
421
|
+
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
422
|
+
|
|
423
|
+
acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
|
|
424
|
+
|
|
425
|
+
I = torch.eye(self.code_size)
|
|
426
|
+
if self.latent_dist=='studentt':
|
|
427
|
+
acs_dof,acs_loc = self.codebook(I)
|
|
428
|
+
else:
|
|
429
|
+
acs_loc = self.codebook(I)
|
|
430
|
+
|
|
431
|
+
with pyro.plate('data'):
|
|
432
|
+
prior = torch.zeros(batch_size, self.code_size, **self.options)
|
|
433
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
|
|
434
|
+
|
|
435
|
+
zn_loc = torch.matmul(ns,acs_loc)
|
|
436
|
+
#zn_scale = torch.matmul(ns,acs_scale)
|
|
437
|
+
zn_scale = acs_scale
|
|
438
|
+
|
|
439
|
+
if self.latent_dist == 'studentt':
|
|
440
|
+
prior_dof = torch.matmul(ns,acs_dof)
|
|
441
|
+
zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
|
|
442
|
+
elif self.latent_dist == 'laplacian':
|
|
443
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
|
|
444
|
+
elif self.latent_dist == 'cauchy':
|
|
445
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
|
|
446
|
+
elif self.latent_dist == 'normal':
|
|
447
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
448
|
+
elif self.latent_dist == 'gumbel':
|
|
449
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
450
|
+
|
|
451
|
+
if self.cell_factor_size>0:
|
|
452
|
+
#zus = self.decoder_undesired([zns,us])
|
|
453
|
+
zus = None
|
|
454
|
+
for i in np.arange(self.cell_factor_size):
|
|
455
|
+
if i==0:
|
|
456
|
+
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
457
|
+
else:
|
|
458
|
+
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
459
|
+
zs = zns+zus
|
|
460
|
+
else:
|
|
461
|
+
zs = zns
|
|
462
|
+
|
|
463
|
+
concentrate = self.decoder_concentrate(zs)
|
|
464
|
+
if self.loss_func == 'bernoulli':
|
|
465
|
+
log_theta = concentrate
|
|
466
|
+
else:
|
|
467
|
+
rate = concentrate.exp()
|
|
468
|
+
if self.loss_func != 'poisson':
|
|
469
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
470
|
+
|
|
471
|
+
if self.loss_func == 'negbinomial':
|
|
472
|
+
if self.use_zeroinflate:
|
|
473
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
474
|
+
else:
|
|
475
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
476
|
+
elif self.loss_func == 'poisson':
|
|
477
|
+
if self.use_zeroinflate:
|
|
478
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
479
|
+
else:
|
|
480
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
481
|
+
elif self.loss_func == 'multinomial':
|
|
482
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
483
|
+
elif self.loss_func == 'bernoulli':
|
|
484
|
+
if self.use_zeroinflate:
|
|
485
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
486
|
+
else:
|
|
487
|
+
pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
|
|
488
|
+
|
|
489
|
+
def guide2(self, xs, us=None):
|
|
490
|
+
with pyro.plate('data'):
|
|
491
|
+
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
492
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
493
|
+
|
|
494
|
+
alpha = self.encoder_n(zns)
|
|
495
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
496
|
+
|
|
497
|
+
def model3(self, xs, ys, embeds=None):
|
|
498
|
+
pyro.module('PerturbFlow', self)
|
|
499
|
+
|
|
500
|
+
eps = torch.finfo(xs.dtype).eps
|
|
501
|
+
batch_size = xs.size(0)
|
|
502
|
+
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
503
|
+
|
|
504
|
+
if self.loss_func=='negbinomial':
|
|
505
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
506
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
507
|
+
|
|
508
|
+
if self.use_zeroinflate:
|
|
509
|
+
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
510
|
+
|
|
511
|
+
acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
|
|
512
|
+
|
|
513
|
+
I = torch.eye(self.code_size)
|
|
514
|
+
if self.latent_dist=='studentt':
|
|
515
|
+
acs_dof,acs_loc = self.codebook(I)
|
|
516
|
+
else:
|
|
517
|
+
acs_loc = self.codebook(I)
|
|
518
|
+
|
|
519
|
+
with pyro.plate('data'):
|
|
520
|
+
#prior = torch.zeros(batch_size, self.code_size, **self.options)
|
|
521
|
+
prior = self.encoder_n(xs)
|
|
522
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
|
|
523
|
+
|
|
524
|
+
zn_loc = torch.matmul(ns,acs_loc)
|
|
525
|
+
#prior_scale = torch.matmul(ns,acs_scale)
|
|
526
|
+
zn_scale = acs_scale
|
|
527
|
+
|
|
528
|
+
if self.latent_dist=='studentt':
|
|
529
|
+
prior_dof = torch.matmul(ns,acs_dof)
|
|
530
|
+
if embeds is None:
|
|
531
|
+
zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
|
|
532
|
+
else:
|
|
533
|
+
zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
|
|
534
|
+
elif self.latent_dist=='laplacian':
|
|
535
|
+
if embeds is None:
|
|
536
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
|
|
537
|
+
else:
|
|
538
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
539
|
+
elif self.latent_dist=='cauchy':
|
|
540
|
+
if embeds is None:
|
|
541
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
|
|
542
|
+
else:
|
|
543
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
544
|
+
elif self.latent_dist=='normal':
|
|
545
|
+
if embeds is None:
|
|
546
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
547
|
+
else:
|
|
548
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
549
|
+
elif self.z_dist == 'gumbel':
|
|
550
|
+
if embeds is None:
|
|
551
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
552
|
+
else:
|
|
553
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
554
|
+
|
|
555
|
+
zs = zns
|
|
556
|
+
|
|
557
|
+
concentrate = self.decoder_concentrate(zs)
|
|
558
|
+
if self.loss_func == 'bernoulli':
|
|
559
|
+
log_theta = concentrate
|
|
560
|
+
else:
|
|
561
|
+
rate = concentrate.exp()
|
|
562
|
+
if self.loss_func != 'poisson':
|
|
563
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
564
|
+
|
|
565
|
+
if self.loss_func == 'negbinomial':
|
|
566
|
+
if self.use_zeroinflate:
|
|
567
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
568
|
+
else:
|
|
569
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
570
|
+
elif self.loss_func == 'poisson':
|
|
571
|
+
if self.use_zeroinflate:
|
|
572
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
573
|
+
else:
|
|
574
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
575
|
+
elif self.loss_func == 'multinomial':
|
|
576
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
577
|
+
elif self.loss_func == 'bernoulli':
|
|
578
|
+
if self.use_zeroinflate:
|
|
579
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
580
|
+
else:
|
|
581
|
+
pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
|
|
582
|
+
|
|
583
|
+
def guide3(self, xs, ys, embeds=None):
|
|
584
|
+
with pyro.plate('data'):
|
|
585
|
+
if embeds is None:
|
|
586
|
+
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
587
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
588
|
+
|
|
589
|
+
def model4(self, xs, us, ys, embeds=None):
|
|
590
|
+
pyro.module('PerturbFlow', self)
|
|
591
|
+
|
|
592
|
+
eps = torch.finfo(xs.dtype).eps
|
|
593
|
+
batch_size = xs.size(0)
|
|
594
|
+
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
595
|
+
|
|
596
|
+
if self.loss_func=='negbinomial':
|
|
597
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
598
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
599
|
+
|
|
600
|
+
if self.use_zeroinflate:
|
|
601
|
+
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
602
|
+
|
|
603
|
+
acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
|
|
604
|
+
|
|
605
|
+
I = torch.eye(self.code_size)
|
|
606
|
+
if self.latent_dist=='studentt':
|
|
607
|
+
acs_dof,acs_loc = self.codebook(I)
|
|
608
|
+
else:
|
|
609
|
+
acs_loc = self.codebook(I)
|
|
610
|
+
|
|
611
|
+
with pyro.plate('data'):
|
|
612
|
+
#prior = torch.zeros(batch_size, self.code_size, **self.options)
|
|
613
|
+
prior = self.encoder_n(xs)
|
|
614
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
|
|
615
|
+
|
|
616
|
+
zn_loc = torch.matmul(ns,acs_loc)
|
|
617
|
+
#prior_scale = torch.matmul(ns,acs_scale)
|
|
618
|
+
zn_scale = acs_scale
|
|
619
|
+
|
|
620
|
+
if self.latent_dist=='studentt':
|
|
621
|
+
prior_dof = torch.matmul(ns,acs_dof)
|
|
622
|
+
if embeds is None:
|
|
623
|
+
zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
|
|
624
|
+
else:
|
|
625
|
+
zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
|
|
626
|
+
elif self.latent_dist=='laplacian':
|
|
627
|
+
if embeds is None:
|
|
628
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
|
|
629
|
+
else:
|
|
630
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
631
|
+
elif self.latent_dist=='cauchy':
|
|
632
|
+
if embeds is None:
|
|
633
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
|
|
634
|
+
else:
|
|
635
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
636
|
+
elif self.latent_dist=='normal':
|
|
637
|
+
if embeds is None:
|
|
638
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
639
|
+
else:
|
|
640
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
641
|
+
elif self.z_dist == 'gumbel':
|
|
642
|
+
if embeds is None:
|
|
643
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
644
|
+
else:
|
|
645
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
646
|
+
|
|
647
|
+
if self.cell_factor_size>0:
|
|
648
|
+
#zus = self.decoder_undesired([zns,us])
|
|
649
|
+
zus = None
|
|
650
|
+
for i in np.arange(self.cell_factor_size):
|
|
651
|
+
if i==0:
|
|
652
|
+
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
653
|
+
else:
|
|
654
|
+
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
655
|
+
zs = zns+zus
|
|
656
|
+
else:
|
|
657
|
+
zs = zns
|
|
658
|
+
|
|
659
|
+
concentrate = self.decoder_concentrate(zs)
|
|
660
|
+
if self.loss_func == 'bernoulli':
|
|
661
|
+
log_theta = concentrate
|
|
662
|
+
else:
|
|
663
|
+
rate = concentrate.exp()
|
|
664
|
+
if self.loss_func != 'poisson':
|
|
665
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
666
|
+
|
|
667
|
+
if self.loss_func == 'negbinomial':
|
|
668
|
+
if self.use_zeroinflate:
|
|
669
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
670
|
+
else:
|
|
671
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
672
|
+
elif self.loss_func == 'poisson':
|
|
673
|
+
if self.use_zeroinflate:
|
|
674
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
675
|
+
else:
|
|
676
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
677
|
+
elif self.loss_func == 'multinomial':
|
|
678
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
679
|
+
elif self.loss_func == 'bernoulli':
|
|
680
|
+
if self.use_zeroinflate:
|
|
681
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
682
|
+
else:
|
|
683
|
+
pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
|
|
684
|
+
|
|
685
|
+
def guide4(self, xs, us, ys, embeds=None):
|
|
686
|
+
with pyro.plate('data'):
|
|
687
|
+
if embeds is None:
|
|
688
|
+
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
689
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
690
|
+
|
|
691
|
+
def _get_codebook(self):
|
|
692
|
+
I = torch.eye(self.code_size, **self.options)
|
|
693
|
+
if self.latent_dist=='studentt':
|
|
694
|
+
_,cb = self.codebook(I)
|
|
695
|
+
else:
|
|
696
|
+
cb = self.codebook(I)
|
|
697
|
+
return cb
|
|
698
|
+
|
|
699
|
+
def get_codebook(self):
|
|
700
|
+
"""
|
|
701
|
+
Return the mean part of metacell codebook
|
|
702
|
+
"""
|
|
703
|
+
cb = self._get_metacell_coordinates()
|
|
704
|
+
cb = tensor_to_numpy(cb)
|
|
705
|
+
return cb
|
|
706
|
+
|
|
707
|
+
def _get_cell_embedding(self, xs):
|
|
708
|
+
zns, _ = self.encoder_zn(xs)
|
|
709
|
+
return zns
|
|
710
|
+
|
|
711
|
+
def get_cell_embedding(self,
|
|
712
|
+
xs,
|
|
713
|
+
batch_size: int = 1024):
|
|
714
|
+
"""
|
|
715
|
+
Return cells' latent representations
|
|
716
|
+
|
|
717
|
+
Parameters
|
|
718
|
+
----------
|
|
719
|
+
xs
|
|
720
|
+
Single-cell expression matrix. It should be a Numpy array or a Pytorch Tensor.
|
|
721
|
+
batch_size
|
|
722
|
+
Size of batch processing.
|
|
723
|
+
use_decoder
|
|
724
|
+
If toggled on, the latent representations will be reconstructed from the metacell codebook
|
|
725
|
+
soft_assign
|
|
726
|
+
If toggled on, the assignments of cells will use probabilistic values.
|
|
727
|
+
"""
|
|
728
|
+
xs = self.preprocess(xs)
|
|
729
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
730
|
+
dataset = CustomDataset(xs)
|
|
731
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
732
|
+
|
|
733
|
+
Z = []
|
|
734
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
735
|
+
for X_batch, _ in dataloader:
|
|
736
|
+
zns = self._get_cell_embedding(X_batch)
|
|
737
|
+
Z.append(tensor_to_numpy(zns))
|
|
738
|
+
pbar.update(1)
|
|
739
|
+
|
|
740
|
+
Z = np.concatenate(Z)
|
|
741
|
+
return Z
|
|
742
|
+
|
|
743
|
+
def _code(self, xs):
|
|
744
|
+
if self.supervised_mode:
|
|
745
|
+
alpha = self.encoder_n(xs)
|
|
746
|
+
else:
|
|
747
|
+
zns,_ = self.encoder_zn(xs)
|
|
748
|
+
alpha = self.encoder_n(zns)
|
|
749
|
+
return alpha
|
|
750
|
+
|
|
751
|
+
def code(self, xs, batch_size=1024):
|
|
752
|
+
xs = self.preprocess(xs)
|
|
753
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
754
|
+
dataset = CustomDataset(xs)
|
|
755
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
756
|
+
|
|
757
|
+
A = []
|
|
758
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
759
|
+
for X_batch, _ in dataloader:
|
|
760
|
+
a = self._code(X_batch)
|
|
761
|
+
A.append(tensor_to_numpy(a))
|
|
762
|
+
pbar.update(1)
|
|
763
|
+
|
|
764
|
+
A = np.concatenate(A)
|
|
765
|
+
return A
|
|
766
|
+
|
|
767
|
+
def _soft_assignments(self, xs):
|
|
768
|
+
alpha = self._code(xs)
|
|
769
|
+
alpha = self.softmax(alpha)
|
|
770
|
+
return alpha
|
|
771
|
+
|
|
772
|
+
def soft_assignments(self, xs, batch_size=1024):
|
|
773
|
+
"""
|
|
774
|
+
Map cells to metacells and return the probabilistic values of metacell assignments
|
|
775
|
+
"""
|
|
776
|
+
xs = self.preprocess(xs)
|
|
777
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
778
|
+
dataset = CustomDataset(xs)
|
|
779
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
780
|
+
|
|
781
|
+
A = []
|
|
782
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
783
|
+
for X_batch, _ in dataloader:
|
|
784
|
+
a = self._soft_assignments(X_batch)
|
|
785
|
+
A.append(tensor_to_numpy(a))
|
|
786
|
+
pbar.update(1)
|
|
787
|
+
|
|
788
|
+
A = np.concatenate(A)
|
|
789
|
+
return A
|
|
790
|
+
|
|
791
|
+
def _hard_assignments(self, xs):
|
|
792
|
+
alpha = self._code(xs)
|
|
793
|
+
res, ind = torch.topk(alpha, 1)
|
|
794
|
+
ns = torch.zeros_like(alpha).scatter_(1, ind, 1.0)
|
|
795
|
+
return ns
|
|
796
|
+
|
|
797
|
+
def hard_assignments(self, xs, batch_size=1024):
|
|
798
|
+
"""
|
|
799
|
+
Map cells to metacells and return the assigned metacell identities.
|
|
800
|
+
"""
|
|
801
|
+
xs = self.preprocess(xs)
|
|
802
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
803
|
+
dataset = CustomDataset(xs)
|
|
804
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
805
|
+
|
|
806
|
+
A = []
|
|
807
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
808
|
+
for X_batch, _ in dataloader:
|
|
809
|
+
a = self._hard_assignments(X_batch)
|
|
810
|
+
A.append(tensor_to_numpy(a))
|
|
811
|
+
pbar.update(1)
|
|
812
|
+
|
|
813
|
+
A = np.concatenate(A)
|
|
814
|
+
return A
|
|
815
|
+
|
|
816
|
+
def _cell_state_response(self, xs, factor_idx, perturb):
|
|
817
|
+
zns,_ = self.encoder_zn(xs)
|
|
818
|
+
if type(factor_idx) == str:
|
|
819
|
+
factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
|
|
820
|
+
|
|
821
|
+
if perturb.ndim==2:
|
|
822
|
+
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
823
|
+
else:
|
|
824
|
+
ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
|
|
825
|
+
|
|
826
|
+
return ms
|
|
827
|
+
|
|
828
|
+
def get_cell_state_response(self,
|
|
829
|
+
xs,
|
|
830
|
+
factor_idx,
|
|
831
|
+
perturb,
|
|
832
|
+
batch_size: int = 1024):
|
|
833
|
+
"""
|
|
834
|
+
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
835
|
+
|
|
836
|
+
"""
|
|
837
|
+
xs = self.preprocess(xs)
|
|
838
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
839
|
+
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
840
|
+
dataset = CustomDataset2(xs,ps)
|
|
841
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
842
|
+
|
|
843
|
+
Z = []
|
|
844
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
845
|
+
for X_batch, P_batch, _ in dataloader:
|
|
846
|
+
zns = self._cell_state_response(X_batch, factor_idx, P_batch)
|
|
847
|
+
Z.append(tensor_to_numpy(zns))
|
|
848
|
+
pbar.update(1)
|
|
849
|
+
|
|
850
|
+
Z = np.concatenate(Z)
|
|
851
|
+
return Z
|
|
852
|
+
|
|
853
|
+
def get_metacell_response(self, factor_idx, perturb):
|
|
854
|
+
zs = self._get_codebook()
|
|
855
|
+
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
856
|
+
|
|
857
|
+
if type(factor_idx) == str:
|
|
858
|
+
factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
|
|
859
|
+
|
|
860
|
+
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
861
|
+
return tensor_to_numpy(ms)
|
|
862
|
+
|
|
863
|
+
def _get_expression_response(self, delta_zs):
|
|
864
|
+
return self.decoder_concentrate(delta_zs)
|
|
865
|
+
|
|
866
|
+
def get_expression_response(self,
|
|
867
|
+
delta_zs,
|
|
868
|
+
batch_size: int = 1024):
|
|
869
|
+
"""
|
|
870
|
+
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
871
|
+
|
|
872
|
+
"""
|
|
873
|
+
delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
|
|
874
|
+
dataset = CustomDataset(delta_zs)
|
|
875
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
876
|
+
|
|
877
|
+
R = []
|
|
878
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
879
|
+
for delta_Z_batch, _ in dataloader:
|
|
880
|
+
r = self._cell_move(delta_Z_batch)
|
|
881
|
+
R.append(tensor_to_numpy(r))
|
|
882
|
+
pbar.update(1)
|
|
883
|
+
|
|
884
|
+
R = np.concatenate(R)
|
|
885
|
+
return R
|
|
886
|
+
|
|
887
|
+
def preprocess(self, xs, threshold=0):
|
|
888
|
+
if self.loss_func == 'bernoulli':
|
|
889
|
+
ad = sc.AnnData(xs)
|
|
890
|
+
binarize(ad, threshold=threshold)
|
|
891
|
+
xs = ad.X.copy()
|
|
892
|
+
else:
|
|
893
|
+
xs = np.round(xs)
|
|
894
|
+
|
|
895
|
+
if sparse.issparse(xs):
|
|
896
|
+
xs = xs.toarray()
|
|
897
|
+
return xs
|
|
898
|
+
|
|
899
|
+
def fit(self, xs,
|
|
900
|
+
us = None,
|
|
901
|
+
ys = None,
|
|
902
|
+
zs = None,
|
|
903
|
+
num_epochs: int = 200,
|
|
904
|
+
learning_rate: float = 0.0001,
|
|
905
|
+
batch_size: int = 256,
|
|
906
|
+
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
907
|
+
beta_1: float = 0.9,
|
|
908
|
+
weight_decay: float = 0.005,
|
|
909
|
+
decay_rate: float = 0.9,
|
|
910
|
+
config_enum: str = 'parallel',
|
|
911
|
+
threshold: int = 0,
|
|
912
|
+
use_jax: bool = False):
|
|
913
|
+
"""
|
|
914
|
+
Train the PerturbFlow model.
|
|
915
|
+
|
|
916
|
+
Parameters
|
|
917
|
+
----------
|
|
918
|
+
xs
|
|
919
|
+
Single-cell experssion matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are features.
|
|
920
|
+
us
|
|
921
|
+
cell-level factor matrix.
|
|
922
|
+
ys
|
|
923
|
+
Desired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are desired factors.
|
|
924
|
+
num_epochs
|
|
925
|
+
Number of training epochs.
|
|
926
|
+
learning_rate
|
|
927
|
+
Parameter for training.
|
|
928
|
+
batch_size
|
|
929
|
+
Size of batch processing.
|
|
930
|
+
algo
|
|
931
|
+
Optimization algorithm.
|
|
932
|
+
beta_1
|
|
933
|
+
Parameter for optimization.
|
|
934
|
+
weight_decay
|
|
935
|
+
Parameter for optimization.
|
|
936
|
+
decay_rate
|
|
937
|
+
Parameter for optimization.
|
|
938
|
+
use_jax
|
|
939
|
+
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
940
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
|
|
941
|
+
"""
|
|
942
|
+
xs = self.preprocess(xs, threshold=threshold)
|
|
943
|
+
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
944
|
+
if us is not None:
|
|
945
|
+
us = convert_to_tensor(us, dtype=self.dtype, device=self.get_device())
|
|
946
|
+
if ys is not None:
|
|
947
|
+
ys = convert_to_tensor(ys, dtype=self.dtype, device=self.get_device())
|
|
948
|
+
if zs is not None:
|
|
949
|
+
zs = convert_to_tensor(zs, dtype=self.dtype, device=self.get_device())
|
|
950
|
+
|
|
951
|
+
dataset = CustomDataset4(xs, us, ys, zs)
|
|
952
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
953
|
+
|
|
954
|
+
# setup the optimizer
|
|
955
|
+
optim_params = {'lr': learning_rate, 'betas': (beta_1, 0.999), 'weight_decay': weight_decay}
|
|
956
|
+
|
|
957
|
+
if algo.lower()=='rmsprop':
|
|
958
|
+
optimizer = torch.optim.RMSprop
|
|
959
|
+
elif algo.lower()=='adam':
|
|
960
|
+
optimizer = torch.optim.Adam
|
|
961
|
+
elif algo.lower() == 'adamw':
|
|
962
|
+
optimizer = torch.optim.AdamW
|
|
963
|
+
else:
|
|
964
|
+
raise ValueError("An optimization algorithm must be specified.")
|
|
965
|
+
scheduler = ExponentialLR({'optimizer': optimizer, 'optim_args': optim_params, 'gamma': decay_rate})
|
|
966
|
+
|
|
967
|
+
pyro.clear_param_store()
|
|
968
|
+
|
|
969
|
+
# set up the loss(es) for inference, wrapping the guide in config_enumerate builds the loss as a sum
|
|
970
|
+
# by enumerating each class label form the sampled discrete categorical distribution in the model
|
|
971
|
+
Elbo = JitTraceEnum_ELBO if use_jax else TraceEnum_ELBO
|
|
972
|
+
elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False)
|
|
973
|
+
if us is None:
|
|
974
|
+
if ys is None:
|
|
975
|
+
guide = config_enumerate(self.guide1, config_enum, expand=True)
|
|
976
|
+
loss_basic = SVI(self.model1, guide, scheduler, loss=elbo)
|
|
977
|
+
else:
|
|
978
|
+
guide = config_enumerate(self.guide3, config_enum, expand=True)
|
|
979
|
+
loss_basic = SVI(self.model3, guide, scheduler, loss=elbo)
|
|
980
|
+
else:
|
|
981
|
+
if ys is None:
|
|
982
|
+
guide = config_enumerate(self.guide2, config_enum, expand=True)
|
|
983
|
+
loss_basic = SVI(self.model2, guide, scheduler, loss=elbo)
|
|
984
|
+
else:
|
|
985
|
+
guide = config_enumerate(self.guide4, config_enum, expand=True)
|
|
986
|
+
loss_basic = SVI(self.model4, guide, scheduler, loss=elbo)
|
|
987
|
+
|
|
988
|
+
# build a list of all losses considered
|
|
989
|
+
losses = [loss_basic]
|
|
990
|
+
num_losses = len(losses)
|
|
991
|
+
|
|
992
|
+
with tqdm(total=num_epochs, desc='Training', unit='epoch') as pbar:
|
|
993
|
+
for epoch in range(num_epochs):
|
|
994
|
+
epoch_losses = [0.0] * num_losses
|
|
995
|
+
for batch_x, batch_u, batch_y, batch_z, _ in dataloader:
|
|
996
|
+
if us is None:
|
|
997
|
+
batch_u = None
|
|
998
|
+
if ys is None:
|
|
999
|
+
batch_y = None
|
|
1000
|
+
if zs is None:
|
|
1001
|
+
batch_z = None
|
|
1002
|
+
|
|
1003
|
+
for loss_id in range(num_losses):
|
|
1004
|
+
if batch_u is None:
|
|
1005
|
+
if batch_y is None:
|
|
1006
|
+
new_loss = losses[loss_id].step(batch_x)
|
|
1007
|
+
else:
|
|
1008
|
+
new_loss = losses[loss_id].step(batch_x, batch_y, batch_z)
|
|
1009
|
+
else:
|
|
1010
|
+
if batch_y is None:
|
|
1011
|
+
new_loss = losses[loss_id].step(batch_x, batch_u)
|
|
1012
|
+
else:
|
|
1013
|
+
new_loss = losses[loss_id].step(batch_x, batch_u, batch_y, batch_z)
|
|
1014
|
+
epoch_losses[loss_id] += new_loss
|
|
1015
|
+
|
|
1016
|
+
avg_epoch_losses_ = map(lambda v: v / len(dataloader), epoch_losses)
|
|
1017
|
+
avg_epoch_losses = map(lambda v: "{:.4f}".format(v), avg_epoch_losses_)
|
|
1018
|
+
|
|
1019
|
+
# store the loss
|
|
1020
|
+
str_loss = " ".join(map(str, avg_epoch_losses))
|
|
1021
|
+
|
|
1022
|
+
# Update progress bar
|
|
1023
|
+
pbar.set_postfix({'loss': str_loss})
|
|
1024
|
+
pbar.update(1)
|
|
1025
|
+
|
|
1026
|
+
@classmethod
|
|
1027
|
+
def save_model(cls, model, file_path, compression=False):
|
|
1028
|
+
"""Save the model to the specified file path."""
|
|
1029
|
+
file_path = os.path.abspath(file_path)
|
|
1030
|
+
|
|
1031
|
+
model.eval()
|
|
1032
|
+
if compression:
|
|
1033
|
+
with gzip.open(file_path, 'wb') as pickle_file:
|
|
1034
|
+
pickle.dump(model, pickle_file)
|
|
1035
|
+
else:
|
|
1036
|
+
with open(file_path, 'wb') as pickle_file:
|
|
1037
|
+
pickle.dump(model, pickle_file)
|
|
1038
|
+
|
|
1039
|
+
print(f'Model saved to {file_path}')
|
|
1040
|
+
|
|
1041
|
+
@classmethod
|
|
1042
|
+
def load_model(cls, file_path):
|
|
1043
|
+
"""Load the model from the specified file path and return an instance."""
|
|
1044
|
+
print(f'Model loaded from {file_path}')
|
|
1045
|
+
|
|
1046
|
+
file_path = os.path.abspath(file_path)
|
|
1047
|
+
if file_path.endswith('gz'):
|
|
1048
|
+
with gzip.open(file_path, 'rb') as pickle_file:
|
|
1049
|
+
model = pickle.load(pickle_file)
|
|
1050
|
+
else:
|
|
1051
|
+
with open(file_path, 'rb') as pickle_file:
|
|
1052
|
+
model = pickle.load(pickle_file)
|
|
1053
|
+
|
|
1054
|
+
return model
|
|
1055
|
+
|
|
1056
|
+
|
|
1057
|
+
EXAMPLE_RUN = (
|
|
1058
|
+
"example run: PerturbFlow --help"
|
|
1059
|
+
)
|
|
1060
|
+
|
|
1061
|
+
def parse_args():
|
|
1062
|
+
parser = argparse.ArgumentParser(
|
|
1063
|
+
description="PerturbFlow\n{}".format(EXAMPLE_RUN))
|
|
1064
|
+
|
|
1065
|
+
parser.add_argument(
|
|
1066
|
+
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
1067
|
+
)
|
|
1068
|
+
parser.add_argument(
|
|
1069
|
+
"--jit", action="store_true", help="use PyTorch jit to speed up training"
|
|
1070
|
+
)
|
|
1071
|
+
parser.add_argument(
|
|
1072
|
+
"-n", "--num-epochs", default=200, type=int, help="number of epochs to run"
|
|
1073
|
+
)
|
|
1074
|
+
parser.add_argument(
|
|
1075
|
+
"-enum",
|
|
1076
|
+
"--enum-discrete",
|
|
1077
|
+
default="parallel",
|
|
1078
|
+
help="parallel, sequential or none. uses parallel enumeration by default",
|
|
1079
|
+
)
|
|
1080
|
+
parser.add_argument(
|
|
1081
|
+
"-data",
|
|
1082
|
+
"--data-file",
|
|
1083
|
+
default=None,
|
|
1084
|
+
type=str,
|
|
1085
|
+
help="the data file",
|
|
1086
|
+
)
|
|
1087
|
+
parser.add_argument(
|
|
1088
|
+
"-cf",
|
|
1089
|
+
"--cell-factor-file",
|
|
1090
|
+
default=None,
|
|
1091
|
+
type=str,
|
|
1092
|
+
help="the file for the record of cell-level factors",
|
|
1093
|
+
)
|
|
1094
|
+
parser.add_argument(
|
|
1095
|
+
"-delta",
|
|
1096
|
+
"--delta",
|
|
1097
|
+
default=0.0,
|
|
1098
|
+
type=float,
|
|
1099
|
+
help="penalty weight for zero inflation loss",
|
|
1100
|
+
)
|
|
1101
|
+
parser.add_argument(
|
|
1102
|
+
"-64",
|
|
1103
|
+
"--float64",
|
|
1104
|
+
action="store_true",
|
|
1105
|
+
help="use double float precision",
|
|
1106
|
+
)
|
|
1107
|
+
parser.add_argument(
|
|
1108
|
+
"--z-dist",
|
|
1109
|
+
default='normal',
|
|
1110
|
+
type=str,
|
|
1111
|
+
choices=['normal','laplacian','studentt','cauchy'],
|
|
1112
|
+
help="distribution model for latent representation",
|
|
1113
|
+
)
|
|
1114
|
+
parser.add_argument(
|
|
1115
|
+
"-cs",
|
|
1116
|
+
"--codebook-size",
|
|
1117
|
+
default=100,
|
|
1118
|
+
type=int,
|
|
1119
|
+
help="size of vector quantization codebook",
|
|
1120
|
+
)
|
|
1121
|
+
parser.add_argument(
|
|
1122
|
+
"-zd",
|
|
1123
|
+
"--z-dim",
|
|
1124
|
+
default=10,
|
|
1125
|
+
type=int,
|
|
1126
|
+
help="size of the tensor representing the latent variable z variable",
|
|
1127
|
+
)
|
|
1128
|
+
parser.add_argument(
|
|
1129
|
+
"-hl",
|
|
1130
|
+
"--hidden-layers",
|
|
1131
|
+
nargs="+",
|
|
1132
|
+
default=[500],
|
|
1133
|
+
type=int,
|
|
1134
|
+
help="a tuple (or list) of MLP layers to be used in the neural networks "
|
|
1135
|
+
"representing the parameters of the distributions in our model",
|
|
1136
|
+
)
|
|
1137
|
+
parser.add_argument(
|
|
1138
|
+
"-hla",
|
|
1139
|
+
"--hidden-layer-activation",
|
|
1140
|
+
default='relu',
|
|
1141
|
+
type=str,
|
|
1142
|
+
choices=['relu','softplus','leakyrelu','linear'],
|
|
1143
|
+
help="activation function for hidden layers",
|
|
1144
|
+
)
|
|
1145
|
+
parser.add_argument(
|
|
1146
|
+
"-plf",
|
|
1147
|
+
"--post-layer-function",
|
|
1148
|
+
nargs="+",
|
|
1149
|
+
default=['layernorm'],
|
|
1150
|
+
type=str,
|
|
1151
|
+
help="post functions for hidden layers, could be none, dropout, layernorm, batchnorm, or combination, default is 'dropout layernorm'",
|
|
1152
|
+
)
|
|
1153
|
+
parser.add_argument(
|
|
1154
|
+
"-paf",
|
|
1155
|
+
"--post-activation-function",
|
|
1156
|
+
nargs="+",
|
|
1157
|
+
default=['none'],
|
|
1158
|
+
type=str,
|
|
1159
|
+
help="post functions for activation layers, could be none or dropout, default is 'none'",
|
|
1160
|
+
)
|
|
1161
|
+
parser.add_argument(
|
|
1162
|
+
"-id",
|
|
1163
|
+
"--inverse-dispersion",
|
|
1164
|
+
default=10.0,
|
|
1165
|
+
type=float,
|
|
1166
|
+
help="inverse dispersion prior for negative binomial",
|
|
1167
|
+
)
|
|
1168
|
+
parser.add_argument(
|
|
1169
|
+
"-lr",
|
|
1170
|
+
"--learning-rate",
|
|
1171
|
+
default=0.0001,
|
|
1172
|
+
type=float,
|
|
1173
|
+
help="learning rate for Adam optimizer",
|
|
1174
|
+
)
|
|
1175
|
+
parser.add_argument(
|
|
1176
|
+
"-dr",
|
|
1177
|
+
"--decay-rate",
|
|
1178
|
+
default=0.9,
|
|
1179
|
+
type=float,
|
|
1180
|
+
help="decay rate for Adam optimizer",
|
|
1181
|
+
)
|
|
1182
|
+
parser.add_argument(
|
|
1183
|
+
"--layer-dropout-rate",
|
|
1184
|
+
default=0.1,
|
|
1185
|
+
type=float,
|
|
1186
|
+
help="droput rate for neural networks",
|
|
1187
|
+
)
|
|
1188
|
+
parser.add_argument(
|
|
1189
|
+
"-b1",
|
|
1190
|
+
"--beta-1",
|
|
1191
|
+
default=0.95,
|
|
1192
|
+
type=float,
|
|
1193
|
+
help="beta-1 parameter for Adam optimizer",
|
|
1194
|
+
)
|
|
1195
|
+
parser.add_argument(
|
|
1196
|
+
"-bs",
|
|
1197
|
+
"--batch-size",
|
|
1198
|
+
default=1000,
|
|
1199
|
+
type=int,
|
|
1200
|
+
help="number of cells to be considered in a batch",
|
|
1201
|
+
)
|
|
1202
|
+
parser.add_argument(
|
|
1203
|
+
"-gp",
|
|
1204
|
+
"--gate-prior",
|
|
1205
|
+
default=0.6,
|
|
1206
|
+
type=float,
|
|
1207
|
+
help="gate prior for zero-inflated model",
|
|
1208
|
+
)
|
|
1209
|
+
parser.add_argument(
|
|
1210
|
+
"-likeli",
|
|
1211
|
+
"--likelihood",
|
|
1212
|
+
default='negbinomial',
|
|
1213
|
+
type=str,
|
|
1214
|
+
choices=['negbinomial', 'multinomial', 'poisson', 'gaussian','lognormal'],
|
|
1215
|
+
help="specify the distribution likelihood function",
|
|
1216
|
+
)
|
|
1217
|
+
parser.add_argument(
|
|
1218
|
+
"-dirichlet",
|
|
1219
|
+
"--use-dirichlet",
|
|
1220
|
+
action="store_true",
|
|
1221
|
+
help="use Dirichlet distribution over gene frequency",
|
|
1222
|
+
)
|
|
1223
|
+
parser.add_argument(
|
|
1224
|
+
"-mass",
|
|
1225
|
+
"--dirichlet-mass",
|
|
1226
|
+
default=1,
|
|
1227
|
+
type=float,
|
|
1228
|
+
help="mass param for dirichlet model",
|
|
1229
|
+
)
|
|
1230
|
+
parser.add_argument(
|
|
1231
|
+
"-zi",
|
|
1232
|
+
"--zero-inflation",
|
|
1233
|
+
default='exact',
|
|
1234
|
+
type=str,
|
|
1235
|
+
choices=['none','exact','inexact'],
|
|
1236
|
+
help="use zero-inflated estimation",
|
|
1237
|
+
)
|
|
1238
|
+
parser.add_argument(
|
|
1239
|
+
"--seed",
|
|
1240
|
+
default=None,
|
|
1241
|
+
type=int,
|
|
1242
|
+
help="seed for controlling randomness in this example",
|
|
1243
|
+
)
|
|
1244
|
+
parser.add_argument(
|
|
1245
|
+
"--save-model",
|
|
1246
|
+
default=None,
|
|
1247
|
+
type=str,
|
|
1248
|
+
help="path to save model for prediction",
|
|
1249
|
+
)
|
|
1250
|
+
args = parser.parse_args()
|
|
1251
|
+
return args
|
|
1252
|
+
|
|
1253
|
+
def main():
|
|
1254
|
+
args = parse_args()
|
|
1255
|
+
assert (
|
|
1256
|
+
(args.data_file is not None) and (
|
|
1257
|
+
os.path.exists(args.data_file))
|
|
1258
|
+
), "data file must be provided"
|
|
1259
|
+
|
|
1260
|
+
if args.seed is not None:
|
|
1261
|
+
set_random_seed(args.seed)
|
|
1262
|
+
|
|
1263
|
+
if args.float64:
|
|
1264
|
+
dtype = torch.float64
|
|
1265
|
+
torch.set_default_dtype(torch.float64)
|
|
1266
|
+
else:
|
|
1267
|
+
dtype = torch.float32
|
|
1268
|
+
torch.set_default_dtype(torch.float32)
|
|
1269
|
+
|
|
1270
|
+
xs = dt.fread(file=args.data_file, header=True).to_numpy()
|
|
1271
|
+
us = None
|
|
1272
|
+
if args.cell_factor_file is not None:
|
|
1273
|
+
us = dt.fread(file=args.cell_factor_file, header=True).to_numpy()
|
|
1274
|
+
|
|
1275
|
+
input_size = xs.shape[1]
|
|
1276
|
+
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1277
|
+
|
|
1278
|
+
latent_dist = args.z_dist
|
|
1279
|
+
|
|
1280
|
+
###########################################
|
|
1281
|
+
perturbflow = PerturbFlow(
|
|
1282
|
+
input_size=input_size,
|
|
1283
|
+
cell_factor_size=cell_factor_size,
|
|
1284
|
+
inverse_dispersion=args.inverse_dispersion,
|
|
1285
|
+
latent_dim=args.latent_dim,
|
|
1286
|
+
hidden_layers=args.hidden_layers,
|
|
1287
|
+
hidden_layer_activation=args.hidden_layer_activation,
|
|
1288
|
+
use_cuda=args.cuda,
|
|
1289
|
+
config_enum=args.enum_discrete,
|
|
1290
|
+
use_dirichlet=args.use_dirichlet,
|
|
1291
|
+
zero_inflation=args.zero_inflation,
|
|
1292
|
+
gate_prior=args.gate_prior,
|
|
1293
|
+
delta=args.delta,
|
|
1294
|
+
loss_func=args.likelihood,
|
|
1295
|
+
dirichlet_mass=args.dirichlet_mass,
|
|
1296
|
+
nn_dropout=args.layer_dropout_rate,
|
|
1297
|
+
post_layer_fct=args.post_layer_function,
|
|
1298
|
+
post_act_fct=args.post_activation_function,
|
|
1299
|
+
codebook_size=args.codebook_size,
|
|
1300
|
+
latent_dist = latent_dist,
|
|
1301
|
+
dtype=dtype,
|
|
1302
|
+
)
|
|
1303
|
+
|
|
1304
|
+
perturbflow.fit(xs, us=us,
|
|
1305
|
+
num_epochs=args.num_epochs,
|
|
1306
|
+
learning_rate=args.learning_rate,
|
|
1307
|
+
batch_size=args.batch_size,
|
|
1308
|
+
beta_1=args.beta_1,
|
|
1309
|
+
decay_rate=args.decay_rate,
|
|
1310
|
+
use_jax=args.jit,
|
|
1311
|
+
config_enum=args.enum_discrete,
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
if args.save_model is not None:
|
|
1315
|
+
if args.save_model.endswith('gz'):
|
|
1316
|
+
PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
|
|
1317
|
+
else:
|
|
1318
|
+
PerturbFlow.save_model(perturbflow, args.save_model)
|
|
1319
|
+
|
|
1320
|
+
|
|
1321
|
+
|
|
1322
|
+
if __name__ == "__main__":
|
|
1323
|
+
|
|
1324
|
+
main()
|