SURE-tools 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/SURE.py ADDED
@@ -0,0 +1,1203 @@
1
+ import pyro
2
+ import pyro.distributions as dist
3
+ from pyro.optim import ExponentialLR
4
+ from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, config_enumerate
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader
9
+ from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_probs
10
+ from torch.distributions import constraints
11
+ from torch.distributions.transforms import SoftmaxTransform
12
+
13
+ from .utils.custom_mlp import MLP, Exp
14
+ from .utils.utils import CustomDataset, CustomDataset3, CustomDataset4, tensor_to_numpy, convert_to_tensor
15
+
16
+
17
+ import os
18
+ import argparse
19
+ import random
20
+ import numpy as np
21
+ import datatable as dt
22
+ from tqdm import tqdm
23
+ from scipy import sparse
24
+
25
+ from typing import Literal
26
+
27
+ import warnings
28
+ warnings.filterwarnings("ignore")
29
+
30
+ import dill as pickle
31
+ import gzip
32
+ from packaging.version import Version
33
+ torch_version = torch.__version__
34
+
35
+
36
+ def set_random_seed(seed):
37
+ # Set seed for PyTorch
38
+ torch.manual_seed(seed)
39
+
40
+ # If using CUDA, set the seed for CUDA
41
+ if torch.cuda.is_available():
42
+ torch.cuda.manual_seed(seed)
43
+ torch.cuda.manual_seed_all(seed) # For multi-GPU setups.
44
+
45
+ # Set seed for NumPy
46
+ np.random.seed(seed)
47
+
48
+ # Set seed for Python's random module
49
+ random.seed(seed)
50
+
51
+ # Set seed for Pyro
52
+ pyro.set_rng_seed(seed)
53
+
54
+ class SURE(nn.Module):
55
+ """SUccinct REpresentation of single-omics cells
56
+
57
+ Parameters
58
+ ----------
59
+ inpute_size
60
+ Number of features (e.g., genes, peaks, proteins, etc.) per cell.
61
+ undesired_size
62
+ Number of undesired factors. It would be used to adjust for undesired variations like batch effect.
63
+ codebook_size
64
+ Number of metacells.
65
+ latent_dim
66
+ Dimensionality of latent states and metacells.
67
+ hidden_layers
68
+ A list give the numbers of neurons for each hidden layer.
69
+ loss_func
70
+ The likelihood model for single-cell data generation.
71
+
72
+ One of the following:
73
+ * ``'negbinomial'`` - negative binomial distribution (default)
74
+ * ``'poisson'`` - poisson distribution
75
+ * ``'multinomial'`` - multinomial distribution
76
+ user_dirichlet
77
+ A boolean option. If toggled on, SURE characterizes single-cell data using a hierarchical model, such as
78
+ dirichlet-negative binomial.
79
+ latent_dist
80
+ The distribution model for latent states.
81
+
82
+ One of the following:
83
+ * ``'normal'`` - normal distribution
84
+ * ``'laplacian'`` - Laplacian distribution
85
+ * ``'studentt'`` - Student-t distribution.
86
+ use_cuda
87
+ A boolean option for switching on cuda device.
88
+
89
+ Examples
90
+ --------
91
+ >>>
92
+ >>>
93
+ >>>
94
+
95
+ """
96
+ def __init__(self,
97
+ input_size: int,
98
+ codebook_size: int = 200,
99
+ cell_factor_size: int = 0,
100
+ supervised_mode: bool = False,
101
+ z_dim: int = 10,
102
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
103
+ loss_func: Literal['negbinomial','poisson','multinomial'] = 'negbinomial',
104
+ inverse_dispersion: float = 10.0,
105
+ use_zeroinflate: bool = False,
106
+ hidden_layers: list = [500],
107
+ hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
108
+ nn_dropout: float = 0.1,
109
+ post_layer_fct: list = ['layernorm'],
110
+ post_act_fct: list = None,
111
+ config_enum: str = 'parallel',
112
+ use_cuda: bool = False,
113
+ seed: int = 42,
114
+ dtype = torch.float32, # type: ignore
115
+ ):
116
+ super().__init__()
117
+
118
+ self.input_size = input_size
119
+ self.cell_factor_size = cell_factor_size
120
+ self.inverse_dispersion = inverse_dispersion
121
+ self.latent_dim = z_dim
122
+ self.hidden_layers = hidden_layers
123
+ self.decoder_hidden_layers = hidden_layers[::-1]
124
+ self.allow_broadcast = config_enum == 'parallel'
125
+ self.use_cuda = use_cuda
126
+ self.loss_func = loss_func
127
+ self.options = None
128
+ self.code_size=codebook_size
129
+ self.supervised_mode=supervised_mode
130
+ self.latent_dist = z_dist
131
+ self.dtype = dtype
132
+ self.use_zeroinflate=use_zeroinflate
133
+ self.nn_dropout = nn_dropout
134
+ self.post_layer_fct = post_layer_fct
135
+ self.post_act_fct = post_act_fct
136
+ self.hidden_layer_activation = hidden_layer_activation
137
+
138
+ self.codebook_weights = None
139
+
140
+ set_random_seed(seed)
141
+ self.setup_networks()
142
+
143
+ def setup_networks(self):
144
+ latent_dim = self.latent_dim
145
+ hidden_sizes = self.hidden_layers
146
+
147
+ nn_layer_norm, nn_batch_norm, nn_layer_dropout = False, False, False
148
+ na_layer_norm, na_batch_norm, na_layer_dropout = False, False, False
149
+
150
+ if self.post_layer_fct is not None:
151
+ nn_layer_norm=True if ('layernorm' in self.post_layer_fct) or ('layer_norm' in self.post_layer_fct) else False
152
+ nn_batch_norm=True if ('batchnorm' in self.post_layer_fct) or ('batch_norm' in self.post_layer_fct) else False
153
+ nn_layer_dropout=True if 'dropout' in self.post_layer_fct else False
154
+
155
+ if self.post_act_fct is not None:
156
+ na_layer_norm=True if ('layernorm' in self.post_act_fct) or ('layer_norm' in self.post_act_fct) else False
157
+ na_batch_norm=True if ('batchnorm' in self.post_act_fct) or ('batch_norm' in self.post_act_fct) else False
158
+ na_layer_dropout=True if 'dropout' in self.post_act_fct else False
159
+
160
+ if nn_layer_norm and nn_batch_norm and nn_layer_dropout:
161
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
162
+ elif nn_layer_norm and nn_layer_dropout:
163
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
164
+ elif nn_batch_norm and nn_layer_dropout:
165
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
166
+ elif nn_layer_norm and nn_batch_norm:
167
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
168
+ elif nn_layer_norm:
169
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
170
+ elif nn_batch_norm:
171
+ post_layer_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
172
+ elif nn_layer_dropout:
173
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
174
+ else:
175
+ post_layer_fct = lambda layer_ix, total_layers, layer: None
176
+
177
+ if na_layer_norm and na_batch_norm and na_layer_dropout:
178
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
179
+ elif na_layer_norm and na_layer_dropout:
180
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
181
+ elif na_batch_norm and na_layer_dropout:
182
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
183
+ elif na_layer_norm and na_batch_norm:
184
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
185
+ elif na_layer_norm:
186
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
187
+ elif na_batch_norm:
188
+ post_act_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
189
+ elif na_layer_dropout:
190
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
191
+ else:
192
+ post_act_fct = lambda layer_ix, total_layers, layer: None
193
+
194
+ if self.hidden_layer_activation == 'relu':
195
+ activate_fct = nn.ReLU
196
+ elif self.hidden_layer_activation == 'softplus':
197
+ activate_fct = nn.Softplus
198
+ elif self.hidden_layer_activation == 'leakyrelu':
199
+ activate_fct = nn.LeakyReLU
200
+ elif self.hidden_layer_activation == 'linear':
201
+ activate_fct = nn.Identity
202
+
203
+ if self.supervised_mode:
204
+ self.encoder_n = MLP(
205
+ [self.input_size] + hidden_sizes + [self.code_size],
206
+ activation=activate_fct,
207
+ output_activation=None,
208
+ post_layer_fct=post_layer_fct,
209
+ post_act_fct=post_act_fct,
210
+ allow_broadcast=self.allow_broadcast,
211
+ use_cuda=self.use_cuda,
212
+ )
213
+ else:
214
+ self.encoder_n = MLP(
215
+ [self.latent_dim] + hidden_sizes + [self.code_size],
216
+ activation=activate_fct,
217
+ output_activation=None,
218
+ post_layer_fct=post_layer_fct,
219
+ post_act_fct=post_act_fct,
220
+ allow_broadcast=self.allow_broadcast,
221
+ use_cuda=self.use_cuda,
222
+ )
223
+
224
+ self.encoder_zn = MLP(
225
+ [self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
226
+ activation=activate_fct,
227
+ output_activation=[None, Exp],
228
+ post_layer_fct=post_layer_fct,
229
+ post_act_fct=post_act_fct,
230
+ allow_broadcast=self.allow_broadcast,
231
+ use_cuda=self.use_cuda,
232
+ )
233
+
234
+ if self.cell_factor_size>0:
235
+ self.cell_factor_effect = MLP(
236
+ [self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
237
+ activation=activate_fct,
238
+ output_activation=None,
239
+ post_layer_fct=post_layer_fct,
240
+ post_act_fct=post_act_fct,
241
+ allow_broadcast=self.allow_broadcast,
242
+ use_cuda=self.use_cuda,
243
+ )
244
+
245
+ self.decoder_concentrate = MLP(
246
+ [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
247
+ activation=activate_fct,
248
+ output_activation=None,
249
+ post_layer_fct=post_layer_fct,
250
+ post_act_fct=post_act_fct,
251
+ allow_broadcast=self.allow_broadcast,
252
+ use_cuda=self.use_cuda,
253
+ )
254
+
255
+ if self.latent_dist == 'studentt':
256
+ self.codebook = MLP(
257
+ [self.code_size] + hidden_sizes + [[latent_dim,latent_dim]],
258
+ activation=activate_fct,
259
+ output_activation=[Exp,None],
260
+ post_layer_fct=post_layer_fct,
261
+ post_act_fct=post_act_fct,
262
+ allow_broadcast=self.allow_broadcast,
263
+ use_cuda=self.use_cuda,
264
+ )
265
+ else:
266
+ self.codebook = MLP(
267
+ [self.code_size] + hidden_sizes + [latent_dim],
268
+ activation=activate_fct,
269
+ output_activation=None,
270
+ post_layer_fct=post_layer_fct,
271
+ post_act_fct=post_act_fct,
272
+ allow_broadcast=self.allow_broadcast,
273
+ use_cuda=self.use_cuda,
274
+ )
275
+
276
+ if self.use_cuda:
277
+ self.cuda()
278
+
279
+ def get_device(self):
280
+ return next(self.parameters()).device
281
+
282
+ def cutoff(self, xs, thresh=None):
283
+ eps = torch.finfo(xs.dtype).eps
284
+
285
+ if not thresh is None:
286
+ if eps < thresh:
287
+ eps = thresh
288
+
289
+ xs = xs.clamp(min=eps)
290
+
291
+ if torch.any(torch.isnan(xs)):
292
+ xs[torch.isnan(xs)] = eps
293
+
294
+ return xs
295
+
296
+ def softmax(self, xs):
297
+ #xs = SoftmaxTransform()(xs)
298
+ xs = dist.Multinomial(total_count=1, logits=xs).mean
299
+ return xs
300
+
301
+ def sigmoid(self, xs):
302
+ #sigm_enc = nn.Sigmoid()
303
+ #xs = sigm_enc(xs)
304
+ #xs = clamp_probs(xs)
305
+ xs = dist.Bernoulli(logits=xs).mean
306
+ return xs
307
+
308
+ def softmax_logit(self, xs):
309
+ eps = torch.finfo(xs.dtype).eps
310
+ xs = self.softmax(xs)
311
+ xs = torch.logit(xs, eps=eps)
312
+ return xs
313
+
314
+ def logit(self, xs):
315
+ eps = torch.finfo(xs.dtype).eps
316
+ xs = torch.logit(xs, eps=eps)
317
+ return xs
318
+
319
+ def dirimulti_param(self, xs):
320
+ xs = self.dirimulti_mass * self.sigmoid(xs)
321
+ return xs
322
+
323
+ def multi_param(self, xs):
324
+ xs = self.softmax(xs)
325
+ return xs
326
+
327
+ def model1(self, xs):
328
+ pyro.module('sure', self)
329
+
330
+ eps = torch.finfo(xs.dtype).eps
331
+ batch_size = xs.size(0)
332
+ self.options = dict(dtype=xs.dtype, device=xs.device)
333
+
334
+ if self.loss_func=='negbinomial':
335
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
336
+ xs.new_ones(self.input_size), constraint=constraints.positive)
337
+
338
+ if self.use_zeroinflate:
339
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
340
+
341
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
342
+
343
+ I = torch.eye(self.code_size)
344
+ if self.latent_dist=='studentt':
345
+ acs_dof,acs_loc = self.codebook(I)
346
+ else:
347
+ acs_loc = self.codebook(I)
348
+
349
+ with pyro.plate('data'):
350
+ prior = torch.zeros(batch_size, self.code_size, **self.options)
351
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
352
+
353
+ zn_loc = torch.matmul(ns,acs_loc)
354
+ #zn_scale = torch.matmul(ns,acs_scale)
355
+ zn_scale = acs_scale
356
+
357
+ if self.latent_dist == 'studentt':
358
+ prior_dof = torch.matmul(ns,acs_dof)
359
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
360
+ elif self.latent_dist == 'laplacian':
361
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
362
+ elif self.latent_dist == 'cauchy':
363
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
364
+ elif self.latent_dist == 'normal':
365
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
366
+ elif self.latent_dist == 'gumbel':
367
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
368
+
369
+ zs = zns
370
+ concentrate = self.decoder_concentrate(zs)
371
+ rate = concentrate.exp()
372
+ if self.loss_func != 'poisson':
373
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
374
+
375
+ if self.loss_func == 'negbinomial':
376
+ if self.use_zeroinflate:
377
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
378
+ else:
379
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
380
+ elif self.loss_func == 'poisson':
381
+ if self.use_zeroinflate:
382
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
383
+ else:
384
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
385
+ elif self.loss_func == 'multinomial':
386
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
387
+
388
+ def guide1(self, xs):
389
+ with pyro.plate('data'):
390
+ zn_loc, zn_scale = self.encoder_zn(xs)
391
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
392
+
393
+ alpha = self.encoder_n(zns)
394
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
395
+
396
+ def model2(self, xs, us=None):
397
+ pyro.module('sure', self)
398
+
399
+ eps = torch.finfo(xs.dtype).eps
400
+ batch_size = xs.size(0)
401
+ self.options = dict(dtype=xs.dtype, device=xs.device)
402
+
403
+ if self.loss_func=='negbinomial':
404
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
405
+ xs.new_ones(self.input_size), constraint=constraints.positive)
406
+
407
+ if self.use_zeroinflate:
408
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
409
+
410
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
411
+
412
+ I = torch.eye(self.code_size)
413
+ if self.latent_dist=='studentt':
414
+ acs_dof,acs_loc = self.codebook(I)
415
+ else:
416
+ acs_loc = self.codebook(I)
417
+
418
+ with pyro.plate('data'):
419
+ prior = torch.zeros(batch_size, self.code_size, **self.options)
420
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
421
+
422
+ zn_loc = torch.matmul(ns,acs_loc)
423
+ #zn_scale = torch.matmul(ns,acs_scale)
424
+ zn_scale = acs_scale
425
+
426
+ if self.latent_dist == 'studentt':
427
+ prior_dof = torch.matmul(ns,acs_dof)
428
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
429
+ elif self.latent_dist == 'laplacian':
430
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
431
+ elif self.latent_dist == 'cauchy':
432
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
433
+ elif self.latent_dist == 'normal':
434
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
435
+ elif self.latent_dist == 'gumbel':
436
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
437
+
438
+ if self.cell_factor_size>0:
439
+ zus = self.cell_factor_effect(us)
440
+ zs = zns + zus
441
+ else:
442
+ zs = zns
443
+
444
+ concentrate = self.decoder_concentrate(zs)
445
+ rate = concentrate.exp()
446
+ if self.loss_func != 'poisson':
447
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
448
+
449
+ if self.loss_func == 'negbinomial':
450
+ if self.use_zeroinflate:
451
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
452
+ else:
453
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
454
+ elif self.loss_func == 'poisson':
455
+ if self.use_zeroinflate:
456
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
457
+ else:
458
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
459
+ elif self.loss_func == 'multinomial':
460
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
461
+
462
+ def guide2(self, xs, us=None):
463
+ with pyro.plate('data'):
464
+ zn_loc, zn_scale = self.encoder_zn(xs)
465
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
466
+
467
+ alpha = self.encoder_n(zns)
468
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
469
+
470
+ def model3(self, xs, ys, embeds=None):
471
+ pyro.module('sure', self)
472
+
473
+ eps = torch.finfo(xs.dtype).eps
474
+ batch_size = xs.size(0)
475
+ self.options = dict(dtype=xs.dtype, device=xs.device)
476
+
477
+ if self.loss_func=='negbinomial':
478
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
479
+ xs.new_ones(self.input_size), constraint=constraints.positive)
480
+
481
+ if self.use_zeroinflate:
482
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
483
+
484
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
485
+
486
+ I = torch.eye(self.code_size)
487
+ if self.latent_dist=='studentt':
488
+ acs_dof,acs_loc = self.codebook(I)
489
+ else:
490
+ acs_loc = self.codebook(I)
491
+
492
+ with pyro.plate('data'):
493
+ #prior = torch.zeros(batch_size, self.code_size, **self.options)
494
+ prior = self.encoder_n(xs)
495
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
496
+
497
+ zn_loc = torch.matmul(ns,acs_loc)
498
+ #prior_scale = torch.matmul(ns,acs_scale)
499
+ zn_scale = acs_scale
500
+
501
+ if self.latent_dist=='studentt':
502
+ prior_dof = torch.matmul(ns,acs_dof)
503
+ if embeds is None:
504
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
505
+ else:
506
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
507
+ elif self.latent_dist=='laplacian':
508
+ if embeds is None:
509
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
510
+ else:
511
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
512
+ elif self.latent_dist=='cauchy':
513
+ if embeds is None:
514
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
515
+ else:
516
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
517
+ elif self.latent_dist=='normal':
518
+ if embeds is None:
519
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
520
+ else:
521
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
522
+ elif self.z_dist == 'gumbel':
523
+ if embeds is None:
524
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
525
+ else:
526
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
527
+
528
+ zs = zns
529
+
530
+ concentrate = self.decoder_concentrate(zs)
531
+ rate = concentrate.exp()
532
+ if self.loss_func != 'poisson':
533
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
534
+
535
+ if self.loss_func == 'negbinomial':
536
+ if self.use_zeroinflate:
537
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
538
+ else:
539
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
540
+ elif self.loss_func == 'poisson':
541
+ if self.use_zeroinflate:
542
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
543
+ else:
544
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
545
+ elif self.loss_func == 'multinomial':
546
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
547
+
548
+ def guide3(self, xs, ys, embeds=None):
549
+ with pyro.plate('data'):
550
+ if embeds is None:
551
+ zn_loc, zn_scale = self.encoder_zn(xs)
552
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
553
+
554
+ def model4(self, xs, us, ys, embeds=None):
555
+ pyro.module('sure', self)
556
+
557
+ eps = torch.finfo(xs.dtype).eps
558
+ batch_size = xs.size(0)
559
+ self.options = dict(dtype=xs.dtype, device=xs.device)
560
+
561
+ if self.loss_func=='negbinomial':
562
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
563
+ xs.new_ones(self.input_size), constraint=constraints.positive)
564
+
565
+ if self.use_zeroinflate:
566
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
567
+
568
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
569
+
570
+ I = torch.eye(self.code_size)
571
+ if self.latent_dist=='studentt':
572
+ acs_dof,acs_loc = self.codebook(I)
573
+ else:
574
+ acs_loc = self.codebook(I)
575
+
576
+ with pyro.plate('data'):
577
+ #prior = torch.zeros(batch_size, self.code_size, **self.options)
578
+ prior = self.encoder_n(xs)
579
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
580
+
581
+ zn_loc = torch.matmul(ns,acs_loc)
582
+ #prior_scale = torch.matmul(ns,acs_scale)
583
+ zn_scale = acs_scale
584
+
585
+ if self.latent_dist=='studentt':
586
+ prior_dof = torch.matmul(ns,acs_dof)
587
+ if embeds is None:
588
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
589
+ else:
590
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
591
+ elif self.latent_dist=='laplacian':
592
+ if embeds is None:
593
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
594
+ else:
595
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
596
+ elif self.latent_dist=='cauchy':
597
+ if embeds is None:
598
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
599
+ else:
600
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
601
+ elif self.latent_dist=='normal':
602
+ if embeds is None:
603
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
604
+ else:
605
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
606
+ elif self.z_dist == 'gumbel':
607
+ if embeds is None:
608
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
609
+ else:
610
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
611
+
612
+ if self.cell_factor_size>0:
613
+ zus = self.cell_factor_effect(us)
614
+ zs = zus + zns
615
+ else:
616
+ zs = zns
617
+
618
+ concentrate = self.decoder_concentrate(zs)
619
+ rate = concentrate.exp()
620
+ if self.loss_func != 'poisson':
621
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
622
+
623
+ if self.loss_func == 'negbinomial':
624
+ if self.use_zeroinflate:
625
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
626
+ else:
627
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
628
+ elif self.loss_func == 'poisson':
629
+ if self.use_zeroinflate:
630
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
631
+ else:
632
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
633
+ elif self.loss_func == 'multinomial':
634
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
635
+
636
+ def guide4(self, xs, us, ys, embeds=None):
637
+ with pyro.plate('data'):
638
+ if embeds is None:
639
+ zn_loc, zn_scale = self.encoder_zn(xs)
640
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
641
+
642
+ def _get_codebook(self):
643
+ I = torch.eye(self.code_size, **self.options)
644
+ if self.latent_dist=='studentt':
645
+ _,cb = self.codebook(I)
646
+ else:
647
+ cb = self.codebook(I)
648
+ return cb
649
+
650
+ def get_codebook(self):
651
+ """
652
+ Return the mean part of metacell codebook
653
+ """
654
+ cb = self._get_metacell_coordinates()
655
+ cb = tensor_to_numpy(cb)
656
+ return cb
657
+
658
+ def _get_cell_embedding(self, xs):
659
+ zns, _ = self.encoder_zn(xs)
660
+ return zns
661
+
662
+ def get_cell_embedding(self,
663
+ xs,
664
+ batch_size: int = 1024):
665
+ """
666
+ Return cells' latent representations
667
+
668
+ Parameters
669
+ ----------
670
+ xs
671
+ Single-cell expression matrix. It should be a Numpy array or a Pytorch Tensor.
672
+ batch_size
673
+ Size of batch processing.
674
+ use_decoder
675
+ If toggled on, the latent representations will be reconstructed from the metacell codebook
676
+ soft_assign
677
+ If toggled on, the assignments of cells will use probabilistic values.
678
+ """
679
+ xs = self.preprocess(xs)
680
+ xs = convert_to_tensor(xs, device=self.get_device())
681
+ dataset = CustomDataset(xs)
682
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
683
+
684
+ Z = []
685
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
686
+ for X_batch, _ in dataloader:
687
+ zns = self._get_cell_embedding(X_batch)
688
+ Z.append(tensor_to_numpy(zns))
689
+ pbar.update(1)
690
+
691
+ Z = np.concatenate(Z)
692
+ return Z
693
+
694
+ def _code(self, xs):
695
+ if self.supervised_mode:
696
+ alpha = self.encoder_n(xs)
697
+ else:
698
+ zns,_ = self.encoder_zn(xs)
699
+ alpha = self.encoder_n(zns)
700
+ return alpha
701
+
702
+ def code(self, xs, batch_size=1024):
703
+ xs = self.preprocess(xs)
704
+ xs = convert_to_tensor(xs, device=self.get_device())
705
+ dataset = CustomDataset(xs)
706
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
707
+
708
+ A = []
709
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
710
+ for X_batch, _ in dataloader:
711
+ a = self._code(X_batch)
712
+ A.append(tensor_to_numpy(a))
713
+ pbar.update(1)
714
+
715
+ A = np.concatenate(A)
716
+ return A
717
+
718
+ def _soft_assignments(self, xs):
719
+ alpha = self._code(xs)
720
+ alpha = self.softmax(alpha)
721
+ return alpha
722
+
723
+ def soft_assignments(self, xs, batch_size=1024):
724
+ """
725
+ Map cells to metacells and return the probabilistic values of metacell assignments
726
+ """
727
+ xs = self.preprocess(xs)
728
+ xs = convert_to_tensor(xs, device=self.get_device())
729
+ dataset = CustomDataset(xs)
730
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
731
+
732
+ A = []
733
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
734
+ for X_batch, _ in dataloader:
735
+ a = self._soft_assignments(X_batch)
736
+ A.append(tensor_to_numpy(a))
737
+ pbar.update(1)
738
+
739
+ A = np.concatenate(A)
740
+ return A
741
+
742
+ def _hard_assignments(self, xs):
743
+ alpha = self._code(xs)
744
+ res, ind = torch.topk(alpha, 1)
745
+ ns = torch.zeros_like(alpha).scatter_(1, ind, 1.0)
746
+ return ns
747
+
748
+ def hard_assignments(self, xs, batch_size=1024):
749
+ """
750
+ Map cells to metacells and return the assigned metacell identities.
751
+ """
752
+ xs = self.preprocess(xs)
753
+ xs = convert_to_tensor(xs, device=self.get_device())
754
+ dataset = CustomDataset(xs)
755
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
756
+
757
+ A = []
758
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
759
+ for X_batch, _ in dataloader:
760
+ a = self._hard_assignments(X_batch)
761
+ A.append(tensor_to_numpy(a))
762
+ pbar.update(1)
763
+
764
+ A = np.concatenate(A)
765
+ return A
766
+
767
+ def preprocess(self, xs):
768
+ if sparse.issparse(xs):
769
+ xs = xs.toarray()
770
+ return xs
771
+
772
+ def fit(self, xs,
773
+ us = None,
774
+ ys = None,
775
+ zs = None,
776
+ num_epochs: int = 200,
777
+ learning_rate: float = 0.0001,
778
+ batch_size: int = 256,
779
+ algo: Literal['adam','rmsprop','adamw'] = 'adam',
780
+ beta_1: float = 0.9,
781
+ weight_decay: float = 0.005,
782
+ decay_rate: float = 0.9,
783
+ config_enum: str = 'parallel',
784
+ use_jax: bool = False):
785
+ """
786
+ Train the SURE model.
787
+
788
+ Parameters
789
+ ----------
790
+ xs
791
+ Single-cell experssion matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are features.
792
+ us
793
+ Undesired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are undesired factors.
794
+ ys
795
+ Desired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are desired factors.
796
+ num_epochs
797
+ Number of training epochs.
798
+ learning_rate
799
+ Parameter for training.
800
+ batch_size
801
+ Size of batch processing.
802
+ algo
803
+ Optimization algorithm.
804
+ beta_1
805
+ Parameter for optimization.
806
+ weight_decay
807
+ Parameter for optimization.
808
+ decay_rate
809
+ Parameter for optimization.
810
+ use_jax
811
+ If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
812
+ the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
813
+ """
814
+ xs = self.preprocess(xs)
815
+ xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
816
+ if us is not None:
817
+ us = convert_to_tensor(us, dtype=self.dtype, device=self.get_device())
818
+ if ys is not None:
819
+ ys = convert_to_tensor(ys, dtype=self.dtype, device=self.get_device())
820
+ if zs is not None:
821
+ zs = convert_to_tensor(zs, dtype=self.dtype, device=self.get_device())
822
+
823
+ dataset = CustomDataset4(xs, us, ys, zs)
824
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
825
+
826
+ # setup the optimizer
827
+ optim_params = {'lr': learning_rate, 'betas': (beta_1, 0.999), 'weight_decay': weight_decay}
828
+
829
+ if algo.lower()=='rmsprop':
830
+ optimizer = torch.optim.RMSprop
831
+ elif algo.lower()=='adam':
832
+ optimizer = torch.optim.Adam
833
+ elif algo.lower() == 'adamw':
834
+ optimizer = torch.optim.AdamW
835
+ else:
836
+ raise ValueError("An optimization algorithm must be specified.")
837
+ scheduler = ExponentialLR({'optimizer': optimizer, 'optim_args': optim_params, 'gamma': decay_rate})
838
+
839
+ pyro.clear_param_store()
840
+
841
+ # set up the loss(es) for inference, wrapping the guide in config_enumerate builds the loss as a sum
842
+ # by enumerating each class label form the sampled discrete categorical distribution in the model
843
+ Elbo = JitTraceEnum_ELBO if use_jax else TraceEnum_ELBO
844
+ elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False)
845
+ if us is None:
846
+ if ys is None:
847
+ guide = config_enumerate(self.guide1, config_enum, expand=True)
848
+ loss_basic = SVI(self.model1, guide, scheduler, loss=elbo)
849
+ else:
850
+ guide = config_enumerate(self.guide3, config_enum, expand=True)
851
+ loss_basic = SVI(self.model3, guide, scheduler, loss=elbo)
852
+ else:
853
+ if ys is None:
854
+ guide = config_enumerate(self.guide2, config_enum, expand=True)
855
+ loss_basic = SVI(self.model2, guide, scheduler, loss=elbo)
856
+ else:
857
+ guide = config_enumerate(self.guide4, config_enum, expand=True)
858
+ loss_basic = SVI(self.model4, guide, scheduler, loss=elbo)
859
+
860
+ # build a list of all losses considered
861
+ losses = [loss_basic]
862
+ num_losses = len(losses)
863
+
864
+ with tqdm(total=num_epochs, desc='Training', unit='epoch') as pbar:
865
+ for epoch in range(num_epochs):
866
+ epoch_losses = [0.0] * num_losses
867
+ for batch_x, batch_u, batch_y, batch_z, _ in dataloader:
868
+ if us is None:
869
+ batch_u = None
870
+ if ys is None:
871
+ batch_y = None
872
+ if zs is None:
873
+ batch_z = None
874
+
875
+ for loss_id in range(num_losses):
876
+ if batch_u is None:
877
+ if batch_y is None:
878
+ new_loss = losses[loss_id].step(batch_x)
879
+ else:
880
+ new_loss = losses[loss_id].step(batch_x, batch_y, batch_z)
881
+ else:
882
+ if batch_y is None:
883
+ new_loss = losses[loss_id].step(batch_x, batch_u)
884
+ else:
885
+ new_loss = losses[loss_id].step(batch_x, batch_u, batch_y, batch_z)
886
+ epoch_losses[loss_id] += new_loss
887
+
888
+ avg_epoch_losses_ = map(lambda v: v / len(dataloader), epoch_losses)
889
+ avg_epoch_losses = map(lambda v: "{:.4f}".format(v), avg_epoch_losses_)
890
+
891
+ # store the loss
892
+ str_loss = " ".join(map(str, avg_epoch_losses))
893
+
894
+ # Update progress bar
895
+ pbar.set_postfix({'loss': str_loss})
896
+ pbar.update(1)
897
+
898
+ assigns = self.soft_assignments(xs)
899
+ assigns = convert_to_tensor(assigns, dtype=self.dtype, device=self.get_device())
900
+ self.codebook_weights = torch.sum(assigns, dim=0)
901
+ self.codebook_weights = self.codebook_weights / torch.sum(self.codebook_weights)
902
+ #self.inverse_dispersion = pyro.param('inverse_dispersion').item()
903
+ #self.dof = pyro.param('dof').item()
904
+
905
+ @classmethod
906
+ def save_model(cls, model, file_path, compression=False):
907
+ """Save the model to the specified file path."""
908
+ file_path = os.path.abspath(file_path)
909
+
910
+ model.eval()
911
+ if compression:
912
+ with gzip.open(file_path, 'wb') as pickle_file:
913
+ pickle.dump(model, pickle_file)
914
+ else:
915
+ with open(file_path, 'wb') as pickle_file:
916
+ pickle.dump(model, pickle_file)
917
+
918
+ print(f'Model saved to {file_path}')
919
+
920
+ @classmethod
921
+ def load_model(cls, file_path):
922
+ """Load the model from the specified file path and return an instance."""
923
+ print(f'Model loaded from {file_path}')
924
+
925
+ file_path = os.path.abspath(file_path)
926
+ if file_path.endswith('gz'):
927
+ with gzip.open(file_path, 'rb') as pickle_file:
928
+ model = pickle.load(pickle_file)
929
+ else:
930
+ with open(file_path, 'rb') as pickle_file:
931
+ model = pickle.load(pickle_file)
932
+
933
+ return model
934
+
935
+
936
+ EXAMPLE_RUN = (
937
+ "example run: SURE --help"
938
+ )
939
+
940
+ def parse_args():
941
+ parser = argparse.ArgumentParser(
942
+ description="SURE\n{}".format(EXAMPLE_RUN))
943
+
944
+ parser.add_argument(
945
+ "--cuda", action="store_true", help="use GPU(s) to speed up training"
946
+ )
947
+ parser.add_argument(
948
+ "--jit", action="store_true", help="use PyTorch jit to speed up training"
949
+ )
950
+ parser.add_argument(
951
+ "-n", "--num-epochs", default=200, type=int, help="number of epochs to run"
952
+ )
953
+ parser.add_argument(
954
+ "-enum",
955
+ "--enum-discrete",
956
+ default="parallel",
957
+ help="parallel, sequential or none. uses parallel enumeration by default",
958
+ )
959
+ parser.add_argument(
960
+ "-data",
961
+ "--data-file",
962
+ default=None,
963
+ type=str,
964
+ help="the data file",
965
+ )
966
+ parser.add_argument(
967
+ "-undesired",
968
+ "--undesired-factor-file",
969
+ default=None,
970
+ type=str,
971
+ help="the file for the record of undesired factors",
972
+ )
973
+ parser.add_argument(
974
+ "-delta",
975
+ "--delta",
976
+ default=0.0,
977
+ type=float,
978
+ help="penalty weight for zero inflation loss",
979
+ )
980
+ parser.add_argument(
981
+ "-64",
982
+ "--float64",
983
+ action="store_true",
984
+ help="use double float precision",
985
+ )
986
+ parser.add_argument(
987
+ "--z-dist",
988
+ default='normal',
989
+ type=str,
990
+ choices=['normal','laplacian','studentt','cauchy'],
991
+ help="distribution model for latent representation",
992
+ )
993
+ parser.add_argument(
994
+ "-cs",
995
+ "--codebook-size",
996
+ default=100,
997
+ type=int,
998
+ help="size of vector quantization codebook",
999
+ )
1000
+ parser.add_argument(
1001
+ "-zd",
1002
+ "--z-dim",
1003
+ default=10,
1004
+ type=int,
1005
+ help="size of the tensor representing the latent variable z variable",
1006
+ )
1007
+ parser.add_argument(
1008
+ "-hl",
1009
+ "--hidden-layers",
1010
+ nargs="+",
1011
+ default=[500],
1012
+ type=int,
1013
+ help="a tuple (or list) of MLP layers to be used in the neural networks "
1014
+ "representing the parameters of the distributions in our model",
1015
+ )
1016
+ parser.add_argument(
1017
+ "-hla",
1018
+ "--hidden-layer-activation",
1019
+ default='relu',
1020
+ type=str,
1021
+ choices=['relu','softplus','leakyrelu','linear'],
1022
+ help="activation function for hidden layers",
1023
+ )
1024
+ parser.add_argument(
1025
+ "-plf",
1026
+ "--post-layer-function",
1027
+ nargs="+",
1028
+ default=['layernorm'],
1029
+ type=str,
1030
+ help="post functions for hidden layers, could be none, dropout, layernorm, batchnorm, or combination, default is 'dropout layernorm'",
1031
+ )
1032
+ parser.add_argument(
1033
+ "-paf",
1034
+ "--post-activation-function",
1035
+ nargs="+",
1036
+ default=['none'],
1037
+ type=str,
1038
+ help="post functions for activation layers, could be none or dropout, default is 'none'",
1039
+ )
1040
+ parser.add_argument(
1041
+ "-id",
1042
+ "--inverse-dispersion",
1043
+ default=10.0,
1044
+ type=float,
1045
+ help="inverse dispersion prior for negative binomial",
1046
+ )
1047
+ parser.add_argument(
1048
+ "-lr",
1049
+ "--learning-rate",
1050
+ default=0.0001,
1051
+ type=float,
1052
+ help="learning rate for Adam optimizer",
1053
+ )
1054
+ parser.add_argument(
1055
+ "-dr",
1056
+ "--decay-rate",
1057
+ default=0.9,
1058
+ type=float,
1059
+ help="decay rate for Adam optimizer",
1060
+ )
1061
+ parser.add_argument(
1062
+ "--layer-dropout-rate",
1063
+ default=0.1,
1064
+ type=float,
1065
+ help="droput rate for neural networks",
1066
+ )
1067
+ parser.add_argument(
1068
+ "-b1",
1069
+ "--beta-1",
1070
+ default=0.95,
1071
+ type=float,
1072
+ help="beta-1 parameter for Adam optimizer",
1073
+ )
1074
+ parser.add_argument(
1075
+ "-bs",
1076
+ "--batch-size",
1077
+ default=1000,
1078
+ type=int,
1079
+ help="number of cells to be considered in a batch",
1080
+ )
1081
+ parser.add_argument(
1082
+ "-gp",
1083
+ "--gate-prior",
1084
+ default=0.6,
1085
+ type=float,
1086
+ help="gate prior for zero-inflated model",
1087
+ )
1088
+ parser.add_argument(
1089
+ "-likeli",
1090
+ "--likelihood",
1091
+ default='negbinomial',
1092
+ type=str,
1093
+ choices=['negbinomial', 'multinomial', 'poisson', 'gaussian','lognormal'],
1094
+ help="specify the distribution likelihood function",
1095
+ )
1096
+ parser.add_argument(
1097
+ "-dirichlet",
1098
+ "--use-dirichlet",
1099
+ action="store_true",
1100
+ help="use Dirichlet distribution over gene frequency",
1101
+ )
1102
+ parser.add_argument(
1103
+ "-mass",
1104
+ "--dirichlet-mass",
1105
+ default=1,
1106
+ type=float,
1107
+ help="mass param for dirichlet model",
1108
+ )
1109
+ parser.add_argument(
1110
+ "-zi",
1111
+ "--zero-inflation",
1112
+ default='exact',
1113
+ type=str,
1114
+ choices=['none','exact','inexact'],
1115
+ help="use zero-inflated estimation",
1116
+ )
1117
+ parser.add_argument(
1118
+ "--seed",
1119
+ default=None,
1120
+ type=int,
1121
+ help="seed for controlling randomness in this example",
1122
+ )
1123
+ parser.add_argument(
1124
+ "--save-model",
1125
+ default=None,
1126
+ type=str,
1127
+ help="path to save model for prediction",
1128
+ )
1129
+ args = parser.parse_args()
1130
+ return args
1131
+
1132
+ def main():
1133
+ args = parse_args()
1134
+ assert (
1135
+ (args.data_file is not None) and (
1136
+ os.path.exists(args.data_file))
1137
+ ), "data file must be provided"
1138
+
1139
+ if args.seed is not None:
1140
+ set_random_seed(args.seed)
1141
+
1142
+ if args.float64:
1143
+ dtype = torch.float64
1144
+ torch.set_default_dtype(torch.float64)
1145
+ else:
1146
+ dtype = torch.float32
1147
+ torch.set_default_dtype(torch.float32)
1148
+
1149
+ xs = dt.fread(file=args.data_file, header=True).to_numpy()
1150
+ us = None
1151
+ if args.undesired_factor_file is not None:
1152
+ us = dt.fread(file=args.undesired_factor_file, header=True).to_numpy()
1153
+
1154
+ input_size = xs.shape[1]
1155
+ undesired_size = 0 if us is None else us.shape[1]
1156
+
1157
+ latent_dist = args.z_dist
1158
+
1159
+ ###########################################
1160
+ sure = SURE(
1161
+ input_size=input_size,
1162
+ undesired_size=undesired_size,
1163
+ inverse_dispersion=args.inverse_dispersion,
1164
+ latent_dim=args.latent_dim,
1165
+ hidden_layers=args.hidden_layers,
1166
+ hidden_layer_activation=args.hidden_layer_activation,
1167
+ use_cuda=args.cuda,
1168
+ config_enum=args.enum_discrete,
1169
+ use_dirichlet=args.use_dirichlet,
1170
+ zero_inflation=args.zero_inflation,
1171
+ gate_prior=args.gate_prior,
1172
+ delta=args.delta,
1173
+ loss_func=args.likelihood,
1174
+ dirichlet_mass=args.dirichlet_mass,
1175
+ nn_dropout=args.layer_dropout_rate,
1176
+ post_layer_fct=args.post_layer_function,
1177
+ post_act_fct=args.post_activation_function,
1178
+ codebook_size=args.codebook_size,
1179
+ latent_dist = latent_dist,
1180
+ dtype=dtype,
1181
+ )
1182
+
1183
+ sure.fit(xs, us=us,
1184
+ num_epochs=args.num_epochs,
1185
+ learning_rate=args.learning_rate,
1186
+ batch_size=args.batch_size,
1187
+ beta_1=args.beta_1,
1188
+ decay_rate=args.decay_rate,
1189
+ use_jax=args.jit,
1190
+ config_enum=args.enum_discrete,
1191
+ )
1192
+
1193
+ if args.save_model is not None:
1194
+ if args.save_model.endswith('gz'):
1195
+ SURE.save_model(sure, args.save_model, compression=True)
1196
+ else:
1197
+ SURE.save_model(sure, args.save_model)
1198
+
1199
+
1200
+
1201
+ if __name__ == "__main__":
1202
+
1203
+ main()