SURE-tools 2.0.9__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
SURE/PerturbFlow.py ADDED
@@ -0,0 +1,1324 @@
1
+ import pyro
2
+ import pyro.distributions as dist
3
+ from pyro.optim import ExponentialLR
4
+ from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, config_enumerate
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader
9
+ from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_probs
10
+ from torch.distributions import constraints
11
+ from torch.distributions.transforms import SoftmaxTransform
12
+
13
+ from .utils.custom_mlp import MLP, Exp
14
+ from .utils.utils import CustomDataset, CustomDataset2, CustomDataset4, tensor_to_numpy, convert_to_tensor
15
+
16
+
17
+ import os
18
+ import argparse
19
+ import random
20
+ import numpy as np
21
+ import datatable as dt
22
+ from tqdm import tqdm
23
+ from scipy import sparse
24
+
25
+ import scanpy as sc
26
+ from .atac import binarize
27
+
28
+ from typing import Literal
29
+
30
+ import warnings
31
+ warnings.filterwarnings("ignore")
32
+
33
+ import dill as pickle
34
+ import gzip
35
+ from packaging.version import Version
36
+ torch_version = torch.__version__
37
+
38
+
39
+ def set_random_seed(seed):
40
+ # Set seed for PyTorch
41
+ torch.manual_seed(seed)
42
+
43
+ # If using CUDA, set the seed for CUDA
44
+ if torch.cuda.is_available():
45
+ torch.cuda.manual_seed(seed)
46
+ torch.cuda.manual_seed_all(seed) # For multi-GPU setups.
47
+
48
+ # Set seed for NumPy
49
+ np.random.seed(seed)
50
+
51
+ # Set seed for Python's random module
52
+ random.seed(seed)
53
+
54
+ # Set seed for Pyro
55
+ pyro.set_rng_seed(seed)
56
+
57
+ class PerturbFlow(nn.Module):
58
+ """SUccinct REpresentation of single-omics cells
59
+
60
+ Parameters
61
+ ----------
62
+ inpute_size
63
+ Number of features (e.g., genes, peaks, proteins, etc.) per cell.
64
+ codebook_size
65
+ Number of metacells.
66
+ cell_factor_size
67
+ Number of cell-level factors.
68
+ z_dim
69
+ Dimensionality of latent states and metacells.
70
+ hidden_layers
71
+ A list give the numbers of neurons for each hidden layer.
72
+ loss_func
73
+ The likelihood model for single-cell data generation.
74
+
75
+ One of the following:
76
+ * ``'negbinomial'`` - negative binomial distribution (default)
77
+ * ``'poisson'`` - poisson distribution
78
+ * ``'multinomial'`` - multinomial distribution
79
+ z_dist
80
+ The distribution model for latent states.
81
+
82
+ One of the following:
83
+ * ``'normal'`` - normal distribution
84
+ * ``'laplacian'`` - Laplacian distribution
85
+ * ``'studentt'`` - Student-t distribution.
86
+ use_cuda
87
+ A boolean option for switching on cuda device.
88
+
89
+ Examples
90
+ --------
91
+ >>>
92
+ >>>
93
+ >>>
94
+
95
+ """
96
+ def __init__(self,
97
+ input_size: int,
98
+ codebook_size: int = 200,
99
+ cell_factor_size: int = 0,
100
+ cell_factor_names: list = None,
101
+ supervised_mode: bool = False,
102
+ z_dim: int = 10,
103
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
104
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
105
+ inverse_dispersion: float = 10.0,
106
+ use_zeroinflate: bool = True,
107
+ hidden_layers: list = [300],
108
+ hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
109
+ nn_dropout: float = 0.1,
110
+ post_layer_fct: list = ['layernorm'],
111
+ post_act_fct: list = None,
112
+ config_enum: str = 'parallel',
113
+ use_cuda: bool = False,
114
+ seed: int = 42,
115
+ dtype = torch.float32, # type: ignore
116
+ ):
117
+ super().__init__()
118
+
119
+ self.input_size = input_size
120
+ self.cell_factor_size = cell_factor_size
121
+ self.inverse_dispersion = inverse_dispersion
122
+ self.latent_dim = z_dim
123
+ self.hidden_layers = hidden_layers
124
+ self.decoder_hidden_layers = hidden_layers[::-1]
125
+ self.allow_broadcast = config_enum == 'parallel'
126
+ self.use_cuda = use_cuda
127
+ self.loss_func = loss_func
128
+ self.options = None
129
+ self.code_size=codebook_size
130
+ self.supervised_mode=supervised_mode
131
+ self.latent_dist = z_dist
132
+ self.dtype = dtype
133
+ self.use_zeroinflate=use_zeroinflate
134
+ self.nn_dropout = nn_dropout
135
+ self.post_layer_fct = post_layer_fct
136
+ self.post_act_fct = post_act_fct
137
+ self.hidden_layer_activation = hidden_layer_activation
138
+ self.cell_factor_names = cell_factor_names
139
+
140
+ self.codebook_weights = None
141
+
142
+ set_random_seed(seed)
143
+ self.setup_networks()
144
+
145
+ def setup_networks(self):
146
+ latent_dim = self.latent_dim
147
+ hidden_sizes = self.hidden_layers
148
+
149
+ nn_layer_norm, nn_batch_norm, nn_layer_dropout = False, False, False
150
+ na_layer_norm, na_batch_norm, na_layer_dropout = False, False, False
151
+
152
+ if self.post_layer_fct is not None:
153
+ nn_layer_norm=True if ('layernorm' in self.post_layer_fct) or ('layer_norm' in self.post_layer_fct) else False
154
+ nn_batch_norm=True if ('batchnorm' in self.post_layer_fct) or ('batch_norm' in self.post_layer_fct) else False
155
+ nn_layer_dropout=True if 'dropout' in self.post_layer_fct else False
156
+
157
+ if self.post_act_fct is not None:
158
+ na_layer_norm=True if ('layernorm' in self.post_act_fct) or ('layer_norm' in self.post_act_fct) else False
159
+ na_batch_norm=True if ('batchnorm' in self.post_act_fct) or ('batch_norm' in self.post_act_fct) else False
160
+ na_layer_dropout=True if 'dropout' in self.post_act_fct else False
161
+
162
+ if nn_layer_norm and nn_batch_norm and nn_layer_dropout:
163
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
164
+ elif nn_layer_norm and nn_layer_dropout:
165
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
166
+ elif nn_batch_norm and nn_layer_dropout:
167
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
168
+ elif nn_layer_norm and nn_batch_norm:
169
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
170
+ elif nn_layer_norm:
171
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
172
+ elif nn_batch_norm:
173
+ post_layer_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
174
+ elif nn_layer_dropout:
175
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
176
+ else:
177
+ post_layer_fct = lambda layer_ix, total_layers, layer: None
178
+
179
+ if na_layer_norm and na_batch_norm and na_layer_dropout:
180
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
181
+ elif na_layer_norm and na_layer_dropout:
182
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
183
+ elif na_batch_norm and na_layer_dropout:
184
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
185
+ elif na_layer_norm and na_batch_norm:
186
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
187
+ elif na_layer_norm:
188
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
189
+ elif na_batch_norm:
190
+ post_act_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
191
+ elif na_layer_dropout:
192
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
193
+ else:
194
+ post_act_fct = lambda layer_ix, total_layers, layer: None
195
+
196
+ if self.hidden_layer_activation == 'relu':
197
+ activate_fct = nn.ReLU
198
+ elif self.hidden_layer_activation == 'softplus':
199
+ activate_fct = nn.Softplus
200
+ elif self.hidden_layer_activation == 'leakyrelu':
201
+ activate_fct = nn.LeakyReLU
202
+ elif self.hidden_layer_activation == 'linear':
203
+ activate_fct = nn.Identity
204
+
205
+ if self.supervised_mode:
206
+ self.encoder_n = MLP(
207
+ [self.input_size] + hidden_sizes + [self.code_size],
208
+ activation=activate_fct,
209
+ output_activation=None,
210
+ post_layer_fct=post_layer_fct,
211
+ post_act_fct=post_act_fct,
212
+ allow_broadcast=self.allow_broadcast,
213
+ use_cuda=self.use_cuda,
214
+ )
215
+ else:
216
+ self.encoder_n = MLP(
217
+ [self.latent_dim] + hidden_sizes + [self.code_size],
218
+ activation=activate_fct,
219
+ output_activation=None,
220
+ post_layer_fct=post_layer_fct,
221
+ post_act_fct=post_act_fct,
222
+ allow_broadcast=self.allow_broadcast,
223
+ use_cuda=self.use_cuda,
224
+ )
225
+
226
+ self.encoder_zn = MLP(
227
+ [self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
228
+ activation=activate_fct,
229
+ output_activation=[None, Exp],
230
+ post_layer_fct=post_layer_fct,
231
+ post_act_fct=post_act_fct,
232
+ allow_broadcast=self.allow_broadcast,
233
+ use_cuda=self.use_cuda,
234
+ )
235
+
236
+ if self.cell_factor_size>0:
237
+ self.cell_factor_effect = nn.ModuleList()
238
+ for i in np.arange(self.cell_factor_size):
239
+ self.cell_factor_effect.append(MLP(
240
+ [self.latent_dim+1] + hidden_sizes + [self.latent_dim],
241
+ activation=activate_fct,
242
+ output_activation=None,
243
+ post_layer_fct=post_layer_fct,
244
+ post_act_fct=post_act_fct,
245
+ allow_broadcast=self.allow_broadcast,
246
+ use_cuda=self.use_cuda,
247
+ )
248
+ )
249
+
250
+ self.decoder_concentrate = MLP(
251
+ [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
252
+ activation=activate_fct,
253
+ output_activation=None,
254
+ post_layer_fct=post_layer_fct,
255
+ post_act_fct=post_act_fct,
256
+ allow_broadcast=self.allow_broadcast,
257
+ use_cuda=self.use_cuda,
258
+ )
259
+
260
+ if self.latent_dist == 'studentt':
261
+ self.codebook = MLP(
262
+ [self.code_size] + hidden_sizes + [[latent_dim,latent_dim]],
263
+ activation=activate_fct,
264
+ output_activation=[Exp,None],
265
+ post_layer_fct=post_layer_fct,
266
+ post_act_fct=post_act_fct,
267
+ allow_broadcast=self.allow_broadcast,
268
+ use_cuda=self.use_cuda,
269
+ )
270
+ else:
271
+ self.codebook = MLP(
272
+ [self.code_size] + hidden_sizes + [latent_dim],
273
+ activation=activate_fct,
274
+ output_activation=None,
275
+ post_layer_fct=post_layer_fct,
276
+ post_act_fct=post_act_fct,
277
+ allow_broadcast=self.allow_broadcast,
278
+ use_cuda=self.use_cuda,
279
+ )
280
+
281
+ if self.use_cuda:
282
+ self.cuda()
283
+
284
+ def get_device(self):
285
+ return next(self.parameters()).device
286
+
287
+ def cutoff(self, xs, thresh=None):
288
+ eps = torch.finfo(xs.dtype).eps
289
+
290
+ if not thresh is None:
291
+ if eps < thresh:
292
+ eps = thresh
293
+
294
+ xs = xs.clamp(min=eps)
295
+
296
+ if torch.any(torch.isnan(xs)):
297
+ xs[torch.isnan(xs)] = eps
298
+
299
+ return xs
300
+
301
+ def softmax(self, xs):
302
+ #xs = SoftmaxTransform()(xs)
303
+ xs = dist.Multinomial(total_count=1, logits=xs).mean
304
+ return xs
305
+
306
+ def sigmoid(self, xs):
307
+ #sigm_enc = nn.Sigmoid()
308
+ #xs = sigm_enc(xs)
309
+ #xs = clamp_probs(xs)
310
+ xs = dist.Bernoulli(logits=xs).mean
311
+ return xs
312
+
313
+ def softmax_logit(self, xs):
314
+ eps = torch.finfo(xs.dtype).eps
315
+ xs = self.softmax(xs)
316
+ xs = torch.logit(xs, eps=eps)
317
+ return xs
318
+
319
+ def logit(self, xs):
320
+ eps = torch.finfo(xs.dtype).eps
321
+ xs = torch.logit(xs, eps=eps)
322
+ return xs
323
+
324
+ def dirimulti_param(self, xs):
325
+ xs = self.dirimulti_mass * self.sigmoid(xs)
326
+ return xs
327
+
328
+ def multi_param(self, xs):
329
+ xs = self.softmax(xs)
330
+ return xs
331
+
332
+ def model1(self, xs):
333
+ pyro.module('PerturbFlow', self)
334
+
335
+ eps = torch.finfo(xs.dtype).eps
336
+ batch_size = xs.size(0)
337
+ self.options = dict(dtype=xs.dtype, device=xs.device)
338
+
339
+ if self.loss_func=='negbinomial':
340
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
341
+ xs.new_ones(self.input_size), constraint=constraints.positive)
342
+
343
+ if self.use_zeroinflate:
344
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
345
+
346
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
347
+
348
+ I = torch.eye(self.code_size)
349
+ if self.latent_dist=='studentt':
350
+ acs_dof,acs_loc = self.codebook(I)
351
+ else:
352
+ acs_loc = self.codebook(I)
353
+
354
+ with pyro.plate('data'):
355
+ prior = torch.zeros(batch_size, self.code_size, **self.options)
356
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
357
+
358
+ zn_loc = torch.matmul(ns,acs_loc)
359
+ #zn_scale = torch.matmul(ns,acs_scale)
360
+ zn_scale = acs_scale
361
+
362
+ if self.latent_dist == 'studentt':
363
+ prior_dof = torch.matmul(ns,acs_dof)
364
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
365
+ elif self.latent_dist == 'laplacian':
366
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
367
+ elif self.latent_dist == 'cauchy':
368
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
369
+ elif self.latent_dist == 'normal':
370
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
371
+ elif self.latent_dist == 'gumbel':
372
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
373
+
374
+ zs = zns
375
+ concentrate = self.decoder_concentrate(zs)
376
+ if self.loss_func == 'bernoulli':
377
+ log_theta = concentrate
378
+ else:
379
+ rate = concentrate.exp()
380
+ if self.loss_func != 'poisson':
381
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
382
+
383
+ if self.loss_func == 'negbinomial':
384
+ if self.use_zeroinflate:
385
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
386
+ else:
387
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
388
+ elif self.loss_func == 'poisson':
389
+ if self.use_zeroinflate:
390
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
391
+ else:
392
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
393
+ elif self.loss_func == 'multinomial':
394
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
395
+ elif self.loss_func == 'bernoulli':
396
+ if self.use_zeroinflate:
397
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
398
+ else:
399
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
400
+
401
+ def guide1(self, xs):
402
+ with pyro.plate('data'):
403
+ zn_loc, zn_scale = self.encoder_zn(xs)
404
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
405
+
406
+ alpha = self.encoder_n(zns)
407
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
408
+
409
+ def model2(self, xs, us=None):
410
+ pyro.module('PerturbFlow', self)
411
+
412
+ eps = torch.finfo(xs.dtype).eps
413
+ batch_size = xs.size(0)
414
+ self.options = dict(dtype=xs.dtype, device=xs.device)
415
+
416
+ if self.loss_func=='negbinomial':
417
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
418
+ xs.new_ones(self.input_size), constraint=constraints.positive)
419
+
420
+ if self.use_zeroinflate:
421
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
422
+
423
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
424
+
425
+ I = torch.eye(self.code_size)
426
+ if self.latent_dist=='studentt':
427
+ acs_dof,acs_loc = self.codebook(I)
428
+ else:
429
+ acs_loc = self.codebook(I)
430
+
431
+ with pyro.plate('data'):
432
+ prior = torch.zeros(batch_size, self.code_size, **self.options)
433
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
434
+
435
+ zn_loc = torch.matmul(ns,acs_loc)
436
+ #zn_scale = torch.matmul(ns,acs_scale)
437
+ zn_scale = acs_scale
438
+
439
+ if self.latent_dist == 'studentt':
440
+ prior_dof = torch.matmul(ns,acs_dof)
441
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
442
+ elif self.latent_dist == 'laplacian':
443
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
444
+ elif self.latent_dist == 'cauchy':
445
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
446
+ elif self.latent_dist == 'normal':
447
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
448
+ elif self.latent_dist == 'gumbel':
449
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
450
+
451
+ if self.cell_factor_size>0:
452
+ #zus = self.decoder_undesired([zns,us])
453
+ zus = None
454
+ for i in np.arange(self.cell_factor_size):
455
+ if i==0:
456
+ zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
457
+ else:
458
+ zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
459
+ zs = zns+zus
460
+ else:
461
+ zs = zns
462
+
463
+ concentrate = self.decoder_concentrate(zs)
464
+ if self.loss_func == 'bernoulli':
465
+ log_theta = concentrate
466
+ else:
467
+ rate = concentrate.exp()
468
+ if self.loss_func != 'poisson':
469
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
470
+
471
+ if self.loss_func == 'negbinomial':
472
+ if self.use_zeroinflate:
473
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
474
+ else:
475
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
476
+ elif self.loss_func == 'poisson':
477
+ if self.use_zeroinflate:
478
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
479
+ else:
480
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
481
+ elif self.loss_func == 'multinomial':
482
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
483
+ elif self.loss_func == 'bernoulli':
484
+ if self.use_zeroinflate:
485
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
486
+ else:
487
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
488
+
489
+ def guide2(self, xs, us=None):
490
+ with pyro.plate('data'):
491
+ zn_loc, zn_scale = self.encoder_zn(xs)
492
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
493
+
494
+ alpha = self.encoder_n(zns)
495
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
496
+
497
+ def model3(self, xs, ys, embeds=None):
498
+ pyro.module('PerturbFlow', self)
499
+
500
+ eps = torch.finfo(xs.dtype).eps
501
+ batch_size = xs.size(0)
502
+ self.options = dict(dtype=xs.dtype, device=xs.device)
503
+
504
+ if self.loss_func=='negbinomial':
505
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
506
+ xs.new_ones(self.input_size), constraint=constraints.positive)
507
+
508
+ if self.use_zeroinflate:
509
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
510
+
511
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
512
+
513
+ I = torch.eye(self.code_size)
514
+ if self.latent_dist=='studentt':
515
+ acs_dof,acs_loc = self.codebook(I)
516
+ else:
517
+ acs_loc = self.codebook(I)
518
+
519
+ with pyro.plate('data'):
520
+ #prior = torch.zeros(batch_size, self.code_size, **self.options)
521
+ prior = self.encoder_n(xs)
522
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
523
+
524
+ zn_loc = torch.matmul(ns,acs_loc)
525
+ #prior_scale = torch.matmul(ns,acs_scale)
526
+ zn_scale = acs_scale
527
+
528
+ if self.latent_dist=='studentt':
529
+ prior_dof = torch.matmul(ns,acs_dof)
530
+ if embeds is None:
531
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
532
+ else:
533
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
534
+ elif self.latent_dist=='laplacian':
535
+ if embeds is None:
536
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
537
+ else:
538
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
539
+ elif self.latent_dist=='cauchy':
540
+ if embeds is None:
541
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
542
+ else:
543
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
544
+ elif self.latent_dist=='normal':
545
+ if embeds is None:
546
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
547
+ else:
548
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
549
+ elif self.z_dist == 'gumbel':
550
+ if embeds is None:
551
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
552
+ else:
553
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
554
+
555
+ zs = zns
556
+
557
+ concentrate = self.decoder_concentrate(zs)
558
+ if self.loss_func == 'bernoulli':
559
+ log_theta = concentrate
560
+ else:
561
+ rate = concentrate.exp()
562
+ if self.loss_func != 'poisson':
563
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
564
+
565
+ if self.loss_func == 'negbinomial':
566
+ if self.use_zeroinflate:
567
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
568
+ else:
569
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
570
+ elif self.loss_func == 'poisson':
571
+ if self.use_zeroinflate:
572
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
573
+ else:
574
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
575
+ elif self.loss_func == 'multinomial':
576
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
577
+ elif self.loss_func == 'bernoulli':
578
+ if self.use_zeroinflate:
579
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
580
+ else:
581
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
582
+
583
+ def guide3(self, xs, ys, embeds=None):
584
+ with pyro.plate('data'):
585
+ if embeds is None:
586
+ zn_loc, zn_scale = self.encoder_zn(xs)
587
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
588
+
589
+ def model4(self, xs, us, ys, embeds=None):
590
+ pyro.module('PerturbFlow', self)
591
+
592
+ eps = torch.finfo(xs.dtype).eps
593
+ batch_size = xs.size(0)
594
+ self.options = dict(dtype=xs.dtype, device=xs.device)
595
+
596
+ if self.loss_func=='negbinomial':
597
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
598
+ xs.new_ones(self.input_size), constraint=constraints.positive)
599
+
600
+ if self.use_zeroinflate:
601
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
602
+
603
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
604
+
605
+ I = torch.eye(self.code_size)
606
+ if self.latent_dist=='studentt':
607
+ acs_dof,acs_loc = self.codebook(I)
608
+ else:
609
+ acs_loc = self.codebook(I)
610
+
611
+ with pyro.plate('data'):
612
+ #prior = torch.zeros(batch_size, self.code_size, **self.options)
613
+ prior = self.encoder_n(xs)
614
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
615
+
616
+ zn_loc = torch.matmul(ns,acs_loc)
617
+ #prior_scale = torch.matmul(ns,acs_scale)
618
+ zn_scale = acs_scale
619
+
620
+ if self.latent_dist=='studentt':
621
+ prior_dof = torch.matmul(ns,acs_dof)
622
+ if embeds is None:
623
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
624
+ else:
625
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
626
+ elif self.latent_dist=='laplacian':
627
+ if embeds is None:
628
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
629
+ else:
630
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
631
+ elif self.latent_dist=='cauchy':
632
+ if embeds is None:
633
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
634
+ else:
635
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
636
+ elif self.latent_dist=='normal':
637
+ if embeds is None:
638
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
639
+ else:
640
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
641
+ elif self.z_dist == 'gumbel':
642
+ if embeds is None:
643
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
644
+ else:
645
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
646
+
647
+ if self.cell_factor_size>0:
648
+ #zus = self.decoder_undesired([zns,us])
649
+ zus = None
650
+ for i in np.arange(self.cell_factor_size):
651
+ if i==0:
652
+ zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
653
+ else:
654
+ zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
655
+ zs = zns+zus
656
+ else:
657
+ zs = zns
658
+
659
+ concentrate = self.decoder_concentrate(zs)
660
+ if self.loss_func == 'bernoulli':
661
+ log_theta = concentrate
662
+ else:
663
+ rate = concentrate.exp()
664
+ if self.loss_func != 'poisson':
665
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
666
+
667
+ if self.loss_func == 'negbinomial':
668
+ if self.use_zeroinflate:
669
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
670
+ else:
671
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
672
+ elif self.loss_func == 'poisson':
673
+ if self.use_zeroinflate:
674
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
675
+ else:
676
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
677
+ elif self.loss_func == 'multinomial':
678
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
679
+ elif self.loss_func == 'bernoulli':
680
+ if self.use_zeroinflate:
681
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
682
+ else:
683
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
684
+
685
+ def guide4(self, xs, us, ys, embeds=None):
686
+ with pyro.plate('data'):
687
+ if embeds is None:
688
+ zn_loc, zn_scale = self.encoder_zn(xs)
689
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
690
+
691
+ def _get_codebook(self):
692
+ I = torch.eye(self.code_size, **self.options)
693
+ if self.latent_dist=='studentt':
694
+ _,cb = self.codebook(I)
695
+ else:
696
+ cb = self.codebook(I)
697
+ return cb
698
+
699
+ def get_codebook(self):
700
+ """
701
+ Return the mean part of metacell codebook
702
+ """
703
+ cb = self._get_metacell_coordinates()
704
+ cb = tensor_to_numpy(cb)
705
+ return cb
706
+
707
+ def _get_cell_embedding(self, xs):
708
+ zns, _ = self.encoder_zn(xs)
709
+ return zns
710
+
711
+ def get_cell_embedding(self,
712
+ xs,
713
+ batch_size: int = 1024):
714
+ """
715
+ Return cells' latent representations
716
+
717
+ Parameters
718
+ ----------
719
+ xs
720
+ Single-cell expression matrix. It should be a Numpy array or a Pytorch Tensor.
721
+ batch_size
722
+ Size of batch processing.
723
+ use_decoder
724
+ If toggled on, the latent representations will be reconstructed from the metacell codebook
725
+ soft_assign
726
+ If toggled on, the assignments of cells will use probabilistic values.
727
+ """
728
+ xs = self.preprocess(xs)
729
+ xs = convert_to_tensor(xs, device=self.get_device())
730
+ dataset = CustomDataset(xs)
731
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
732
+
733
+ Z = []
734
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
735
+ for X_batch, _ in dataloader:
736
+ zns = self._get_cell_embedding(X_batch)
737
+ Z.append(tensor_to_numpy(zns))
738
+ pbar.update(1)
739
+
740
+ Z = np.concatenate(Z)
741
+ return Z
742
+
743
+ def _code(self, xs):
744
+ if self.supervised_mode:
745
+ alpha = self.encoder_n(xs)
746
+ else:
747
+ zns,_ = self.encoder_zn(xs)
748
+ alpha = self.encoder_n(zns)
749
+ return alpha
750
+
751
+ def code(self, xs, batch_size=1024):
752
+ xs = self.preprocess(xs)
753
+ xs = convert_to_tensor(xs, device=self.get_device())
754
+ dataset = CustomDataset(xs)
755
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
756
+
757
+ A = []
758
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
759
+ for X_batch, _ in dataloader:
760
+ a = self._code(X_batch)
761
+ A.append(tensor_to_numpy(a))
762
+ pbar.update(1)
763
+
764
+ A = np.concatenate(A)
765
+ return A
766
+
767
+ def _soft_assignments(self, xs):
768
+ alpha = self._code(xs)
769
+ alpha = self.softmax(alpha)
770
+ return alpha
771
+
772
+ def soft_assignments(self, xs, batch_size=1024):
773
+ """
774
+ Map cells to metacells and return the probabilistic values of metacell assignments
775
+ """
776
+ xs = self.preprocess(xs)
777
+ xs = convert_to_tensor(xs, device=self.get_device())
778
+ dataset = CustomDataset(xs)
779
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
780
+
781
+ A = []
782
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
783
+ for X_batch, _ in dataloader:
784
+ a = self._soft_assignments(X_batch)
785
+ A.append(tensor_to_numpy(a))
786
+ pbar.update(1)
787
+
788
+ A = np.concatenate(A)
789
+ return A
790
+
791
+ def _hard_assignments(self, xs):
792
+ alpha = self._code(xs)
793
+ res, ind = torch.topk(alpha, 1)
794
+ ns = torch.zeros_like(alpha).scatter_(1, ind, 1.0)
795
+ return ns
796
+
797
+ def hard_assignments(self, xs, batch_size=1024):
798
+ """
799
+ Map cells to metacells and return the assigned metacell identities.
800
+ """
801
+ xs = self.preprocess(xs)
802
+ xs = convert_to_tensor(xs, device=self.get_device())
803
+ dataset = CustomDataset(xs)
804
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
805
+
806
+ A = []
807
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
808
+ for X_batch, _ in dataloader:
809
+ a = self._hard_assignments(X_batch)
810
+ A.append(tensor_to_numpy(a))
811
+ pbar.update(1)
812
+
813
+ A = np.concatenate(A)
814
+ return A
815
+
816
+ def _cell_state_response(self, xs, factor_idx, perturb):
817
+ zns,_ = self.encoder_zn(xs)
818
+ if type(factor_idx) == str:
819
+ factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
820
+
821
+ if perturb.ndim==2:
822
+ ms = self.cell_factor_effect[factor_idx]([zns, perturb])
823
+ else:
824
+ ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
825
+
826
+ return ms
827
+
828
+ def get_cell_state_response(self,
829
+ xs,
830
+ factor_idx,
831
+ perturb,
832
+ batch_size: int = 1024):
833
+ """
834
+ Return cells' changes in the latent space induced by specific perturbation of a factor
835
+
836
+ """
837
+ xs = self.preprocess(xs)
838
+ xs = convert_to_tensor(xs, device=self.get_device())
839
+ ps = convert_to_tensor(perturb, device=self.get_device())
840
+ dataset = CustomDataset2(xs,ps)
841
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
842
+
843
+ Z = []
844
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
845
+ for X_batch, P_batch, _ in dataloader:
846
+ zns = self._cell_state_response(X_batch, factor_idx, P_batch)
847
+ Z.append(tensor_to_numpy(zns))
848
+ pbar.update(1)
849
+
850
+ Z = np.concatenate(Z)
851
+ return Z
852
+
853
+ def get_metacell_response(self, factor_idx, perturb):
854
+ zs = self._get_codebook()
855
+ ps = convert_to_tensor(perturb, device=self.get_device())
856
+
857
+ if type(factor_idx) == str:
858
+ factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
859
+
860
+ ms = self.cell_factor_effect[factor_idx]([zs,ps])
861
+ return tensor_to_numpy(ms)
862
+
863
+ def _get_expression_response(self, delta_zs):
864
+ return self.decoder_concentrate(delta_zs)
865
+
866
+ def get_expression_response(self,
867
+ delta_zs,
868
+ batch_size: int = 1024):
869
+ """
870
+ Return cells' changes in the latent space induced by specific perturbation of a factor
871
+
872
+ """
873
+ delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
874
+ dataset = CustomDataset(delta_zs)
875
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
876
+
877
+ R = []
878
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
879
+ for delta_Z_batch, _ in dataloader:
880
+ r = self._cell_move(delta_Z_batch)
881
+ R.append(tensor_to_numpy(r))
882
+ pbar.update(1)
883
+
884
+ R = np.concatenate(R)
885
+ return R
886
+
887
+ def preprocess(self, xs, threshold=0):
888
+ if self.loss_func == 'bernoulli':
889
+ ad = sc.AnnData(xs)
890
+ binarize(ad, threshold=threshold)
891
+ xs = ad.X.copy()
892
+ else:
893
+ xs = np.round(xs)
894
+
895
+ if sparse.issparse(xs):
896
+ xs = xs.toarray()
897
+ return xs
898
+
899
+ def fit(self, xs,
900
+ us = None,
901
+ ys = None,
902
+ zs = None,
903
+ num_epochs: int = 200,
904
+ learning_rate: float = 0.0001,
905
+ batch_size: int = 256,
906
+ algo: Literal['adam','rmsprop','adamw'] = 'adam',
907
+ beta_1: float = 0.9,
908
+ weight_decay: float = 0.005,
909
+ decay_rate: float = 0.9,
910
+ config_enum: str = 'parallel',
911
+ threshold: int = 0,
912
+ use_jax: bool = False):
913
+ """
914
+ Train the PerturbFlow model.
915
+
916
+ Parameters
917
+ ----------
918
+ xs
919
+ Single-cell experssion matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are features.
920
+ us
921
+ cell-level factor matrix.
922
+ ys
923
+ Desired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are desired factors.
924
+ num_epochs
925
+ Number of training epochs.
926
+ learning_rate
927
+ Parameter for training.
928
+ batch_size
929
+ Size of batch processing.
930
+ algo
931
+ Optimization algorithm.
932
+ beta_1
933
+ Parameter for optimization.
934
+ weight_decay
935
+ Parameter for optimization.
936
+ decay_rate
937
+ Parameter for optimization.
938
+ use_jax
939
+ If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
940
+ the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
941
+ """
942
+ xs = self.preprocess(xs, threshold=threshold)
943
+ xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
944
+ if us is not None:
945
+ us = convert_to_tensor(us, dtype=self.dtype, device=self.get_device())
946
+ if ys is not None:
947
+ ys = convert_to_tensor(ys, dtype=self.dtype, device=self.get_device())
948
+ if zs is not None:
949
+ zs = convert_to_tensor(zs, dtype=self.dtype, device=self.get_device())
950
+
951
+ dataset = CustomDataset4(xs, us, ys, zs)
952
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
953
+
954
+ # setup the optimizer
955
+ optim_params = {'lr': learning_rate, 'betas': (beta_1, 0.999), 'weight_decay': weight_decay}
956
+
957
+ if algo.lower()=='rmsprop':
958
+ optimizer = torch.optim.RMSprop
959
+ elif algo.lower()=='adam':
960
+ optimizer = torch.optim.Adam
961
+ elif algo.lower() == 'adamw':
962
+ optimizer = torch.optim.AdamW
963
+ else:
964
+ raise ValueError("An optimization algorithm must be specified.")
965
+ scheduler = ExponentialLR({'optimizer': optimizer, 'optim_args': optim_params, 'gamma': decay_rate})
966
+
967
+ pyro.clear_param_store()
968
+
969
+ # set up the loss(es) for inference, wrapping the guide in config_enumerate builds the loss as a sum
970
+ # by enumerating each class label form the sampled discrete categorical distribution in the model
971
+ Elbo = JitTraceEnum_ELBO if use_jax else TraceEnum_ELBO
972
+ elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False)
973
+ if us is None:
974
+ if ys is None:
975
+ guide = config_enumerate(self.guide1, config_enum, expand=True)
976
+ loss_basic = SVI(self.model1, guide, scheduler, loss=elbo)
977
+ else:
978
+ guide = config_enumerate(self.guide3, config_enum, expand=True)
979
+ loss_basic = SVI(self.model3, guide, scheduler, loss=elbo)
980
+ else:
981
+ if ys is None:
982
+ guide = config_enumerate(self.guide2, config_enum, expand=True)
983
+ loss_basic = SVI(self.model2, guide, scheduler, loss=elbo)
984
+ else:
985
+ guide = config_enumerate(self.guide4, config_enum, expand=True)
986
+ loss_basic = SVI(self.model4, guide, scheduler, loss=elbo)
987
+
988
+ # build a list of all losses considered
989
+ losses = [loss_basic]
990
+ num_losses = len(losses)
991
+
992
+ with tqdm(total=num_epochs, desc='Training', unit='epoch') as pbar:
993
+ for epoch in range(num_epochs):
994
+ epoch_losses = [0.0] * num_losses
995
+ for batch_x, batch_u, batch_y, batch_z, _ in dataloader:
996
+ if us is None:
997
+ batch_u = None
998
+ if ys is None:
999
+ batch_y = None
1000
+ if zs is None:
1001
+ batch_z = None
1002
+
1003
+ for loss_id in range(num_losses):
1004
+ if batch_u is None:
1005
+ if batch_y is None:
1006
+ new_loss = losses[loss_id].step(batch_x)
1007
+ else:
1008
+ new_loss = losses[loss_id].step(batch_x, batch_y, batch_z)
1009
+ else:
1010
+ if batch_y is None:
1011
+ new_loss = losses[loss_id].step(batch_x, batch_u)
1012
+ else:
1013
+ new_loss = losses[loss_id].step(batch_x, batch_u, batch_y, batch_z)
1014
+ epoch_losses[loss_id] += new_loss
1015
+
1016
+ avg_epoch_losses_ = map(lambda v: v / len(dataloader), epoch_losses)
1017
+ avg_epoch_losses = map(lambda v: "{:.4f}".format(v), avg_epoch_losses_)
1018
+
1019
+ # store the loss
1020
+ str_loss = " ".join(map(str, avg_epoch_losses))
1021
+
1022
+ # Update progress bar
1023
+ pbar.set_postfix({'loss': str_loss})
1024
+ pbar.update(1)
1025
+
1026
+ @classmethod
1027
+ def save_model(cls, model, file_path, compression=False):
1028
+ """Save the model to the specified file path."""
1029
+ file_path = os.path.abspath(file_path)
1030
+
1031
+ model.eval()
1032
+ if compression:
1033
+ with gzip.open(file_path, 'wb') as pickle_file:
1034
+ pickle.dump(model, pickle_file)
1035
+ else:
1036
+ with open(file_path, 'wb') as pickle_file:
1037
+ pickle.dump(model, pickle_file)
1038
+
1039
+ print(f'Model saved to {file_path}')
1040
+
1041
+ @classmethod
1042
+ def load_model(cls, file_path):
1043
+ """Load the model from the specified file path and return an instance."""
1044
+ print(f'Model loaded from {file_path}')
1045
+
1046
+ file_path = os.path.abspath(file_path)
1047
+ if file_path.endswith('gz'):
1048
+ with gzip.open(file_path, 'rb') as pickle_file:
1049
+ model = pickle.load(pickle_file)
1050
+ else:
1051
+ with open(file_path, 'rb') as pickle_file:
1052
+ model = pickle.load(pickle_file)
1053
+
1054
+ return model
1055
+
1056
+
1057
+ EXAMPLE_RUN = (
1058
+ "example run: PerturbFlow --help"
1059
+ )
1060
+
1061
+ def parse_args():
1062
+ parser = argparse.ArgumentParser(
1063
+ description="PerturbFlow\n{}".format(EXAMPLE_RUN))
1064
+
1065
+ parser.add_argument(
1066
+ "--cuda", action="store_true", help="use GPU(s) to speed up training"
1067
+ )
1068
+ parser.add_argument(
1069
+ "--jit", action="store_true", help="use PyTorch jit to speed up training"
1070
+ )
1071
+ parser.add_argument(
1072
+ "-n", "--num-epochs", default=200, type=int, help="number of epochs to run"
1073
+ )
1074
+ parser.add_argument(
1075
+ "-enum",
1076
+ "--enum-discrete",
1077
+ default="parallel",
1078
+ help="parallel, sequential or none. uses parallel enumeration by default",
1079
+ )
1080
+ parser.add_argument(
1081
+ "-data",
1082
+ "--data-file",
1083
+ default=None,
1084
+ type=str,
1085
+ help="the data file",
1086
+ )
1087
+ parser.add_argument(
1088
+ "-cf",
1089
+ "--cell-factor-file",
1090
+ default=None,
1091
+ type=str,
1092
+ help="the file for the record of cell-level factors",
1093
+ )
1094
+ parser.add_argument(
1095
+ "-delta",
1096
+ "--delta",
1097
+ default=0.0,
1098
+ type=float,
1099
+ help="penalty weight for zero inflation loss",
1100
+ )
1101
+ parser.add_argument(
1102
+ "-64",
1103
+ "--float64",
1104
+ action="store_true",
1105
+ help="use double float precision",
1106
+ )
1107
+ parser.add_argument(
1108
+ "--z-dist",
1109
+ default='normal',
1110
+ type=str,
1111
+ choices=['normal','laplacian','studentt','cauchy'],
1112
+ help="distribution model for latent representation",
1113
+ )
1114
+ parser.add_argument(
1115
+ "-cs",
1116
+ "--codebook-size",
1117
+ default=100,
1118
+ type=int,
1119
+ help="size of vector quantization codebook",
1120
+ )
1121
+ parser.add_argument(
1122
+ "-zd",
1123
+ "--z-dim",
1124
+ default=10,
1125
+ type=int,
1126
+ help="size of the tensor representing the latent variable z variable",
1127
+ )
1128
+ parser.add_argument(
1129
+ "-hl",
1130
+ "--hidden-layers",
1131
+ nargs="+",
1132
+ default=[500],
1133
+ type=int,
1134
+ help="a tuple (or list) of MLP layers to be used in the neural networks "
1135
+ "representing the parameters of the distributions in our model",
1136
+ )
1137
+ parser.add_argument(
1138
+ "-hla",
1139
+ "--hidden-layer-activation",
1140
+ default='relu',
1141
+ type=str,
1142
+ choices=['relu','softplus','leakyrelu','linear'],
1143
+ help="activation function for hidden layers",
1144
+ )
1145
+ parser.add_argument(
1146
+ "-plf",
1147
+ "--post-layer-function",
1148
+ nargs="+",
1149
+ default=['layernorm'],
1150
+ type=str,
1151
+ help="post functions for hidden layers, could be none, dropout, layernorm, batchnorm, or combination, default is 'dropout layernorm'",
1152
+ )
1153
+ parser.add_argument(
1154
+ "-paf",
1155
+ "--post-activation-function",
1156
+ nargs="+",
1157
+ default=['none'],
1158
+ type=str,
1159
+ help="post functions for activation layers, could be none or dropout, default is 'none'",
1160
+ )
1161
+ parser.add_argument(
1162
+ "-id",
1163
+ "--inverse-dispersion",
1164
+ default=10.0,
1165
+ type=float,
1166
+ help="inverse dispersion prior for negative binomial",
1167
+ )
1168
+ parser.add_argument(
1169
+ "-lr",
1170
+ "--learning-rate",
1171
+ default=0.0001,
1172
+ type=float,
1173
+ help="learning rate for Adam optimizer",
1174
+ )
1175
+ parser.add_argument(
1176
+ "-dr",
1177
+ "--decay-rate",
1178
+ default=0.9,
1179
+ type=float,
1180
+ help="decay rate for Adam optimizer",
1181
+ )
1182
+ parser.add_argument(
1183
+ "--layer-dropout-rate",
1184
+ default=0.1,
1185
+ type=float,
1186
+ help="droput rate for neural networks",
1187
+ )
1188
+ parser.add_argument(
1189
+ "-b1",
1190
+ "--beta-1",
1191
+ default=0.95,
1192
+ type=float,
1193
+ help="beta-1 parameter for Adam optimizer",
1194
+ )
1195
+ parser.add_argument(
1196
+ "-bs",
1197
+ "--batch-size",
1198
+ default=1000,
1199
+ type=int,
1200
+ help="number of cells to be considered in a batch",
1201
+ )
1202
+ parser.add_argument(
1203
+ "-gp",
1204
+ "--gate-prior",
1205
+ default=0.6,
1206
+ type=float,
1207
+ help="gate prior for zero-inflated model",
1208
+ )
1209
+ parser.add_argument(
1210
+ "-likeli",
1211
+ "--likelihood",
1212
+ default='negbinomial',
1213
+ type=str,
1214
+ choices=['negbinomial', 'multinomial', 'poisson', 'gaussian','lognormal'],
1215
+ help="specify the distribution likelihood function",
1216
+ )
1217
+ parser.add_argument(
1218
+ "-dirichlet",
1219
+ "--use-dirichlet",
1220
+ action="store_true",
1221
+ help="use Dirichlet distribution over gene frequency",
1222
+ )
1223
+ parser.add_argument(
1224
+ "-mass",
1225
+ "--dirichlet-mass",
1226
+ default=1,
1227
+ type=float,
1228
+ help="mass param for dirichlet model",
1229
+ )
1230
+ parser.add_argument(
1231
+ "-zi",
1232
+ "--zero-inflation",
1233
+ default='exact',
1234
+ type=str,
1235
+ choices=['none','exact','inexact'],
1236
+ help="use zero-inflated estimation",
1237
+ )
1238
+ parser.add_argument(
1239
+ "--seed",
1240
+ default=None,
1241
+ type=int,
1242
+ help="seed for controlling randomness in this example",
1243
+ )
1244
+ parser.add_argument(
1245
+ "--save-model",
1246
+ default=None,
1247
+ type=str,
1248
+ help="path to save model for prediction",
1249
+ )
1250
+ args = parser.parse_args()
1251
+ return args
1252
+
1253
+ def main():
1254
+ args = parse_args()
1255
+ assert (
1256
+ (args.data_file is not None) and (
1257
+ os.path.exists(args.data_file))
1258
+ ), "data file must be provided"
1259
+
1260
+ if args.seed is not None:
1261
+ set_random_seed(args.seed)
1262
+
1263
+ if args.float64:
1264
+ dtype = torch.float64
1265
+ torch.set_default_dtype(torch.float64)
1266
+ else:
1267
+ dtype = torch.float32
1268
+ torch.set_default_dtype(torch.float32)
1269
+
1270
+ xs = dt.fread(file=args.data_file, header=True).to_numpy()
1271
+ us = None
1272
+ if args.cell_factor_file is not None:
1273
+ us = dt.fread(file=args.cell_factor_file, header=True).to_numpy()
1274
+
1275
+ input_size = xs.shape[1]
1276
+ cell_factor_size = 0 if us is None else us.shape[1]
1277
+
1278
+ latent_dist = args.z_dist
1279
+
1280
+ ###########################################
1281
+ perturbflow = PerturbFlow(
1282
+ input_size=input_size,
1283
+ cell_factor_size=cell_factor_size,
1284
+ inverse_dispersion=args.inverse_dispersion,
1285
+ latent_dim=args.latent_dim,
1286
+ hidden_layers=args.hidden_layers,
1287
+ hidden_layer_activation=args.hidden_layer_activation,
1288
+ use_cuda=args.cuda,
1289
+ config_enum=args.enum_discrete,
1290
+ use_dirichlet=args.use_dirichlet,
1291
+ zero_inflation=args.zero_inflation,
1292
+ gate_prior=args.gate_prior,
1293
+ delta=args.delta,
1294
+ loss_func=args.likelihood,
1295
+ dirichlet_mass=args.dirichlet_mass,
1296
+ nn_dropout=args.layer_dropout_rate,
1297
+ post_layer_fct=args.post_layer_function,
1298
+ post_act_fct=args.post_activation_function,
1299
+ codebook_size=args.codebook_size,
1300
+ latent_dist = latent_dist,
1301
+ dtype=dtype,
1302
+ )
1303
+
1304
+ perturbflow.fit(xs, us=us,
1305
+ num_epochs=args.num_epochs,
1306
+ learning_rate=args.learning_rate,
1307
+ batch_size=args.batch_size,
1308
+ beta_1=args.beta_1,
1309
+ decay_rate=args.decay_rate,
1310
+ use_jax=args.jit,
1311
+ config_enum=args.enum_discrete,
1312
+ )
1313
+
1314
+ if args.save_model is not None:
1315
+ if args.save_model.endswith('gz'):
1316
+ PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
1317
+ else:
1318
+ PerturbFlow.save_model(perturbflow, args.save_model)
1319
+
1320
+
1321
+
1322
+ if __name__ == "__main__":
1323
+
1324
+ main()