SURE-tools 2.4.7__py3-none-any.whl → 2.4.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
SURE/DensityFlow2.py ADDED
@@ -0,0 +1,1422 @@
1
+ import pyro
2
+ import pyro.distributions as dist
3
+ from pyro.optim import ExponentialLR
4
+ from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, config_enumerate
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader
9
+ from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_probs
10
+ from torch.distributions import constraints
11
+ from torch.distributions.transforms import SoftmaxTransform
12
+
13
+ from .utils.custom_mlp import MLP, Exp, ZeroBiasMLP
14
+ from .utils.utils import CustomDataset, CustomDataset2, CustomDataset4, tensor_to_numpy, convert_to_tensor
15
+
16
+
17
+ import os
18
+ import argparse
19
+ import random
20
+ import numpy as np
21
+ import datatable as dt
22
+ from tqdm import tqdm
23
+ from scipy import sparse
24
+
25
+ import scanpy as sc
26
+ from .atac import binarize
27
+
28
+ from typing import Literal
29
+
30
+ import warnings
31
+ warnings.filterwarnings("ignore")
32
+
33
+ import dill as pickle
34
+ import gzip
35
+ from packaging.version import Version
36
+ torch_version = torch.__version__
37
+
38
+
39
+ def set_random_seed(seed):
40
+ # Set seed for PyTorch
41
+ torch.manual_seed(seed)
42
+
43
+ # If using CUDA, set the seed for CUDA
44
+ if torch.cuda.is_available():
45
+ torch.cuda.manual_seed(seed)
46
+ torch.cuda.manual_seed_all(seed) # For multi-GPU setups.
47
+
48
+ # Set seed for NumPy
49
+ np.random.seed(seed)
50
+
51
+ # Set seed for Python's random module
52
+ random.seed(seed)
53
+
54
+ # Set seed for Pyro
55
+ pyro.set_rng_seed(seed)
56
+
57
+ class DensityFlow2(nn.Module):
58
+ def __init__(self,
59
+ input_size: int,
60
+ codebook_size: int = 200,
61
+ cell_factor_size: int = 0,
62
+ turn_off_cell_specific: bool = False,
63
+ supervised_mode: bool = False,
64
+ z_dim: int = 10,
65
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
67
+ inverse_dispersion: float = 10.0,
68
+ use_zeroinflate: bool = False,
69
+ hidden_layers: list = [500],
70
+ hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
71
+ nn_dropout: float = 0.1,
72
+ post_layer_fct: list = ['layernorm'],
73
+ post_act_fct: list = None,
74
+ config_enum: str = 'parallel',
75
+ use_cuda: bool = True,
76
+ seed: int = 42,
77
+ zero_bias: bool|list = True,
78
+ dtype = torch.float32, # type: ignore
79
+ ):
80
+ super().__init__()
81
+
82
+ self.input_size = input_size
83
+ self.cell_factor_size = cell_factor_size
84
+ self.inverse_dispersion = inverse_dispersion
85
+ self.latent_dim = z_dim
86
+ self.hidden_layers = hidden_layers
87
+ self.decoder_hidden_layers = hidden_layers[::-1]
88
+ self.allow_broadcast = config_enum == 'parallel'
89
+ self.use_cuda = use_cuda
90
+ self.loss_func = loss_func
91
+ self.options = None
92
+ self.code_size=codebook_size
93
+ self.supervised_mode=supervised_mode
94
+ self.latent_dist = z_dist
95
+ self.dtype = dtype
96
+ self.use_zeroinflate=use_zeroinflate
97
+ self.nn_dropout = nn_dropout
98
+ self.post_layer_fct = post_layer_fct
99
+ self.post_act_fct = post_act_fct
100
+ self.hidden_layer_activation = hidden_layer_activation
101
+ if type(zero_bias) == list:
102
+ self.use_bias = [not x for x in zero_bias]
103
+ else:
104
+ self.use_bias = [not zero_bias] * self.cell_factor_size
105
+ #self.use_bias = not zero_bias
106
+ self.turn_off_cell_specific = turn_off_cell_specific
107
+
108
+ self.codebook_weights = None
109
+
110
+ set_random_seed(seed)
111
+ self.setup_networks()
112
+
113
+ print(f"🧬 DensityFlow2 Initialized:")
114
+ print(f" - Latent Dimension: {self.latent_dim}")
115
+ print(f" - Gene Dimension: {self.input_size}")
116
+ print(f" - Hidden Dimensions: {self.hidden_layers}")
117
+ print(f" - Device: {self.get_device()}")
118
+ print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
119
+
120
+ def setup_networks(self):
121
+ latent_dim = self.latent_dim
122
+ hidden_sizes = self.hidden_layers
123
+
124
+ nn_layer_norm, nn_batch_norm, nn_layer_dropout = False, False, False
125
+ na_layer_norm, na_batch_norm, na_layer_dropout = False, False, False
126
+
127
+ if self.post_layer_fct is not None:
128
+ nn_layer_norm=True if ('layernorm' in self.post_layer_fct) or ('layer_norm' in self.post_layer_fct) else False
129
+ nn_batch_norm=True if ('batchnorm' in self.post_layer_fct) or ('batch_norm' in self.post_layer_fct) else False
130
+ nn_layer_dropout=True if 'dropout' in self.post_layer_fct else False
131
+
132
+ if self.post_act_fct is not None:
133
+ na_layer_norm=True if ('layernorm' in self.post_act_fct) or ('layer_norm' in self.post_act_fct) else False
134
+ na_batch_norm=True if ('batchnorm' in self.post_act_fct) or ('batch_norm' in self.post_act_fct) else False
135
+ na_layer_dropout=True if 'dropout' in self.post_act_fct else False
136
+
137
+ if nn_layer_norm and nn_batch_norm and nn_layer_dropout:
138
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
139
+ elif nn_layer_norm and nn_layer_dropout:
140
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
141
+ elif nn_batch_norm and nn_layer_dropout:
142
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
143
+ elif nn_layer_norm and nn_batch_norm:
144
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
145
+ elif nn_layer_norm:
146
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
147
+ elif nn_batch_norm:
148
+ post_layer_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
149
+ elif nn_layer_dropout:
150
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
151
+ else:
152
+ post_layer_fct = lambda layer_ix, total_layers, layer: None
153
+
154
+ if na_layer_norm and na_batch_norm and na_layer_dropout:
155
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
156
+ elif na_layer_norm and na_layer_dropout:
157
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
158
+ elif na_batch_norm and na_layer_dropout:
159
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
160
+ elif na_layer_norm and na_batch_norm:
161
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
162
+ elif na_layer_norm:
163
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
164
+ elif na_batch_norm:
165
+ post_act_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
166
+ elif na_layer_dropout:
167
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
168
+ else:
169
+ post_act_fct = lambda layer_ix, total_layers, layer: None
170
+
171
+ if self.hidden_layer_activation == 'relu':
172
+ activate_fct = nn.ReLU
173
+ elif self.hidden_layer_activation == 'softplus':
174
+ activate_fct = nn.Softplus
175
+ elif self.hidden_layer_activation == 'leakyrelu':
176
+ activate_fct = nn.LeakyReLU
177
+ elif self.hidden_layer_activation == 'linear':
178
+ activate_fct = nn.Identity
179
+
180
+ if self.supervised_mode:
181
+ self.encoder_n = MLP(
182
+ [self.input_size] + hidden_sizes + [self.code_size],
183
+ activation=activate_fct,
184
+ output_activation=None,
185
+ post_layer_fct=post_layer_fct,
186
+ post_act_fct=post_act_fct,
187
+ allow_broadcast=self.allow_broadcast,
188
+ use_cuda=self.use_cuda,
189
+ )
190
+ else:
191
+ self.encoder_n = MLP(
192
+ [self.latent_dim] + hidden_sizes + [self.code_size],
193
+ activation=activate_fct,
194
+ output_activation=None,
195
+ post_layer_fct=post_layer_fct,
196
+ post_act_fct=post_act_fct,
197
+ allow_broadcast=self.allow_broadcast,
198
+ use_cuda=self.use_cuda,
199
+ )
200
+
201
+ self.encoder_zn = MLP(
202
+ [self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
203
+ activation=activate_fct,
204
+ output_activation=[None, Exp],
205
+ post_layer_fct=post_layer_fct,
206
+ post_act_fct=post_act_fct,
207
+ allow_broadcast=self.allow_broadcast,
208
+ use_cuda=self.use_cuda,
209
+ )
210
+
211
+ if self.loss_func == 'negbinomial':
212
+ self.encoder_inverse_dispersion = MLP(
213
+ [self.latent_dim] + hidden_sizes + [[self.input_size, self.input_size]],
214
+ activation=activate_fct,
215
+ output_activation=[Exp, Exp],
216
+ post_layer_fct=post_layer_fct,
217
+ post_act_fct=post_act_fct,
218
+ allow_broadcast=self.allow_broadcast,
219
+ use_cuda=self.use_cuda,
220
+ )
221
+
222
+ if self.cell_factor_size>0:
223
+ self.cell_factor_effect = nn.ModuleList()
224
+ for i in np.arange(self.cell_factor_size):
225
+ if self.use_bias[i]:
226
+ if self.turn_off_cell_specific:
227
+ self.cell_factor_effect.append(MLP(
228
+ [1] + self.decoder_hidden_layers + [self.latent_dim],
229
+ activation=activate_fct,
230
+ output_activation=None,
231
+ post_layer_fct=post_layer_fct,
232
+ post_act_fct=post_act_fct,
233
+ allow_broadcast=self.allow_broadcast,
234
+ use_cuda=self.use_cuda,
235
+ )
236
+ )
237
+ else:
238
+ self.cell_factor_effect.append(MLP(
239
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
240
+ activation=activate_fct,
241
+ output_activation=None,
242
+ post_layer_fct=post_layer_fct,
243
+ post_act_fct=post_act_fct,
244
+ allow_broadcast=self.allow_broadcast,
245
+ use_cuda=self.use_cuda,
246
+ )
247
+ )
248
+ else:
249
+ if self.turn_off_cell_specific:
250
+ self.cell_factor_effect.append(ZeroBiasMLP(
251
+ [1] + self.decoder_hidden_layers + [self.latent_dim],
252
+ activation=activate_fct,
253
+ output_activation=None,
254
+ post_layer_fct=post_layer_fct,
255
+ post_act_fct=post_act_fct,
256
+ allow_broadcast=self.allow_broadcast,
257
+ use_cuda=self.use_cuda,
258
+ )
259
+ )
260
+ else:
261
+ self.cell_factor_effect.append(ZeroBiasMLP(
262
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
263
+ activation=activate_fct,
264
+ output_activation=None,
265
+ post_layer_fct=post_layer_fct,
266
+ post_act_fct=post_act_fct,
267
+ allow_broadcast=self.allow_broadcast,
268
+ use_cuda=self.use_cuda,
269
+ )
270
+ )
271
+
272
+ self.decoder_concentrate = MLP(
273
+ [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
274
+ activation=activate_fct,
275
+ output_activation=None,
276
+ post_layer_fct=post_layer_fct,
277
+ post_act_fct=post_act_fct,
278
+ allow_broadcast=self.allow_broadcast,
279
+ use_cuda=self.use_cuda,
280
+ )
281
+
282
+ if self.latent_dist == 'studentt':
283
+ self.codebook = MLP(
284
+ [self.code_size] + hidden_sizes + [[latent_dim,latent_dim]],
285
+ activation=activate_fct,
286
+ output_activation=[Exp,None],
287
+ post_layer_fct=post_layer_fct,
288
+ post_act_fct=post_act_fct,
289
+ allow_broadcast=self.allow_broadcast,
290
+ use_cuda=self.use_cuda,
291
+ )
292
+ else:
293
+ self.codebook = MLP(
294
+ [self.code_size] + hidden_sizes + [latent_dim],
295
+ activation=activate_fct,
296
+ output_activation=None,
297
+ post_layer_fct=post_layer_fct,
298
+ post_act_fct=post_act_fct,
299
+ allow_broadcast=self.allow_broadcast,
300
+ use_cuda=self.use_cuda,
301
+ )
302
+
303
+ if self.use_cuda:
304
+ self.cuda()
305
+
306
+ def get_device(self):
307
+ return next(self.parameters()).device
308
+
309
+ def cutoff(self, xs, thresh=None):
310
+ eps = torch.finfo(xs.dtype).eps
311
+
312
+ if not thresh is None:
313
+ if eps < thresh:
314
+ eps = thresh
315
+
316
+ xs = xs.clamp(min=eps)
317
+
318
+ if torch.any(torch.isnan(xs)):
319
+ xs[torch.isnan(xs)] = eps
320
+
321
+ return xs
322
+
323
+ def softmax(self, xs):
324
+ #xs = SoftmaxTransform()(xs)
325
+ xs = dist.Multinomial(total_count=1, logits=xs).mean
326
+ return xs
327
+
328
+ def sigmoid(self, xs):
329
+ #sigm_enc = nn.Sigmoid()
330
+ #xs = sigm_enc(xs)
331
+ #xs = clamp_probs(xs)
332
+ xs = dist.Bernoulli(logits=xs).mean
333
+ return xs
334
+
335
+ def softmax_logit(self, xs):
336
+ eps = torch.finfo(xs.dtype).eps
337
+ xs = self.softmax(xs)
338
+ xs = torch.logit(xs, eps=eps)
339
+ return xs
340
+
341
+ def logit(self, xs):
342
+ eps = torch.finfo(xs.dtype).eps
343
+ xs = torch.logit(xs, eps=eps)
344
+ return xs
345
+
346
+ def dirimulti_param(self, xs):
347
+ xs = self.dirimulti_mass * self.sigmoid(xs)
348
+ return xs
349
+
350
+ def multi_param(self, xs):
351
+ xs = self.softmax(xs)
352
+ return xs
353
+
354
+ def model1(self, xs):
355
+ pyro.module('DensityFlow2', self)
356
+
357
+ eps = torch.finfo(xs.dtype).eps
358
+ batch_size = xs.size(0)
359
+ self.options = dict(dtype=xs.dtype, device=xs.device)
360
+
361
+ if self.loss_func=='negbinomial':
362
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
363
+ xs.new_ones(self.input_size), constraint=constraints.positive)
364
+
365
+ if self.use_zeroinflate:
366
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
367
+
368
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
369
+
370
+ I = torch.eye(self.code_size)
371
+ if self.latent_dist=='studentt':
372
+ acs_dof,acs_loc = self.codebook(I)
373
+ else:
374
+ acs_loc = self.codebook(I)
375
+
376
+ with pyro.plate('data'):
377
+ prior = torch.zeros(batch_size, self.code_size, **self.options)
378
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
379
+
380
+ zn_loc = torch.matmul(ns,acs_loc)
381
+ #zn_scale = torch.matmul(ns,acs_scale)
382
+ zn_scale = acs_scale
383
+
384
+ if self.latent_dist == 'studentt':
385
+ prior_dof = torch.matmul(ns,acs_dof)
386
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
387
+ elif self.latent_dist == 'laplacian':
388
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
389
+ elif self.latent_dist == 'cauchy':
390
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
391
+ elif self.latent_dist == 'normal':
392
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
393
+ elif self.latent_dist == 'gumbel':
394
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
395
+
396
+ zs = zns
397
+ concentrate = self.decoder_concentrate(zs)
398
+ if self.loss_func in ['bernoulli']:
399
+ log_theta = concentrate
400
+ else:
401
+ rate = concentrate.exp()
402
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
403
+ if self.loss_func == 'poisson':
404
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
405
+
406
+ if self.loss_func == 'negbinomial':
407
+ if self.use_zeroinflate:
408
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
409
+ else:
410
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
411
+ elif self.loss_func == 'poisson':
412
+ if self.use_zeroinflate:
413
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
414
+ else:
415
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
416
+ elif self.loss_func == 'multinomial':
417
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
418
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
419
+ elif self.loss_func == 'bernoulli':
420
+ if self.use_zeroinflate:
421
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
422
+ else:
423
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
424
+
425
+ def guide1(self, xs):
426
+ with pyro.plate('data'):
427
+ #zn_loc, zn_scale = self.encoder_zn(xs)
428
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
429
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
430
+
431
+ alpha = self.encoder_n(zns)
432
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
433
+
434
+ def model2(self, xs, us=None):
435
+ pyro.module('DensityFlow2', self)
436
+
437
+ eps = torch.finfo(xs.dtype).eps
438
+ batch_size = xs.size(0)
439
+ self.options = dict(dtype=xs.dtype, device=xs.device)
440
+
441
+ if self.loss_func=='negbinomial':
442
+ #total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
443
+ # xs.new_ones(self.input_size), constraint=constraints.positive)
444
+ with pyro.plate("genes", self.input_size):
445
+ inverse_dispersion = pyro.sample("inverse_dispersion", dist.LogNormal(self.inverse_dispersion, 0.5).to_event(1))
446
+
447
+ if self.use_zeroinflate:
448
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
449
+
450
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
451
+
452
+ I = torch.eye(self.code_size)
453
+ if self.latent_dist=='studentt':
454
+ acs_dof,acs_loc = self.codebook(I)
455
+ else:
456
+ acs_loc = self.codebook(I)
457
+
458
+ with pyro.plate('data'):
459
+ prior = torch.zeros(batch_size, self.code_size, **self.options)
460
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
461
+
462
+ zn_loc = torch.matmul(ns,acs_loc)
463
+ #zn_scale = torch.matmul(ns,acs_scale)
464
+ zn_scale = acs_scale
465
+
466
+ if self.latent_dist == 'studentt':
467
+ prior_dof = torch.matmul(ns,acs_dof)
468
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
469
+ elif self.latent_dist == 'laplacian':
470
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
471
+ elif self.latent_dist == 'cauchy':
472
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
473
+ elif self.latent_dist == 'normal':
474
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
475
+ elif self.latent_dist == 'gumbel':
476
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
477
+
478
+ '''if self.cell_factor_size>0:
479
+ zus = self._total_shifts(zns, us)
480
+ zs = zns+zus
481
+ else:
482
+ zs = zns'''
483
+
484
+ zs = zns
485
+ concentrate = self.decoder_concentrate(zs)
486
+ for i in np.arange(self.cell_factor_size):
487
+ zus = self._cell_shift(zs, i, us[:,i].reshape(-1,1))
488
+ concentrate += self.decoder_concentrate(zus)
489
+
490
+ if self.loss_func in ['bernoulli']:
491
+ log_theta = concentrate
492
+ elif self.loss_func in ['negbinomial']:
493
+ mu = concentrate.exp()
494
+ else:
495
+ rate = concentrate.exp()
496
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
497
+ if self.loss_func == 'poisson':
498
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
499
+
500
+ if self.loss_func == 'negbinomial':
501
+ logits = (mu.log()-inverse_dispersion.log()).clamp(min=-10, max=10)
502
+ if self.use_zeroinflate:
503
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=inverse_dispersion,
504
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
505
+ else:
506
+ pyro.sample('x', dist.NegativeBinomial(total_count=inverse_dispersion,
507
+ logits=logits).to_event(1), obs=xs)
508
+ elif self.loss_func == 'poisson':
509
+ if self.use_zeroinflate:
510
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
511
+ else:
512
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
513
+ elif self.loss_func == 'multinomial':
514
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
515
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
516
+ elif self.loss_func == 'bernoulli':
517
+ if self.use_zeroinflate:
518
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
519
+ else:
520
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
521
+
522
+ def guide2(self, xs, us=None):
523
+ with pyro.plate('data'):
524
+ #zn_loc, zn_scale = self.encoder_zn(xs)
525
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
526
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
527
+
528
+ alpha = self.encoder_n(zns)
529
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
530
+
531
+ if self.loss_func == 'negbinomial':
532
+ id_loc,id_scale = self.encoder_inverse_dispersion(zns)
533
+ with pyro.plate("genes", self.input_size):
534
+ pyro.sample("inverse_dispersion", dist.LogNormal(id_loc, id_scale).to_event(1))
535
+
536
+ def model3(self, xs, ys, embeds=None):
537
+ pyro.module('DensityFlow2', self)
538
+
539
+ eps = torch.finfo(xs.dtype).eps
540
+ batch_size = xs.size(0)
541
+ self.options = dict(dtype=xs.dtype, device=xs.device)
542
+
543
+ if self.loss_func=='negbinomial':
544
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
545
+ xs.new_ones(self.input_size), constraint=constraints.positive)
546
+
547
+ if self.use_zeroinflate:
548
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
549
+
550
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
551
+
552
+ I = torch.eye(self.code_size)
553
+ if self.latent_dist=='studentt':
554
+ acs_dof,acs_loc = self.codebook(I)
555
+ else:
556
+ acs_loc = self.codebook(I)
557
+
558
+ with pyro.plate('data'):
559
+ #prior = torch.zeros(batch_size, self.code_size, **self.options)
560
+ prior = self.encoder_n(xs)
561
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
562
+
563
+ zn_loc = torch.matmul(ns,acs_loc)
564
+ #prior_scale = torch.matmul(ns,acs_scale)
565
+ zn_scale = acs_scale
566
+
567
+ if self.latent_dist=='studentt':
568
+ prior_dof = torch.matmul(ns,acs_dof)
569
+ if embeds is None:
570
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
571
+ else:
572
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
573
+ elif self.latent_dist=='laplacian':
574
+ if embeds is None:
575
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
576
+ else:
577
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
578
+ elif self.latent_dist=='cauchy':
579
+ if embeds is None:
580
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
581
+ else:
582
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
583
+ elif self.latent_dist=='normal':
584
+ if embeds is None:
585
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
586
+ else:
587
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
588
+ elif self.z_dist == 'gumbel':
589
+ if embeds is None:
590
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
591
+ else:
592
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
593
+
594
+ zs = zns
595
+
596
+ concentrate = self.decoder_concentrate(zs)
597
+ if self.loss_func in ['bernoulli']:
598
+ log_theta = concentrate
599
+ else:
600
+ rate = concentrate.exp()
601
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
602
+ if self.loss_func == 'poisson':
603
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
604
+
605
+ if self.loss_func == 'negbinomial':
606
+ if self.use_zeroinflate:
607
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
608
+ else:
609
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
610
+ elif self.loss_func == 'poisson':
611
+ if self.use_zeroinflate:
612
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
613
+ else:
614
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
615
+ elif self.loss_func == 'multinomial':
616
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
617
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
618
+ elif self.loss_func == 'bernoulli':
619
+ if self.use_zeroinflate:
620
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
621
+ else:
622
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
623
+
624
+ def guide3(self, xs, ys, embeds=None):
625
+ with pyro.plate('data'):
626
+ if embeds is None:
627
+ #zn_loc, zn_scale = self.encoder_zn(xs)
628
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
629
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
630
+ else:
631
+ zns = embeds
632
+
633
+ def model4(self, xs, us, ys, embeds=None):
634
+ pyro.module('DensityFlow2', self)
635
+
636
+ eps = torch.finfo(xs.dtype).eps
637
+ batch_size = xs.size(0)
638
+ self.options = dict(dtype=xs.dtype, device=xs.device)
639
+
640
+ if self.loss_func=='negbinomial':
641
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
642
+ xs.new_ones(self.input_size), constraint=constraints.positive)
643
+
644
+ if self.use_zeroinflate:
645
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
646
+
647
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
648
+
649
+ I = torch.eye(self.code_size)
650
+ if self.latent_dist=='studentt':
651
+ acs_dof,acs_loc = self.codebook(I)
652
+ else:
653
+ acs_loc = self.codebook(I)
654
+
655
+ with pyro.plate('data'):
656
+ #prior = torch.zeros(batch_size, self.code_size, **self.options)
657
+ prior = self.encoder_n(xs)
658
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
659
+
660
+ zn_loc = torch.matmul(ns,acs_loc)
661
+ #prior_scale = torch.matmul(ns,acs_scale)
662
+ zn_scale = acs_scale
663
+
664
+ if self.latent_dist=='studentt':
665
+ prior_dof = torch.matmul(ns,acs_dof)
666
+ if embeds is None:
667
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
668
+ else:
669
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
670
+ elif self.latent_dist=='laplacian':
671
+ if embeds is None:
672
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
673
+ else:
674
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
675
+ elif self.latent_dist=='cauchy':
676
+ if embeds is None:
677
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
678
+ else:
679
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
680
+ elif self.latent_dist=='normal':
681
+ if embeds is None:
682
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
683
+ else:
684
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
685
+ elif self.z_dist == 'gumbel':
686
+ if embeds is None:
687
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
688
+ else:
689
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
690
+
691
+ '''if self.cell_factor_size>0:
692
+ zus = self._total_shifts(zns, us)
693
+ zs = zns+zus
694
+ else:
695
+ zs = zns'''
696
+
697
+ zs = zns
698
+ concentrate = self.decoder_concentrate(zs)
699
+ for i in np.arange(self.cell_factor_size):
700
+ zus = self._cell_shift(zs, i, us[:,i].reshape(-1,1))
701
+ concentrate += self.decoder_concentrate(zus)
702
+
703
+ if self.loss_func in ['bernoulli']:
704
+ log_theta = concentrate
705
+ else:
706
+ rate = concentrate.exp()
707
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
708
+ if self.loss_func == 'poisson':
709
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
710
+
711
+ if self.loss_func == 'negbinomial':
712
+ if self.use_zeroinflate:
713
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
714
+ else:
715
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
716
+ elif self.loss_func == 'poisson':
717
+ if self.use_zeroinflate:
718
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
719
+ else:
720
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
721
+ elif self.loss_func == 'multinomial':
722
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
723
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
724
+ elif self.loss_func == 'bernoulli':
725
+ if self.use_zeroinflate:
726
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
727
+ else:
728
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
729
+
730
+ def guide4(self, xs, us, ys, embeds=None):
731
+ with pyro.plate('data'):
732
+ if embeds is None:
733
+ #zn_loc, zn_scale = self.encoder_zn(xs)
734
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
735
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
736
+ else:
737
+ zns = embeds
738
+
739
+ def _total_shifts(self, zns, us):
740
+ zus = None
741
+ for i in np.arange(self.cell_factor_size):
742
+ if i==0:
743
+ #if self.turn_off_cell_specific:
744
+ # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
745
+ #else:
746
+ # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
747
+ zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
748
+ else:
749
+ #if self.turn_off_cell_specific:
750
+ # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
751
+ #else:
752
+ # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
753
+ zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
754
+ return zus
755
+
756
+ def _get_codebook_identity(self):
757
+ return torch.eye(self.code_size, **self.options)
758
+
759
+ def _get_codebook(self):
760
+ I = torch.eye(self.code_size, **self.options)
761
+ if self.latent_dist=='studentt':
762
+ _,cb = self.codebook(I)
763
+ else:
764
+ cb = self.codebook(I)
765
+ return cb
766
+
767
+ def get_codebook(self):
768
+ """
769
+ Return the mean part of metacell codebook
770
+ """
771
+ cb = self._get_codebook()
772
+ cb = tensor_to_numpy(cb)
773
+ return cb
774
+
775
+ def _get_basal_embedding(self, xs):
776
+ loc, scale = self.encoder_zn(xs)
777
+ return loc, scale
778
+
779
+ def get_basal_embedding(self,
780
+ xs,
781
+ batch_size: int = 1024):
782
+ """
783
+ Return cells' basal latent representations
784
+
785
+ Parameters
786
+ ----------
787
+ xs
788
+ Single-cell expression matrix. It should be a Numpy array or a Pytorch Tensor.
789
+ batch_size
790
+ Size of batch processing.
791
+ use_decoder
792
+ If toggled on, the latent representations will be reconstructed from the metacell codebook
793
+ soft_assign
794
+ If toggled on, the assignments of cells will use probabilistic values.
795
+ """
796
+ xs = self.preprocess(xs)
797
+ xs = convert_to_tensor(xs, device=self.get_device())
798
+ dataset = CustomDataset(xs)
799
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
800
+
801
+ Z = []
802
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
803
+ for X_batch, _ in dataloader:
804
+ zns,_ = self._get_basal_embedding(X_batch)
805
+ Z.append(tensor_to_numpy(zns))
806
+ pbar.update(1)
807
+
808
+ Z = np.concatenate(Z)
809
+ return Z
810
+
811
+ def _code(self, xs):
812
+ if self.supervised_mode:
813
+ alpha = self.encoder_n(xs)
814
+ else:
815
+ #zns,_ = self.encoder_zn(xs)
816
+ zns,_ = self._get_basal_embedding(xs)
817
+ alpha = self.encoder_n(zns)
818
+ return alpha
819
+
820
+ def code(self, xs, batch_size=1024):
821
+ xs = self.preprocess(xs)
822
+ xs = convert_to_tensor(xs, device=self.get_device())
823
+ dataset = CustomDataset(xs)
824
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
825
+
826
+ A = []
827
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
828
+ for X_batch, _ in dataloader:
829
+ a = self._code(X_batch)
830
+ A.append(tensor_to_numpy(a))
831
+ pbar.update(1)
832
+
833
+ A = np.concatenate(A)
834
+ return A
835
+
836
+ def _soft_assignments(self, xs):
837
+ alpha = self._code(xs)
838
+ alpha = self.softmax(alpha)
839
+ return alpha
840
+
841
+ def soft_assignments(self, xs, batch_size=1024):
842
+ """
843
+ Map cells to metacells and return the probabilistic values of metacell assignments
844
+ """
845
+ xs = self.preprocess(xs)
846
+ xs = convert_to_tensor(xs, device=self.get_device())
847
+ dataset = CustomDataset(xs)
848
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
849
+
850
+ A = []
851
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
852
+ for X_batch, _ in dataloader:
853
+ a = self._soft_assignments(X_batch)
854
+ A.append(tensor_to_numpy(a))
855
+ pbar.update(1)
856
+
857
+ A = np.concatenate(A)
858
+ return A
859
+
860
+ def _hard_assignments(self, xs):
861
+ alpha = self._code(xs)
862
+ res, ind = torch.topk(alpha, 1)
863
+ ns = torch.zeros_like(alpha).scatter_(1, ind, 1.0)
864
+ return ns
865
+
866
+ def hard_assignments(self, xs, batch_size=1024):
867
+ """
868
+ Map cells to metacells and return the assigned metacell identities.
869
+ """
870
+ xs = self.preprocess(xs)
871
+ xs = convert_to_tensor(xs, device=self.get_device())
872
+ dataset = CustomDataset(xs)
873
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
874
+
875
+ A = []
876
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
877
+ for X_batch, _ in dataloader:
878
+ a = self._hard_assignments(X_batch)
879
+ A.append(tensor_to_numpy(a))
880
+ pbar.update(1)
881
+
882
+ A = np.concatenate(A)
883
+ return A
884
+
885
+ def predict(self, xs, us, perturbs_predict:list, perturbs_reference:list, library_sizes=None):
886
+ perturbs_reference = np.array(perturbs_reference)
887
+
888
+ # basal embedding
889
+ zs = self.get_basal_embedding(xs)
890
+ for pert in perturbs_predict:
891
+ pert_idx = int(np.where(perturbs_reference==pert)[0])
892
+ us_i = us[:,pert_idx].reshape(-1,1)
893
+
894
+ # factor effect of xs
895
+ dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
896
+
897
+ # perturbation effect
898
+ ps = np.ones_like(us_i)
899
+ if np.sum(np.abs(ps-us_i))>=1:
900
+ dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
901
+ zs = zs + dzs0 + dzs
902
+ else:
903
+ zs = zs + dzs0
904
+
905
+ if library_sizes is None:
906
+ library_sizes = np.sum(xs, axis=1, keepdims=True)
907
+ elif type(library_sizes) == list:
908
+ library_sizes = np.array(library_sizes)
909
+ library_sizes = library_sizes.reshape(-1,1)
910
+ elif len(library_sizes.shape)==1:
911
+ library_sizes = library_sizes.reshape(-1,1)
912
+
913
+ counts = self.get_counts(zs, library_sizes=library_sizes)
914
+
915
+ return counts, zs
916
+
917
+ def _cell_shift(self, zs, perturb_idx, perturb):
918
+ #zns,_ = self.encoder_zn(xs)
919
+ #zns,_ = self._get_basal_embedding(xs)
920
+ zns = zs
921
+ if perturb.ndim==2:
922
+ if self.turn_off_cell_specific:
923
+ ms = self.cell_factor_effect[perturb_idx](perturb)
924
+ else:
925
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
926
+ else:
927
+ if self.turn_off_cell_specific:
928
+ ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
929
+ else:
930
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
931
+
932
+ return ms
933
+
934
+ def get_cell_shift(self,
935
+ zs,
936
+ perturb_idx,
937
+ perturb_us,
938
+ batch_size: int = 1024):
939
+ """
940
+ Return cells' changes in the latent space induced by specific perturbation of a factor
941
+
942
+ """
943
+ #xs = self.preprocess(xs)
944
+ zs = convert_to_tensor(zs, device=self.get_device())
945
+ ps = convert_to_tensor(perturb_us, device=self.get_device())
946
+ dataset = CustomDataset2(zs,ps)
947
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
948
+
949
+ Z = []
950
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
951
+ for Z_batch, P_batch, _ in dataloader:
952
+ zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
953
+ Z.append(tensor_to_numpy(zns))
954
+ pbar.update(1)
955
+
956
+ Z = np.concatenate(Z)
957
+ return Z
958
+
959
+ def _get_theta(self, delta_zs):
960
+ return self.decoder_concentrate(delta_zs)
961
+
962
+ def get_theta(self,
963
+ delta_zs,
964
+ batch_size: int = 1024):
965
+ """
966
+ Return cells' changes in the feature space induced by specific perturbation of a factor
967
+
968
+ """
969
+ delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
970
+ dataset = CustomDataset(delta_zs)
971
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
972
+
973
+ R = []
974
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
975
+ for delta_Z_batch, _ in dataloader:
976
+ r = self._get_theta(delta_Z_batch)
977
+ R.append(tensor_to_numpy(r))
978
+ pbar.update(1)
979
+
980
+ R = np.concatenate(R)
981
+ return R
982
+
983
+ def _count(self, concentrate, library_size=None):
984
+ if self.loss_func == 'bernoulli':
985
+ #counts = self.sigmoid(concentrate)
986
+ counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
987
+ elif self.loss_func == 'multinomial':
988
+ theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
989
+ counts = theta * library_size
990
+ else:
991
+ rate = concentrate.exp()
992
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
993
+ counts = theta * library_size
994
+ return counts
995
+
996
+ def get_counts(self, concentrate,
997
+ library_sizes,
998
+ batch_size: int = 1024):
999
+
1000
+ concentrate = convert_to_tensor(concentrate, device=self.get_device())
1001
+
1002
+ if type(library_sizes) == list:
1003
+ library_sizes = np.array(library_sizes).reshape(-1,1)
1004
+ elif len(library_sizes.shape)==1:
1005
+ library_sizes = library_sizes.reshape(-1,1)
1006
+ ls = convert_to_tensor(library_sizes, device=self.get_device())
1007
+
1008
+ dataset = CustomDataset2(concentrate,ls)
1009
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
1010
+
1011
+ E = []
1012
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
1013
+ for C_batch, L_batch, _ in dataloader:
1014
+ counts = self._count(C_batch, L_batch)
1015
+ E.append(tensor_to_numpy(counts))
1016
+ pbar.update(1)
1017
+
1018
+ E = np.concatenate(E)
1019
+ return E
1020
+
1021
+ def preprocess(self, xs, threshold=0):
1022
+ if self.loss_func == 'bernoulli':
1023
+ ad = sc.AnnData(xs)
1024
+ binarize(ad, threshold=threshold)
1025
+ xs = ad.X.copy()
1026
+ else:
1027
+ xs = np.round(xs)
1028
+
1029
+ if sparse.issparse(xs):
1030
+ xs = xs.toarray()
1031
+ return xs
1032
+
1033
+ def fit(self, xs,
1034
+ us = None,
1035
+ ys = None,
1036
+ zs = None,
1037
+ num_epochs: int = 500,
1038
+ learning_rate: float = 0.0001,
1039
+ batch_size: int = 256,
1040
+ algo: Literal['adam','rmsprop','adamw'] = 'adam',
1041
+ beta_1: float = 0.9,
1042
+ weight_decay: float = 0.005,
1043
+ decay_rate: float = 0.9,
1044
+ config_enum: str = 'parallel',
1045
+ threshold: int = 0,
1046
+ use_jax: bool = True):
1047
+ """
1048
+ Train the DensityFlow2 model.
1049
+
1050
+ Parameters
1051
+ ----------
1052
+ xs
1053
+ Single-cell experssion matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are features.
1054
+ us
1055
+ cell-level factor matrix.
1056
+ ys
1057
+ Desired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are desired factors.
1058
+ num_epochs
1059
+ Number of training epochs.
1060
+ learning_rate
1061
+ Parameter for training.
1062
+ batch_size
1063
+ Size of batch processing.
1064
+ algo
1065
+ Optimization algorithm.
1066
+ beta_1
1067
+ Parameter for optimization.
1068
+ weight_decay
1069
+ Parameter for optimization.
1070
+ decay_rate
1071
+ Parameter for optimization.
1072
+ use_jax
1073
+ If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
1074
+ the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow2 in the shell command.
1075
+ """
1076
+ xs = self.preprocess(xs, threshold=threshold)
1077
+ xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
1078
+ if us is not None:
1079
+ us = convert_to_tensor(us, dtype=self.dtype, device=self.get_device())
1080
+ if ys is not None:
1081
+ ys = convert_to_tensor(ys, dtype=self.dtype, device=self.get_device())
1082
+ if zs is not None:
1083
+ zs = convert_to_tensor(zs, dtype=self.dtype, device=self.get_device())
1084
+
1085
+ dataset = CustomDataset4(xs, us, ys, zs)
1086
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
1087
+
1088
+ # setup the optimizer
1089
+ optim_params = {'lr': learning_rate, 'betas': (beta_1, 0.999), 'weight_decay': weight_decay}
1090
+
1091
+ if algo.lower()=='rmsprop':
1092
+ optimizer = torch.optim.RMSprop
1093
+ elif algo.lower()=='adam':
1094
+ optimizer = torch.optim.Adam
1095
+ elif algo.lower() == 'adamw':
1096
+ optimizer = torch.optim.AdamW
1097
+ else:
1098
+ raise ValueError("An optimization algorithm must be specified.")
1099
+ scheduler = ExponentialLR({'optimizer': optimizer, 'optim_args': optim_params, 'gamma': decay_rate})
1100
+
1101
+ pyro.clear_param_store()
1102
+
1103
+ # set up the loss(es) for inference, wrapping the guide in config_enumerate builds the loss as a sum
1104
+ # by enumerating each class label form the sampled discrete categorical distribution in the model
1105
+ Elbo = JitTraceEnum_ELBO if use_jax else TraceEnum_ELBO
1106
+ elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False)
1107
+ if us is None:
1108
+ if ys is None:
1109
+ guide = config_enumerate(self.guide1, config_enum, expand=True)
1110
+ loss_basic = SVI(self.model1, guide, scheduler, loss=elbo)
1111
+ else:
1112
+ guide = config_enumerate(self.guide3, config_enum, expand=True)
1113
+ loss_basic = SVI(self.model3, guide, scheduler, loss=elbo)
1114
+ else:
1115
+ if ys is None:
1116
+ guide = config_enumerate(self.guide2, config_enum, expand=True)
1117
+ loss_basic = SVI(self.model2, guide, scheduler, loss=elbo)
1118
+ else:
1119
+ guide = config_enumerate(self.guide4, config_enum, expand=True)
1120
+ loss_basic = SVI(self.model4, guide, scheduler, loss=elbo)
1121
+
1122
+ # build a list of all losses considered
1123
+ losses = [loss_basic]
1124
+ num_losses = len(losses)
1125
+
1126
+ with tqdm(total=num_epochs, desc='Training', unit='epoch') as pbar:
1127
+ for epoch in range(num_epochs):
1128
+ epoch_losses = [0.0] * num_losses
1129
+ for batch_x, batch_u, batch_y, batch_z, _ in dataloader:
1130
+ if us is None:
1131
+ batch_u = None
1132
+ if ys is None:
1133
+ batch_y = None
1134
+ if zs is None:
1135
+ batch_z = None
1136
+
1137
+ for loss_id in range(num_losses):
1138
+ if batch_u is None:
1139
+ if batch_y is None:
1140
+ new_loss = losses[loss_id].step(batch_x)
1141
+ else:
1142
+ new_loss = losses[loss_id].step(batch_x, batch_y, batch_z)
1143
+ else:
1144
+ if batch_y is None:
1145
+ new_loss = losses[loss_id].step(batch_x, batch_u)
1146
+ else:
1147
+ new_loss = losses[loss_id].step(batch_x, batch_u, batch_y, batch_z)
1148
+ epoch_losses[loss_id] += new_loss
1149
+
1150
+ avg_epoch_losses_ = map(lambda v: v / len(dataloader), epoch_losses)
1151
+ avg_epoch_losses = map(lambda v: "{:.4f}".format(v), avg_epoch_losses_)
1152
+
1153
+ # store the loss
1154
+ str_loss = " ".join(map(str, avg_epoch_losses))
1155
+
1156
+ # Update progress bar
1157
+ pbar.set_postfix({'loss': str_loss})
1158
+ pbar.update(1)
1159
+
1160
+ @classmethod
1161
+ def save_model(cls, model, file_path, compression=False):
1162
+ """Save the model to the specified file path."""
1163
+ file_path = os.path.abspath(file_path)
1164
+
1165
+ model.eval()
1166
+ if compression:
1167
+ with gzip.open(file_path, 'wb') as pickle_file:
1168
+ pickle.dump(model, pickle_file)
1169
+ else:
1170
+ with open(file_path, 'wb') as pickle_file:
1171
+ pickle.dump(model, pickle_file)
1172
+
1173
+ print(f'Model saved to {file_path}')
1174
+
1175
+ @classmethod
1176
+ def load_model(cls, file_path):
1177
+ """Load the model from the specified file path and return an instance."""
1178
+ print(f'Model loaded from {file_path}')
1179
+
1180
+ file_path = os.path.abspath(file_path)
1181
+ if file_path.endswith('gz'):
1182
+ with gzip.open(file_path, 'rb') as pickle_file:
1183
+ model = pickle.load(pickle_file)
1184
+ else:
1185
+ with open(file_path, 'rb') as pickle_file:
1186
+ model = pickle.load(pickle_file)
1187
+
1188
+ return model
1189
+
1190
+
1191
+ EXAMPLE_RUN = (
1192
+ "example run: DensityFlow2 --help"
1193
+ )
1194
+
1195
+ def parse_args():
1196
+ parser = argparse.ArgumentParser(
1197
+ description="DensityFlow2\n{}".format(EXAMPLE_RUN))
1198
+
1199
+ parser.add_argument(
1200
+ "--cuda", action="store_true", help="use GPU(s) to speed up training"
1201
+ )
1202
+ parser.add_argument(
1203
+ "--jit", action="store_true", help="use PyTorch jit to speed up training"
1204
+ )
1205
+ parser.add_argument(
1206
+ "-n", "--num-epochs", default=200, type=int, help="number of epochs to run"
1207
+ )
1208
+ parser.add_argument(
1209
+ "-enum",
1210
+ "--enum-discrete",
1211
+ default="parallel",
1212
+ help="parallel, sequential or none. uses parallel enumeration by default",
1213
+ )
1214
+ parser.add_argument(
1215
+ "-data",
1216
+ "--data-file",
1217
+ default=None,
1218
+ type=str,
1219
+ help="the data file",
1220
+ )
1221
+ parser.add_argument(
1222
+ "-cf",
1223
+ "--cell-factor-file",
1224
+ default=None,
1225
+ type=str,
1226
+ help="the file for the record of cell-level factors",
1227
+ )
1228
+ parser.add_argument(
1229
+ "-bs",
1230
+ "--batch-size",
1231
+ default=1000,
1232
+ type=int,
1233
+ help="number of cells to be considered in a batch",
1234
+ )
1235
+ parser.add_argument(
1236
+ "-lr",
1237
+ "--learning-rate",
1238
+ default=0.0001,
1239
+ type=float,
1240
+ help="learning rate for Adam optimizer",
1241
+ )
1242
+ parser.add_argument(
1243
+ "-cs",
1244
+ "--codebook-size",
1245
+ default=100,
1246
+ type=int,
1247
+ help="size of vector quantization codebook",
1248
+ )
1249
+ parser.add_argument(
1250
+ "--z-dist",
1251
+ default='gumbel',
1252
+ type=str,
1253
+ choices=['normal','laplacian','studentt','gumbel','cauchy'],
1254
+ help="distribution model for latent representation",
1255
+ )
1256
+ parser.add_argument(
1257
+ "-zd",
1258
+ "--z-dim",
1259
+ default=10,
1260
+ type=int,
1261
+ help="size of the tensor representing the latent variable z variable",
1262
+ )
1263
+ parser.add_argument(
1264
+ "-likeli",
1265
+ "--likelihood",
1266
+ default='negbinomial',
1267
+ type=str,
1268
+ choices=['negbinomial', 'multinomial', 'poisson', 'bernoulli'],
1269
+ help="specify the distribution likelihood function",
1270
+ )
1271
+ parser.add_argument(
1272
+ "-zi",
1273
+ "--zeroinflate",
1274
+ action="store_true",
1275
+ help="use zero-inflated estimation",
1276
+ )
1277
+ parser.add_argument(
1278
+ "-id",
1279
+ "--inverse-dispersion",
1280
+ default=10.0,
1281
+ type=float,
1282
+ help="inverse dispersion prior for negative binomial",
1283
+ )
1284
+ parser.add_argument(
1285
+ "-hl",
1286
+ "--hidden-layers",
1287
+ nargs="+",
1288
+ default=[500],
1289
+ type=int,
1290
+ help="a tuple (or list) of MLP layers to be used in the neural networks "
1291
+ "representing the parameters of the distributions in our model",
1292
+ )
1293
+ parser.add_argument(
1294
+ "-hla",
1295
+ "--hidden-layer-activation",
1296
+ default='relu',
1297
+ type=str,
1298
+ choices=['relu','softplus','leakyrelu','linear'],
1299
+ help="activation function for hidden layers",
1300
+ )
1301
+ parser.add_argument(
1302
+ "-plf",
1303
+ "--post-layer-function",
1304
+ nargs="+",
1305
+ default=['layernorm'],
1306
+ type=str,
1307
+ help="post functions for hidden layers, could be none, dropout, layernorm, batchnorm, or combination, default is 'dropout layernorm'",
1308
+ )
1309
+ parser.add_argument(
1310
+ "-paf",
1311
+ "--post-activation-function",
1312
+ nargs="+",
1313
+ default=['none'],
1314
+ type=str,
1315
+ help="post functions for activation layers, could be none or dropout, default is 'none'",
1316
+ )
1317
+ parser.add_argument(
1318
+ "-64",
1319
+ "--float64",
1320
+ action="store_true",
1321
+ help="use double float precision",
1322
+ )
1323
+ parser.add_argument(
1324
+ "-dr",
1325
+ "--decay-rate",
1326
+ default=0.9,
1327
+ type=float,
1328
+ help="decay rate for Adam optimizer",
1329
+ )
1330
+ parser.add_argument(
1331
+ "--layer-dropout-rate",
1332
+ default=0.1,
1333
+ type=float,
1334
+ help="droput rate for neural networks",
1335
+ )
1336
+ parser.add_argument(
1337
+ "-b1",
1338
+ "--beta-1",
1339
+ default=0.95,
1340
+ type=float,
1341
+ help="beta-1 parameter for Adam optimizer",
1342
+ )
1343
+ parser.add_argument(
1344
+ "--seed",
1345
+ default=None,
1346
+ type=int,
1347
+ help="seed for controlling randomness in this example",
1348
+ )
1349
+ parser.add_argument(
1350
+ "--save-model",
1351
+ default=None,
1352
+ type=str,
1353
+ help="path to save model for prediction",
1354
+ )
1355
+ args = parser.parse_args()
1356
+ return args
1357
+
1358
+ def main():
1359
+ args = parse_args()
1360
+ assert (
1361
+ (args.data_file is not None) and (
1362
+ os.path.exists(args.data_file))
1363
+ ), "data file must be provided"
1364
+
1365
+ if args.seed is not None:
1366
+ set_random_seed(args.seed)
1367
+
1368
+ if args.float64:
1369
+ dtype = torch.float64
1370
+ torch.set_default_dtype(torch.float64)
1371
+ else:
1372
+ dtype = torch.float32
1373
+ torch.set_default_dtype(torch.float32)
1374
+
1375
+ xs = dt.fread(file=args.data_file, header=True).to_numpy()
1376
+ us = None
1377
+ if args.cell_factor_file is not None:
1378
+ us = dt.fread(file=args.cell_factor_file, header=True).to_numpy()
1379
+
1380
+ input_size = xs.shape[1]
1381
+ cell_factor_size = 0 if us is None else us.shape[1]
1382
+
1383
+ ###########################################
1384
+ df = DensityFlow2(
1385
+ input_size=input_size,
1386
+ cell_factor_size=cell_factor_size,
1387
+ inverse_dispersion=args.inverse_dispersion,
1388
+ z_dim=args.z_dim,
1389
+ hidden_layers=args.hidden_layers,
1390
+ hidden_layer_activation=args.hidden_layer_activation,
1391
+ use_cuda=args.cuda,
1392
+ config_enum=args.enum_discrete,
1393
+ use_zeroinflate=args.zeroinflate,
1394
+ loss_func=args.likelihood,
1395
+ nn_dropout=args.layer_dropout_rate,
1396
+ post_layer_fct=args.post_layer_function,
1397
+ post_act_fct=args.post_activation_function,
1398
+ codebook_size=args.codebook_size,
1399
+ z_dist = args.z_dist,
1400
+ dtype=dtype,
1401
+ )
1402
+
1403
+ df.fit(xs, us=us,
1404
+ num_epochs=args.num_epochs,
1405
+ learning_rate=args.learning_rate,
1406
+ batch_size=args.batch_size,
1407
+ beta_1=args.beta_1,
1408
+ decay_rate=args.decay_rate,
1409
+ use_jax=args.jit,
1410
+ config_enum=args.enum_discrete,
1411
+ )
1412
+
1413
+ if args.save_model is not None:
1414
+ if args.save_model.endswith('gz'):
1415
+ DensityFlow2.save_model(df, args.save_model, compression=True)
1416
+ else:
1417
+ DensityFlow2.save_model(df, args.save_model)
1418
+
1419
+
1420
+
1421
+ if __name__ == "__main__":
1422
+ main()