SURE-tools 2.1.92__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/DensityFlow.py ADDED
@@ -0,0 +1,1359 @@
1
+ import pyro
2
+ import pyro.distributions as dist
3
+ from pyro.optim import ExponentialLR
4
+ from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, config_enumerate
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader
9
+ from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_probs
10
+ from torch.distributions import constraints
11
+ from torch.distributions.transforms import SoftmaxTransform
12
+
13
+ from .utils.custom_mlp import MLP, Exp, ZeroBiasMLP
14
+ from .utils.utils import CustomDataset, CustomDataset2, CustomDataset4, tensor_to_numpy, convert_to_tensor
15
+
16
+
17
+ import os
18
+ import argparse
19
+ import random
20
+ import numpy as np
21
+ import datatable as dt
22
+ from tqdm import tqdm
23
+ from scipy import sparse
24
+
25
+ import scanpy as sc
26
+ from .atac import binarize
27
+
28
+ from typing import Literal
29
+
30
+ import warnings
31
+ warnings.filterwarnings("ignore")
32
+
33
+ import dill as pickle
34
+ import gzip
35
+ from packaging.version import Version
36
+ torch_version = torch.__version__
37
+
38
+
39
+ def set_random_seed(seed):
40
+ # Set seed for PyTorch
41
+ torch.manual_seed(seed)
42
+
43
+ # If using CUDA, set the seed for CUDA
44
+ if torch.cuda.is_available():
45
+ torch.cuda.manual_seed(seed)
46
+ torch.cuda.manual_seed_all(seed) # For multi-GPU setups.
47
+
48
+ # Set seed for NumPy
49
+ np.random.seed(seed)
50
+
51
+ # Set seed for Python's random module
52
+ random.seed(seed)
53
+
54
+ # Set seed for Pyro
55
+ pyro.set_rng_seed(seed)
56
+
57
+ class DensityFlow(nn.Module):
58
+ def __init__(self,
59
+ input_size: int,
60
+ codebook_size: int = 200,
61
+ cell_factor_size: int = 0,
62
+ supervised_mode: bool = False,
63
+ z_dim: int = 10,
64
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
66
+ inverse_dispersion: float = 10.0,
67
+ use_zeroinflate: bool = False,
68
+ hidden_layers: list = [300],
69
+ hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
70
+ nn_dropout: float = 0.1,
71
+ post_layer_fct: list = ['layernorm'],
72
+ post_act_fct: list = None,
73
+ config_enum: str = 'parallel',
74
+ use_cuda: bool = True,
75
+ seed: int = 42,
76
+ zero_bias: bool|list = True,
77
+ dtype = torch.float32, # type: ignore
78
+ ):
79
+ super().__init__()
80
+
81
+ self.input_size = input_size
82
+ self.cell_factor_size = cell_factor_size
83
+ self.inverse_dispersion = inverse_dispersion
84
+ self.latent_dim = z_dim
85
+ self.hidden_layers = hidden_layers
86
+ self.decoder_hidden_layers = hidden_layers[::-1]
87
+ self.allow_broadcast = config_enum == 'parallel'
88
+ self.use_cuda = use_cuda
89
+ self.loss_func = loss_func
90
+ self.options = None
91
+ self.code_size=codebook_size
92
+ self.supervised_mode=supervised_mode
93
+ self.latent_dist = z_dist
94
+ self.dtype = dtype
95
+ self.use_zeroinflate=use_zeroinflate
96
+ self.nn_dropout = nn_dropout
97
+ self.post_layer_fct = post_layer_fct
98
+ self.post_act_fct = post_act_fct
99
+ self.hidden_layer_activation = hidden_layer_activation
100
+ if type(zero_bias) == list:
101
+ self.use_bias = [not x for x in zero_bias]
102
+ else:
103
+ self.use_bias = [not zero_bias] * self.cell_factor_size
104
+ #self.use_bias = not zero_bias
105
+
106
+ self.codebook_weights = None
107
+
108
+ set_random_seed(seed)
109
+ self.setup_networks()
110
+
111
+ def setup_networks(self):
112
+ latent_dim = self.latent_dim
113
+ hidden_sizes = self.hidden_layers
114
+
115
+ nn_layer_norm, nn_batch_norm, nn_layer_dropout = False, False, False
116
+ na_layer_norm, na_batch_norm, na_layer_dropout = False, False, False
117
+
118
+ if self.post_layer_fct is not None:
119
+ nn_layer_norm=True if ('layernorm' in self.post_layer_fct) or ('layer_norm' in self.post_layer_fct) else False
120
+ nn_batch_norm=True if ('batchnorm' in self.post_layer_fct) or ('batch_norm' in self.post_layer_fct) else False
121
+ nn_layer_dropout=True if 'dropout' in self.post_layer_fct else False
122
+
123
+ if self.post_act_fct is not None:
124
+ na_layer_norm=True if ('layernorm' in self.post_act_fct) or ('layer_norm' in self.post_act_fct) else False
125
+ na_batch_norm=True if ('batchnorm' in self.post_act_fct) or ('batch_norm' in self.post_act_fct) else False
126
+ na_layer_dropout=True if 'dropout' in self.post_act_fct else False
127
+
128
+ if nn_layer_norm and nn_batch_norm and nn_layer_dropout:
129
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
130
+ elif nn_layer_norm and nn_layer_dropout:
131
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
132
+ elif nn_batch_norm and nn_layer_dropout:
133
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
134
+ elif nn_layer_norm and nn_batch_norm:
135
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
136
+ elif nn_layer_norm:
137
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
138
+ elif nn_batch_norm:
139
+ post_layer_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
140
+ elif nn_layer_dropout:
141
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
142
+ else:
143
+ post_layer_fct = lambda layer_ix, total_layers, layer: None
144
+
145
+ if na_layer_norm and na_batch_norm and na_layer_dropout:
146
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
147
+ elif na_layer_norm and na_layer_dropout:
148
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
149
+ elif na_batch_norm and na_layer_dropout:
150
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
151
+ elif na_layer_norm and na_batch_norm:
152
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
153
+ elif na_layer_norm:
154
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
155
+ elif na_batch_norm:
156
+ post_act_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
157
+ elif na_layer_dropout:
158
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
159
+ else:
160
+ post_act_fct = lambda layer_ix, total_layers, layer: None
161
+
162
+ if self.hidden_layer_activation == 'relu':
163
+ activate_fct = nn.ReLU
164
+ elif self.hidden_layer_activation == 'softplus':
165
+ activate_fct = nn.Softplus
166
+ elif self.hidden_layer_activation == 'leakyrelu':
167
+ activate_fct = nn.LeakyReLU
168
+ elif self.hidden_layer_activation == 'linear':
169
+ activate_fct = nn.Identity
170
+
171
+ if self.supervised_mode:
172
+ self.encoder_n = MLP(
173
+ [self.input_size] + hidden_sizes + [self.code_size],
174
+ activation=activate_fct,
175
+ output_activation=None,
176
+ post_layer_fct=post_layer_fct,
177
+ post_act_fct=post_act_fct,
178
+ allow_broadcast=self.allow_broadcast,
179
+ use_cuda=self.use_cuda,
180
+ )
181
+ else:
182
+ self.encoder_n = MLP(
183
+ [self.latent_dim] + hidden_sizes + [self.code_size],
184
+ activation=activate_fct,
185
+ output_activation=None,
186
+ post_layer_fct=post_layer_fct,
187
+ post_act_fct=post_act_fct,
188
+ allow_broadcast=self.allow_broadcast,
189
+ use_cuda=self.use_cuda,
190
+ )
191
+
192
+ self.encoder_zn = MLP(
193
+ [self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
194
+ activation=activate_fct,
195
+ output_activation=[None, Exp],
196
+ post_layer_fct=post_layer_fct,
197
+ post_act_fct=post_act_fct,
198
+ allow_broadcast=self.allow_broadcast,
199
+ use_cuda=self.use_cuda,
200
+ )
201
+
202
+ if self.cell_factor_size>0:
203
+ self.cell_factor_effect = nn.ModuleList()
204
+ for i in np.arange(self.cell_factor_size):
205
+ if self.use_bias[i]:
206
+ self.cell_factor_effect.append(MLP(
207
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
208
+ activation=activate_fct,
209
+ output_activation=None,
210
+ post_layer_fct=post_layer_fct,
211
+ post_act_fct=post_act_fct,
212
+ allow_broadcast=self.allow_broadcast,
213
+ use_cuda=self.use_cuda,
214
+ )
215
+ )
216
+ else:
217
+ self.cell_factor_effect.append(ZeroBiasMLP(
218
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
219
+ activation=activate_fct,
220
+ output_activation=None,
221
+ post_layer_fct=post_layer_fct,
222
+ post_act_fct=post_act_fct,
223
+ allow_broadcast=self.allow_broadcast,
224
+ use_cuda=self.use_cuda,
225
+ )
226
+ )
227
+
228
+ self.decoder_concentrate = MLP(
229
+ [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
230
+ activation=activate_fct,
231
+ output_activation=None,
232
+ post_layer_fct=post_layer_fct,
233
+ post_act_fct=post_act_fct,
234
+ allow_broadcast=self.allow_broadcast,
235
+ use_cuda=self.use_cuda,
236
+ )
237
+
238
+ if self.latent_dist == 'studentt':
239
+ self.codebook = MLP(
240
+ [self.code_size] + hidden_sizes + [[latent_dim,latent_dim]],
241
+ activation=activate_fct,
242
+ output_activation=[Exp,None],
243
+ post_layer_fct=post_layer_fct,
244
+ post_act_fct=post_act_fct,
245
+ allow_broadcast=self.allow_broadcast,
246
+ use_cuda=self.use_cuda,
247
+ )
248
+ else:
249
+ self.codebook = MLP(
250
+ [self.code_size] + hidden_sizes + [latent_dim],
251
+ activation=activate_fct,
252
+ output_activation=None,
253
+ post_layer_fct=post_layer_fct,
254
+ post_act_fct=post_act_fct,
255
+ allow_broadcast=self.allow_broadcast,
256
+ use_cuda=self.use_cuda,
257
+ )
258
+
259
+ if self.use_cuda:
260
+ self.cuda()
261
+
262
+ def get_device(self):
263
+ return next(self.parameters()).device
264
+
265
+ def cutoff(self, xs, thresh=None):
266
+ eps = torch.finfo(xs.dtype).eps
267
+
268
+ if not thresh is None:
269
+ if eps < thresh:
270
+ eps = thresh
271
+
272
+ xs = xs.clamp(min=eps)
273
+
274
+ if torch.any(torch.isnan(xs)):
275
+ xs[torch.isnan(xs)] = eps
276
+
277
+ return xs
278
+
279
+ def softmax(self, xs):
280
+ #xs = SoftmaxTransform()(xs)
281
+ xs = dist.Multinomial(total_count=1, logits=xs).mean
282
+ return xs
283
+
284
+ def sigmoid(self, xs):
285
+ #sigm_enc = nn.Sigmoid()
286
+ #xs = sigm_enc(xs)
287
+ #xs = clamp_probs(xs)
288
+ xs = dist.Bernoulli(logits=xs).mean
289
+ return xs
290
+
291
+ def softmax_logit(self, xs):
292
+ eps = torch.finfo(xs.dtype).eps
293
+ xs = self.softmax(xs)
294
+ xs = torch.logit(xs, eps=eps)
295
+ return xs
296
+
297
+ def logit(self, xs):
298
+ eps = torch.finfo(xs.dtype).eps
299
+ xs = torch.logit(xs, eps=eps)
300
+ return xs
301
+
302
+ def dirimulti_param(self, xs):
303
+ xs = self.dirimulti_mass * self.sigmoid(xs)
304
+ return xs
305
+
306
+ def multi_param(self, xs):
307
+ xs = self.softmax(xs)
308
+ return xs
309
+
310
+ def model1(self, xs):
311
+ pyro.module('DensityFlow', self)
312
+
313
+ eps = torch.finfo(xs.dtype).eps
314
+ batch_size = xs.size(0)
315
+ self.options = dict(dtype=xs.dtype, device=xs.device)
316
+
317
+ if self.loss_func=='negbinomial':
318
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
319
+ xs.new_ones(self.input_size), constraint=constraints.positive)
320
+
321
+ if self.use_zeroinflate:
322
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
323
+
324
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
325
+
326
+ I = torch.eye(self.code_size)
327
+ if self.latent_dist=='studentt':
328
+ acs_dof,acs_loc = self.codebook(I)
329
+ else:
330
+ acs_loc = self.codebook(I)
331
+
332
+ with pyro.plate('data'):
333
+ prior = torch.zeros(batch_size, self.code_size, **self.options)
334
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
335
+
336
+ zn_loc = torch.matmul(ns,acs_loc)
337
+ #zn_scale = torch.matmul(ns,acs_scale)
338
+ zn_scale = acs_scale
339
+
340
+ if self.latent_dist == 'studentt':
341
+ prior_dof = torch.matmul(ns,acs_dof)
342
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
343
+ elif self.latent_dist == 'laplacian':
344
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
345
+ elif self.latent_dist == 'cauchy':
346
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
347
+ elif self.latent_dist == 'normal':
348
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
349
+ elif self.latent_dist == 'gumbel':
350
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
351
+
352
+ zs = zns
353
+ concentrate = self.decoder_concentrate(zs)
354
+ if self.loss_func in ['bernoulli']:
355
+ log_theta = concentrate
356
+ else:
357
+ rate = concentrate.exp()
358
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
359
+ if self.loss_func == 'poisson':
360
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
361
+
362
+ if self.loss_func == 'negbinomial':
363
+ if self.use_zeroinflate:
364
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
365
+ else:
366
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
367
+ elif self.loss_func == 'poisson':
368
+ if self.use_zeroinflate:
369
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
370
+ else:
371
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
372
+ elif self.loss_func == 'multinomial':
373
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
374
+ elif self.loss_func == 'bernoulli':
375
+ if self.use_zeroinflate:
376
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
377
+ else:
378
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
379
+
380
+ def guide1(self, xs):
381
+ with pyro.plate('data'):
382
+ #zn_loc, zn_scale = self.encoder_zn(xs)
383
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
384
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
385
+
386
+ alpha = self.encoder_n(zns)
387
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
388
+
389
+ def model2(self, xs, us=None):
390
+ pyro.module('DensityFlow', self)
391
+
392
+ eps = torch.finfo(xs.dtype).eps
393
+ batch_size = xs.size(0)
394
+ self.options = dict(dtype=xs.dtype, device=xs.device)
395
+
396
+ if self.loss_func=='negbinomial':
397
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
398
+ xs.new_ones(self.input_size), constraint=constraints.positive)
399
+
400
+ if self.use_zeroinflate:
401
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
402
+
403
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
404
+
405
+ I = torch.eye(self.code_size)
406
+ if self.latent_dist=='studentt':
407
+ acs_dof,acs_loc = self.codebook(I)
408
+ else:
409
+ acs_loc = self.codebook(I)
410
+
411
+ with pyro.plate('data'):
412
+ prior = torch.zeros(batch_size, self.code_size, **self.options)
413
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
414
+
415
+ zn_loc = torch.matmul(ns,acs_loc)
416
+ #zn_scale = torch.matmul(ns,acs_scale)
417
+ zn_scale = acs_scale
418
+
419
+ if self.latent_dist == 'studentt':
420
+ prior_dof = torch.matmul(ns,acs_dof)
421
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
422
+ elif self.latent_dist == 'laplacian':
423
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
424
+ elif self.latent_dist == 'cauchy':
425
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
426
+ elif self.latent_dist == 'normal':
427
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
428
+ elif self.latent_dist == 'gumbel':
429
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
430
+
431
+ if self.cell_factor_size>0:
432
+ zus = self._total_effects(zns, us)
433
+ zs = zns+zus
434
+ else:
435
+ zs = zns
436
+
437
+ concentrate = self.decoder_concentrate(zs)
438
+ if self.loss_func in ['bernoulli']:
439
+ log_theta = concentrate
440
+ else:
441
+ rate = concentrate.exp()
442
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
443
+ if self.loss_func == 'poisson':
444
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
445
+
446
+ if self.loss_func == 'negbinomial':
447
+ if self.use_zeroinflate:
448
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
449
+ else:
450
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
451
+ elif self.loss_func == 'poisson':
452
+ if self.use_zeroinflate:
453
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
454
+ else:
455
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
456
+ elif self.loss_func == 'multinomial':
457
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
458
+ elif self.loss_func == 'bernoulli':
459
+ if self.use_zeroinflate:
460
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
461
+ else:
462
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
463
+
464
+ def guide2(self, xs, us=None):
465
+ with pyro.plate('data'):
466
+ #zn_loc, zn_scale = self.encoder_zn(xs)
467
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
468
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
469
+
470
+ alpha = self.encoder_n(zns)
471
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
472
+
473
+ def model3(self, xs, ys, embeds=None):
474
+ pyro.module('DensityFlow', self)
475
+
476
+ eps = torch.finfo(xs.dtype).eps
477
+ batch_size = xs.size(0)
478
+ self.options = dict(dtype=xs.dtype, device=xs.device)
479
+
480
+ if self.loss_func=='negbinomial':
481
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
482
+ xs.new_ones(self.input_size), constraint=constraints.positive)
483
+
484
+ if self.use_zeroinflate:
485
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
486
+
487
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
488
+
489
+ I = torch.eye(self.code_size)
490
+ if self.latent_dist=='studentt':
491
+ acs_dof,acs_loc = self.codebook(I)
492
+ else:
493
+ acs_loc = self.codebook(I)
494
+
495
+ with pyro.plate('data'):
496
+ #prior = torch.zeros(batch_size, self.code_size, **self.options)
497
+ prior = self.encoder_n(xs)
498
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
499
+
500
+ zn_loc = torch.matmul(ns,acs_loc)
501
+ #prior_scale = torch.matmul(ns,acs_scale)
502
+ zn_scale = acs_scale
503
+
504
+ if self.latent_dist=='studentt':
505
+ prior_dof = torch.matmul(ns,acs_dof)
506
+ if embeds is None:
507
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
508
+ else:
509
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
510
+ elif self.latent_dist=='laplacian':
511
+ if embeds is None:
512
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
513
+ else:
514
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
515
+ elif self.latent_dist=='cauchy':
516
+ if embeds is None:
517
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
518
+ else:
519
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
520
+ elif self.latent_dist=='normal':
521
+ if embeds is None:
522
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
523
+ else:
524
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
525
+ elif self.z_dist == 'gumbel':
526
+ if embeds is None:
527
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
528
+ else:
529
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
530
+
531
+ zs = zns
532
+
533
+ concentrate = self.decoder_concentrate(zs)
534
+ if self.loss_func in ['bernoulli']:
535
+ log_theta = concentrate
536
+ else:
537
+ rate = concentrate.exp()
538
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
539
+ if self.loss_func == 'poisson':
540
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
541
+
542
+ if self.loss_func == 'negbinomial':
543
+ if self.use_zeroinflate:
544
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
545
+ else:
546
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
547
+ elif self.loss_func == 'poisson':
548
+ if self.use_zeroinflate:
549
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
550
+ else:
551
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
552
+ elif self.loss_func == 'multinomial':
553
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
554
+ elif self.loss_func == 'bernoulli':
555
+ if self.use_zeroinflate:
556
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
557
+ else:
558
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
559
+
560
+ def guide3(self, xs, ys, embeds=None):
561
+ with pyro.plate('data'):
562
+ if embeds is None:
563
+ #zn_loc, zn_scale = self.encoder_zn(xs)
564
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
565
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
566
+ else:
567
+ zns = embeds
568
+
569
+ def model4(self, xs, us, ys, embeds=None):
570
+ pyro.module('DensityFlow', self)
571
+
572
+ eps = torch.finfo(xs.dtype).eps
573
+ batch_size = xs.size(0)
574
+ self.options = dict(dtype=xs.dtype, device=xs.device)
575
+
576
+ if self.loss_func=='negbinomial':
577
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
578
+ xs.new_ones(self.input_size), constraint=constraints.positive)
579
+
580
+ if self.use_zeroinflate:
581
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
582
+
583
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
584
+
585
+ I = torch.eye(self.code_size)
586
+ if self.latent_dist=='studentt':
587
+ acs_dof,acs_loc = self.codebook(I)
588
+ else:
589
+ acs_loc = self.codebook(I)
590
+
591
+ with pyro.plate('data'):
592
+ #prior = torch.zeros(batch_size, self.code_size, **self.options)
593
+ prior = self.encoder_n(xs)
594
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
595
+
596
+ zn_loc = torch.matmul(ns,acs_loc)
597
+ #prior_scale = torch.matmul(ns,acs_scale)
598
+ zn_scale = acs_scale
599
+
600
+ if self.latent_dist=='studentt':
601
+ prior_dof = torch.matmul(ns,acs_dof)
602
+ if embeds is None:
603
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
604
+ else:
605
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
606
+ elif self.latent_dist=='laplacian':
607
+ if embeds is None:
608
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
609
+ else:
610
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
611
+ elif self.latent_dist=='cauchy':
612
+ if embeds is None:
613
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
614
+ else:
615
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
616
+ elif self.latent_dist=='normal':
617
+ if embeds is None:
618
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
619
+ else:
620
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
621
+ elif self.z_dist == 'gumbel':
622
+ if embeds is None:
623
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
624
+ else:
625
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
626
+
627
+ if self.cell_factor_size>0:
628
+ #zus = None
629
+ #for i in np.arange(self.cell_factor_size):
630
+ # if i==0:
631
+ # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
632
+ # else:
633
+ # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
634
+ zus = self._total_effects(zns, us)
635
+ zs = zns+zus
636
+ else:
637
+ zs = zns
638
+
639
+ concentrate = self.decoder_concentrate(zs)
640
+ if self.loss_func in ['bernoulli']:
641
+ log_theta = concentrate
642
+ else:
643
+ rate = concentrate.exp()
644
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
645
+ if self.loss_func == 'poisson':
646
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
647
+
648
+ if self.loss_func == 'negbinomial':
649
+ if self.use_zeroinflate:
650
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
651
+ else:
652
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
653
+ elif self.loss_func == 'poisson':
654
+ if self.use_zeroinflate:
655
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
656
+ else:
657
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
658
+ elif self.loss_func == 'multinomial':
659
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
660
+ elif self.loss_func == 'bernoulli':
661
+ if self.use_zeroinflate:
662
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
663
+ else:
664
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
665
+
666
+ def guide4(self, xs, us, ys, embeds=None):
667
+ with pyro.plate('data'):
668
+ if embeds is None:
669
+ #zn_loc, zn_scale = self.encoder_zn(xs)
670
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
671
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
672
+ else:
673
+ zns = embeds
674
+
675
+ def _total_effects(self, zns, us):
676
+ zus = None
677
+ for i in np.arange(self.cell_factor_size):
678
+ if i==0:
679
+ zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
680
+ else:
681
+ zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
682
+ return zus
683
+
684
+ def _get_codebook_identity(self):
685
+ return torch.eye(self.code_size, **self.options)
686
+
687
+ def _get_codebook(self):
688
+ I = torch.eye(self.code_size, **self.options)
689
+ if self.latent_dist=='studentt':
690
+ _,cb = self.codebook(I)
691
+ else:
692
+ cb = self.codebook(I)
693
+ return cb
694
+
695
+ def get_codebook(self):
696
+ """
697
+ Return the mean part of metacell codebook
698
+ """
699
+ cb = self._get_metacell_coordinates()
700
+ cb = tensor_to_numpy(cb)
701
+ return cb
702
+
703
+ def _get_basal_embedding(self, xs):
704
+ loc, scale = self.encoder_zn(xs)
705
+ return loc, scale
706
+
707
+ def get_basal_embedding(self,
708
+ xs,
709
+ batch_size: int = 1024):
710
+ """
711
+ Return cells' basal latent representations
712
+
713
+ Parameters
714
+ ----------
715
+ xs
716
+ Single-cell expression matrix. It should be a Numpy array or a Pytorch Tensor.
717
+ batch_size
718
+ Size of batch processing.
719
+ use_decoder
720
+ If toggled on, the latent representations will be reconstructed from the metacell codebook
721
+ soft_assign
722
+ If toggled on, the assignments of cells will use probabilistic values.
723
+ """
724
+ xs = self.preprocess(xs)
725
+ xs = convert_to_tensor(xs, device=self.get_device())
726
+ dataset = CustomDataset(xs)
727
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
728
+
729
+ Z = []
730
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
731
+ for X_batch, _ in dataloader:
732
+ zns,_ = self._get_basal_embedding(X_batch)
733
+ Z.append(tensor_to_numpy(zns))
734
+ pbar.update(1)
735
+
736
+ Z = np.concatenate(Z)
737
+ return Z
738
+
739
+ def _code(self, xs):
740
+ if self.supervised_mode:
741
+ alpha = self.encoder_n(xs)
742
+ else:
743
+ #zns,_ = self.encoder_zn(xs)
744
+ zns,_ = self._get_basal_embedding(xs)
745
+ alpha = self.encoder_n(zns)
746
+ return alpha
747
+
748
+ def code(self, xs, batch_size=1024):
749
+ xs = self.preprocess(xs)
750
+ xs = convert_to_tensor(xs, device=self.get_device())
751
+ dataset = CustomDataset(xs)
752
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
753
+
754
+ A = []
755
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
756
+ for X_batch, _ in dataloader:
757
+ a = self._code(X_batch)
758
+ A.append(tensor_to_numpy(a))
759
+ pbar.update(1)
760
+
761
+ A = np.concatenate(A)
762
+ return A
763
+
764
+ def _soft_assignments(self, xs):
765
+ alpha = self._code(xs)
766
+ alpha = self.softmax(alpha)
767
+ return alpha
768
+
769
+ def soft_assignments(self, xs, batch_size=1024):
770
+ """
771
+ Map cells to metacells and return the probabilistic values of metacell assignments
772
+ """
773
+ xs = self.preprocess(xs)
774
+ xs = convert_to_tensor(xs, device=self.get_device())
775
+ dataset = CustomDataset(xs)
776
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
777
+
778
+ A = []
779
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
780
+ for X_batch, _ in dataloader:
781
+ a = self._soft_assignments(X_batch)
782
+ A.append(tensor_to_numpy(a))
783
+ pbar.update(1)
784
+
785
+ A = np.concatenate(A)
786
+ return A
787
+
788
+ def _hard_assignments(self, xs):
789
+ alpha = self._code(xs)
790
+ res, ind = torch.topk(alpha, 1)
791
+ ns = torch.zeros_like(alpha).scatter_(1, ind, 1.0)
792
+ return ns
793
+
794
+ def hard_assignments(self, xs, batch_size=1024):
795
+ """
796
+ Map cells to metacells and return the assigned metacell identities.
797
+ """
798
+ xs = self.preprocess(xs)
799
+ xs = convert_to_tensor(xs, device=self.get_device())
800
+ dataset = CustomDataset(xs)
801
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
802
+
803
+ A = []
804
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
805
+ for X_batch, _ in dataloader:
806
+ a = self._hard_assignments(X_batch)
807
+ A.append(tensor_to_numpy(a))
808
+ pbar.update(1)
809
+
810
+ A = np.concatenate(A)
811
+ return A
812
+
813
+ def predict(self, xs, us, perturbs_predict:list, perturbs_reference:list, library_sizes=None):
814
+ perturbs_reference = np.array(perturbs_reference)
815
+
816
+ # basal embedding
817
+ zs = self.get_basal_embedding(xs)
818
+ for pert in perturbs_predict:
819
+ pert_idx = int(np.where(perturbs_reference==pert)[0])
820
+ us_i = us[:,pert_idx].reshape(-1,1)
821
+
822
+ # factor effect of xs
823
+ dzs0 = self.get_cell_response(xs, factor_idx=pert_idx, perturb=us_i)
824
+
825
+ # perturbation effect
826
+ ps = np.ones_like(us_i)
827
+ dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
828
+
829
+ zs = zs + dzs0 + dzs
830
+
831
+ if library_sizes is None:
832
+ library_sizes = np.sum(xs, axis=1, keepdims=True)
833
+ elif type(library_sizes) == list:
834
+ library_sizes = np.array(library_sizes)
835
+ library_sizes = library_sizes.reshape(-1,1)
836
+ elif len(library_sizes.shape)==1:
837
+ library_sizes = library_sizes.reshape(-1,1)
838
+
839
+ counts = self.get_counts(zs, library_sizes=library_sizes)
840
+
841
+ return counts, zs
842
+
843
+ def _cell_response(self, xs, factor_idx, perturb):
844
+ #zns,_ = self.encoder_zn(xs)
845
+ zns,_ = self._get_basal_embedding(xs)
846
+ if perturb.ndim==2:
847
+ ms = self.cell_factor_effect[factor_idx]([zns, perturb])
848
+ else:
849
+ ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
850
+
851
+ return ms
852
+
853
+ def get_cell_response(self,
854
+ xs,
855
+ factor_idx,
856
+ perturb,
857
+ batch_size: int = 1024):
858
+ """
859
+ Return cells' changes in the latent space induced by specific perturbation of a factor
860
+
861
+ """
862
+ xs = self.preprocess(xs)
863
+ xs = convert_to_tensor(xs, device=self.get_device())
864
+ ps = convert_to_tensor(perturb, device=self.get_device())
865
+ dataset = CustomDataset2(xs,ps)
866
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
867
+
868
+ Z = []
869
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
870
+ for X_batch, P_batch, _ in dataloader:
871
+ zns = self._cell_response(X_batch, factor_idx, P_batch)
872
+ Z.append(tensor_to_numpy(zns))
873
+ pbar.update(1)
874
+
875
+ Z = np.concatenate(Z)
876
+ return Z
877
+
878
+ def get_metacell_response(self, factor_idx, perturb):
879
+ zs = self._get_codebook()
880
+ ps = convert_to_tensor(perturb, device=self.get_device())
881
+ ms = self.cell_factor_effect[factor_idx]([zs,ps])
882
+ return tensor_to_numpy(ms)
883
+
884
+ def _get_expression_response(self, delta_zs):
885
+ return self.decoder_concentrate(delta_zs)
886
+
887
+ def get_expression_response(self,
888
+ delta_zs,
889
+ batch_size: int = 1024):
890
+ """
891
+ Return cells' changes in the feature space induced by specific perturbation of a factor
892
+
893
+ """
894
+ delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
895
+ dataset = CustomDataset(delta_zs)
896
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
897
+
898
+ R = []
899
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
900
+ for delta_Z_batch, _ in dataloader:
901
+ r = self._get_expression_response(delta_Z_batch)
902
+ R.append(tensor_to_numpy(r))
903
+ pbar.update(1)
904
+
905
+ R = np.concatenate(R)
906
+ return R
907
+
908
+ def _count(self,concentrate, library_size=None):
909
+ if self.loss_func == 'bernoulli':
910
+ #counts = self.sigmoid(concentrate)
911
+ counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
912
+ else:
913
+ rate = concentrate.exp()
914
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
915
+ counts = theta * library_size
916
+ #counts = dist.Poisson(rate=rate).to_event(1).mean
917
+ return counts
918
+
919
+ def _count_sample(self,concentrate):
920
+ if self.loss_func == 'bernoulli':
921
+ logits = concentrate
922
+ counts = dist.Bernoulli(logits=logits).to_event(1).sample()
923
+ else:
924
+ counts = self._count(concentrate=concentrate)
925
+ counts = dist.Poisson(rate=counts).to_event(1).sample()
926
+ return counts
927
+
928
+ def get_counts(self, zs, library_sizes,
929
+ batch_size: int = 1024,
930
+ use_sampler: bool = False):
931
+
932
+ zs = convert_to_tensor(zs, device=self.get_device())
933
+
934
+ if type(library_sizes) == list:
935
+ library_sizes = np.array(library_sizes).view(-1,1)
936
+ elif len(library_sizes.shape)==1:
937
+ library_sizes = library_sizes.view(-1,1)
938
+ ls = convert_to_tensor(library_sizes, device=self.get_device())
939
+
940
+ dataset = CustomDataset2(zs,ls)
941
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
942
+
943
+ E = []
944
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
945
+ for Z_batch, L_batch, _ in dataloader:
946
+ concentrate = self._get_expression_response(Z_batch)
947
+ if use_sampler:
948
+ counts = self._count_sample(concentrate)
949
+ else:
950
+ counts = self._count(concentrate, L_batch)
951
+ E.append(tensor_to_numpy(counts))
952
+ pbar.update(1)
953
+
954
+ E = np.concatenate(E)
955
+ return E
956
+
957
+ def preprocess(self, xs, threshold=0):
958
+ if self.loss_func == 'bernoulli':
959
+ ad = sc.AnnData(xs)
960
+ binarize(ad, threshold=threshold)
961
+ xs = ad.X.copy()
962
+ else:
963
+ xs = np.round(xs)
964
+
965
+ if sparse.issparse(xs):
966
+ xs = xs.toarray()
967
+ return xs
968
+
969
+ def fit(self, xs,
970
+ us = None,
971
+ ys = None,
972
+ zs = None,
973
+ num_epochs: int = 200,
974
+ learning_rate: float = 0.0001,
975
+ batch_size: int = 256,
976
+ algo: Literal['adam','rmsprop','adamw'] = 'adam',
977
+ beta_1: float = 0.9,
978
+ weight_decay: float = 0.005,
979
+ decay_rate: float = 0.9,
980
+ config_enum: str = 'parallel',
981
+ threshold: int = 0,
982
+ use_jax: bool = True):
983
+ """
984
+ Train the DensityFlow model.
985
+
986
+ Parameters
987
+ ----------
988
+ xs
989
+ Single-cell experssion matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are features.
990
+ us
991
+ cell-level factor matrix.
992
+ ys
993
+ Desired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are desired factors.
994
+ num_epochs
995
+ Number of training epochs.
996
+ learning_rate
997
+ Parameter for training.
998
+ batch_size
999
+ Size of batch processing.
1000
+ algo
1001
+ Optimization algorithm.
1002
+ beta_1
1003
+ Parameter for optimization.
1004
+ weight_decay
1005
+ Parameter for optimization.
1006
+ decay_rate
1007
+ Parameter for optimization.
1008
+ use_jax
1009
+ If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
1010
+ the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
1011
+ """
1012
+ xs = self.preprocess(xs, threshold=threshold)
1013
+ xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
1014
+ if us is not None:
1015
+ us = convert_to_tensor(us, dtype=self.dtype, device=self.get_device())
1016
+ if ys is not None:
1017
+ ys = convert_to_tensor(ys, dtype=self.dtype, device=self.get_device())
1018
+ if zs is not None:
1019
+ zs = convert_to_tensor(zs, dtype=self.dtype, device=self.get_device())
1020
+
1021
+ dataset = CustomDataset4(xs, us, ys, zs)
1022
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
1023
+
1024
+ # setup the optimizer
1025
+ optim_params = {'lr': learning_rate, 'betas': (beta_1, 0.999), 'weight_decay': weight_decay}
1026
+
1027
+ if algo.lower()=='rmsprop':
1028
+ optimizer = torch.optim.RMSprop
1029
+ elif algo.lower()=='adam':
1030
+ optimizer = torch.optim.Adam
1031
+ elif algo.lower() == 'adamw':
1032
+ optimizer = torch.optim.AdamW
1033
+ else:
1034
+ raise ValueError("An optimization algorithm must be specified.")
1035
+ scheduler = ExponentialLR({'optimizer': optimizer, 'optim_args': optim_params, 'gamma': decay_rate})
1036
+
1037
+ pyro.clear_param_store()
1038
+
1039
+ # set up the loss(es) for inference, wrapping the guide in config_enumerate builds the loss as a sum
1040
+ # by enumerating each class label form the sampled discrete categorical distribution in the model
1041
+ Elbo = JitTraceEnum_ELBO if use_jax else TraceEnum_ELBO
1042
+ elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False)
1043
+ if us is None:
1044
+ if ys is None:
1045
+ guide = config_enumerate(self.guide1, config_enum, expand=True)
1046
+ loss_basic = SVI(self.model1, guide, scheduler, loss=elbo)
1047
+ else:
1048
+ guide = config_enumerate(self.guide3, config_enum, expand=True)
1049
+ loss_basic = SVI(self.model3, guide, scheduler, loss=elbo)
1050
+ else:
1051
+ if ys is None:
1052
+ guide = config_enumerate(self.guide2, config_enum, expand=True)
1053
+ loss_basic = SVI(self.model2, guide, scheduler, loss=elbo)
1054
+ else:
1055
+ guide = config_enumerate(self.guide4, config_enum, expand=True)
1056
+ loss_basic = SVI(self.model4, guide, scheduler, loss=elbo)
1057
+
1058
+ # build a list of all losses considered
1059
+ losses = [loss_basic]
1060
+ num_losses = len(losses)
1061
+
1062
+ with tqdm(total=num_epochs, desc='Training', unit='epoch') as pbar:
1063
+ for epoch in range(num_epochs):
1064
+ epoch_losses = [0.0] * num_losses
1065
+ for batch_x, batch_u, batch_y, batch_z, _ in dataloader:
1066
+ if us is None:
1067
+ batch_u = None
1068
+ if ys is None:
1069
+ batch_y = None
1070
+ if zs is None:
1071
+ batch_z = None
1072
+
1073
+ for loss_id in range(num_losses):
1074
+ if batch_u is None:
1075
+ if batch_y is None:
1076
+ new_loss = losses[loss_id].step(batch_x)
1077
+ else:
1078
+ new_loss = losses[loss_id].step(batch_x, batch_y, batch_z)
1079
+ else:
1080
+ if batch_y is None:
1081
+ new_loss = losses[loss_id].step(batch_x, batch_u)
1082
+ else:
1083
+ new_loss = losses[loss_id].step(batch_x, batch_u, batch_y, batch_z)
1084
+ epoch_losses[loss_id] += new_loss
1085
+
1086
+ avg_epoch_losses_ = map(lambda v: v / len(dataloader), epoch_losses)
1087
+ avg_epoch_losses = map(lambda v: "{:.4f}".format(v), avg_epoch_losses_)
1088
+
1089
+ # store the loss
1090
+ str_loss = " ".join(map(str, avg_epoch_losses))
1091
+
1092
+ # Update progress bar
1093
+ pbar.set_postfix({'loss': str_loss})
1094
+ pbar.update(1)
1095
+
1096
+ @classmethod
1097
+ def save_model(cls, model, file_path, compression=False):
1098
+ """Save the model to the specified file path."""
1099
+ file_path = os.path.abspath(file_path)
1100
+
1101
+ model.eval()
1102
+ if compression:
1103
+ with gzip.open(file_path, 'wb') as pickle_file:
1104
+ pickle.dump(model, pickle_file)
1105
+ else:
1106
+ with open(file_path, 'wb') as pickle_file:
1107
+ pickle.dump(model, pickle_file)
1108
+
1109
+ print(f'Model saved to {file_path}')
1110
+
1111
+ @classmethod
1112
+ def load_model(cls, file_path):
1113
+ """Load the model from the specified file path and return an instance."""
1114
+ print(f'Model loaded from {file_path}')
1115
+
1116
+ file_path = os.path.abspath(file_path)
1117
+ if file_path.endswith('gz'):
1118
+ with gzip.open(file_path, 'rb') as pickle_file:
1119
+ model = pickle.load(pickle_file)
1120
+ else:
1121
+ with open(file_path, 'rb') as pickle_file:
1122
+ model = pickle.load(pickle_file)
1123
+
1124
+ return model
1125
+
1126
+
1127
+ EXAMPLE_RUN = (
1128
+ "example run: DensityFlow --help"
1129
+ )
1130
+
1131
+ def parse_args():
1132
+ parser = argparse.ArgumentParser(
1133
+ description="DensityFlow\n{}".format(EXAMPLE_RUN))
1134
+
1135
+ parser.add_argument(
1136
+ "--cuda", action="store_true", help="use GPU(s) to speed up training"
1137
+ )
1138
+ parser.add_argument(
1139
+ "--jit", action="store_true", help="use PyTorch jit to speed up training"
1140
+ )
1141
+ parser.add_argument(
1142
+ "-n", "--num-epochs", default=200, type=int, help="number of epochs to run"
1143
+ )
1144
+ parser.add_argument(
1145
+ "-enum",
1146
+ "--enum-discrete",
1147
+ default="parallel",
1148
+ help="parallel, sequential or none. uses parallel enumeration by default",
1149
+ )
1150
+ parser.add_argument(
1151
+ "-data",
1152
+ "--data-file",
1153
+ default=None,
1154
+ type=str,
1155
+ help="the data file",
1156
+ )
1157
+ parser.add_argument(
1158
+ "-cf",
1159
+ "--cell-factor-file",
1160
+ default=None,
1161
+ type=str,
1162
+ help="the file for the record of cell-level factors",
1163
+ )
1164
+ parser.add_argument(
1165
+ "-bs",
1166
+ "--batch-size",
1167
+ default=1000,
1168
+ type=int,
1169
+ help="number of cells to be considered in a batch",
1170
+ )
1171
+ parser.add_argument(
1172
+ "-lr",
1173
+ "--learning-rate",
1174
+ default=0.0001,
1175
+ type=float,
1176
+ help="learning rate for Adam optimizer",
1177
+ )
1178
+ parser.add_argument(
1179
+ "-cs",
1180
+ "--codebook-size",
1181
+ default=100,
1182
+ type=int,
1183
+ help="size of vector quantization codebook",
1184
+ )
1185
+ parser.add_argument(
1186
+ "--z-dist",
1187
+ default='gumbel',
1188
+ type=str,
1189
+ choices=['normal','laplacian','studentt','gumbel','cauchy'],
1190
+ help="distribution model for latent representation",
1191
+ )
1192
+ parser.add_argument(
1193
+ "-zd",
1194
+ "--z-dim",
1195
+ default=10,
1196
+ type=int,
1197
+ help="size of the tensor representing the latent variable z variable",
1198
+ )
1199
+ parser.add_argument(
1200
+ "-likeli",
1201
+ "--likelihood",
1202
+ default='negbinomial',
1203
+ type=str,
1204
+ choices=['negbinomial', 'multinomial', 'poisson', 'bernoulli'],
1205
+ help="specify the distribution likelihood function",
1206
+ )
1207
+ parser.add_argument(
1208
+ "-zi",
1209
+ "--zeroinflate",
1210
+ action="store_true",
1211
+ help="use zero-inflated estimation",
1212
+ )
1213
+ parser.add_argument(
1214
+ "-id",
1215
+ "--inverse-dispersion",
1216
+ default=10.0,
1217
+ type=float,
1218
+ help="inverse dispersion prior for negative binomial",
1219
+ )
1220
+ parser.add_argument(
1221
+ "-hl",
1222
+ "--hidden-layers",
1223
+ nargs="+",
1224
+ default=[500],
1225
+ type=int,
1226
+ help="a tuple (or list) of MLP layers to be used in the neural networks "
1227
+ "representing the parameters of the distributions in our model",
1228
+ )
1229
+ parser.add_argument(
1230
+ "-hla",
1231
+ "--hidden-layer-activation",
1232
+ default='relu',
1233
+ type=str,
1234
+ choices=['relu','softplus','leakyrelu','linear'],
1235
+ help="activation function for hidden layers",
1236
+ )
1237
+ parser.add_argument(
1238
+ "-plf",
1239
+ "--post-layer-function",
1240
+ nargs="+",
1241
+ default=['layernorm'],
1242
+ type=str,
1243
+ help="post functions for hidden layers, could be none, dropout, layernorm, batchnorm, or combination, default is 'dropout layernorm'",
1244
+ )
1245
+ parser.add_argument(
1246
+ "-paf",
1247
+ "--post-activation-function",
1248
+ nargs="+",
1249
+ default=['none'],
1250
+ type=str,
1251
+ help="post functions for activation layers, could be none or dropout, default is 'none'",
1252
+ )
1253
+ parser.add_argument(
1254
+ "-64",
1255
+ "--float64",
1256
+ action="store_true",
1257
+ help="use double float precision",
1258
+ )
1259
+ parser.add_argument(
1260
+ "-dr",
1261
+ "--decay-rate",
1262
+ default=0.9,
1263
+ type=float,
1264
+ help="decay rate for Adam optimizer",
1265
+ )
1266
+ parser.add_argument(
1267
+ "--layer-dropout-rate",
1268
+ default=0.1,
1269
+ type=float,
1270
+ help="droput rate for neural networks",
1271
+ )
1272
+ parser.add_argument(
1273
+ "-b1",
1274
+ "--beta-1",
1275
+ default=0.95,
1276
+ type=float,
1277
+ help="beta-1 parameter for Adam optimizer",
1278
+ )
1279
+ parser.add_argument(
1280
+ "--seed",
1281
+ default=None,
1282
+ type=int,
1283
+ help="seed for controlling randomness in this example",
1284
+ )
1285
+ parser.add_argument(
1286
+ "--save-model",
1287
+ default=None,
1288
+ type=str,
1289
+ help="path to save model for prediction",
1290
+ )
1291
+ args = parser.parse_args()
1292
+ return args
1293
+
1294
+ def main():
1295
+ args = parse_args()
1296
+ assert (
1297
+ (args.data_file is not None) and (
1298
+ os.path.exists(args.data_file))
1299
+ ), "data file must be provided"
1300
+
1301
+ if args.seed is not None:
1302
+ set_random_seed(args.seed)
1303
+
1304
+ if args.float64:
1305
+ dtype = torch.float64
1306
+ torch.set_default_dtype(torch.float64)
1307
+ else:
1308
+ dtype = torch.float32
1309
+ torch.set_default_dtype(torch.float32)
1310
+
1311
+ xs = dt.fread(file=args.data_file, header=True).to_numpy()
1312
+ us = None
1313
+ if args.cell_factor_file is not None:
1314
+ us = dt.fread(file=args.cell_factor_file, header=True).to_numpy()
1315
+
1316
+ input_size = xs.shape[1]
1317
+ cell_factor_size = 0 if us is None else us.shape[1]
1318
+
1319
+ ###########################################
1320
+ DensityFlow = DensityFlow(
1321
+ input_size=input_size,
1322
+ cell_factor_size=cell_factor_size,
1323
+ inverse_dispersion=args.inverse_dispersion,
1324
+ z_dim=args.z_dim,
1325
+ hidden_layers=args.hidden_layers,
1326
+ hidden_layer_activation=args.hidden_layer_activation,
1327
+ use_cuda=args.cuda,
1328
+ config_enum=args.enum_discrete,
1329
+ use_zeroinflate=args.zeroinflate,
1330
+ loss_func=args.likelihood,
1331
+ nn_dropout=args.layer_dropout_rate,
1332
+ post_layer_fct=args.post_layer_function,
1333
+ post_act_fct=args.post_activation_function,
1334
+ codebook_size=args.codebook_size,
1335
+ z_dist = args.z_dist,
1336
+ dtype=dtype,
1337
+ )
1338
+
1339
+ DensityFlow.fit(xs, us=us,
1340
+ num_epochs=args.num_epochs,
1341
+ learning_rate=args.learning_rate,
1342
+ batch_size=args.batch_size,
1343
+ beta_1=args.beta_1,
1344
+ decay_rate=args.decay_rate,
1345
+ use_jax=args.jit,
1346
+ config_enum=args.enum_discrete,
1347
+ )
1348
+
1349
+ if args.save_model is not None:
1350
+ if args.save_model.endswith('gz'):
1351
+ DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
1352
+ else:
1353
+ DensityFlow.save_model(DensityFlow, args.save_model)
1354
+
1355
+
1356
+
1357
+ if __name__ == "__main__":
1358
+
1359
+ main()