SURE-tools 2.2.24__py3-none-any.whl → 2.4.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
SURE/PerturbE.py ADDED
@@ -0,0 +1,1300 @@
1
+ import pyro
2
+ import pyro.distributions as dist
3
+ from pyro.optim import ExponentialLR
4
+ from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, config_enumerate
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader
9
+ from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_probs
10
+ from torch.distributions import constraints
11
+ from torch.distributions.transforms import SoftmaxTransform
12
+
13
+ from .utils.custom_mlp import MLP, Exp, ZeroBiasMLP2
14
+ from .utils.utils import CustomDataset, CustomDataset2, CustomDataset4, tensor_to_numpy, convert_to_tensor
15
+
16
+
17
+ import os
18
+ import argparse
19
+ import random
20
+ import numpy as np
21
+ import datatable as dt
22
+ from tqdm import tqdm
23
+ from scipy import sparse
24
+
25
+ import scanpy as sc
26
+ from .atac import binarize
27
+
28
+ from typing import Literal
29
+
30
+ import warnings
31
+ warnings.filterwarnings("ignore")
32
+
33
+ import dill as pickle
34
+ import gzip
35
+ from packaging.version import Version
36
+ torch_version = torch.__version__
37
+
38
+
39
+ def set_random_seed(seed):
40
+ # Set seed for PyTorch
41
+ torch.manual_seed(seed)
42
+
43
+ # If using CUDA, set the seed for CUDA
44
+ if torch.cuda.is_available():
45
+ torch.cuda.manual_seed(seed)
46
+ torch.cuda.manual_seed_all(seed) # For multi-GPU setups.
47
+
48
+ # Set seed for NumPy
49
+ np.random.seed(seed)
50
+
51
+ # Set seed for Python's random module
52
+ random.seed(seed)
53
+
54
+ # Set seed for Pyro
55
+ pyro.set_rng_seed(seed)
56
+
57
+ class PerturbE(nn.Module):
58
+ def __init__(self,
59
+ input_size: int,
60
+ codebook_size: int = 200,
61
+ cell_factor_size: int = 0,
62
+ supervised_mode: bool = False,
63
+ z_dim: int = 10,
64
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
66
+ inverse_dispersion: float = 10.0,
67
+ use_zeroinflate: bool = False,
68
+ hidden_layers: list = [500],
69
+ hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
70
+ nn_dropout: float = 0.1,
71
+ post_layer_fct: list = ['layernorm'],
72
+ post_act_fct: list = None,
73
+ config_enum: str = 'parallel',
74
+ use_cuda: bool = True,
75
+ seed: int = 42,
76
+ dtype = torch.float32, # type: ignore
77
+ ):
78
+ super().__init__()
79
+
80
+ self.input_size = input_size
81
+ self.cell_factor_size = cell_factor_size
82
+ self.inverse_dispersion = inverse_dispersion
83
+ self.latent_dim = z_dim
84
+ self.hidden_layers = hidden_layers
85
+ self.decoder_hidden_layers = hidden_layers[::-1]
86
+ self.allow_broadcast = config_enum == 'parallel'
87
+ self.use_cuda = use_cuda
88
+ self.loss_func = loss_func
89
+ self.options = None
90
+ self.code_size=codebook_size
91
+ self.supervised_mode=supervised_mode
92
+ self.latent_dist = z_dist
93
+ self.dtype = dtype
94
+ self.use_zeroinflate=use_zeroinflate
95
+ self.nn_dropout = nn_dropout
96
+ self.post_layer_fct = post_layer_fct
97
+ self.post_act_fct = post_act_fct
98
+ self.hidden_layer_activation = hidden_layer_activation
99
+
100
+ self.codebook_weights = None
101
+
102
+ set_random_seed(seed)
103
+ self.setup_networks()
104
+
105
+ def setup_networks(self):
106
+ latent_dim = self.latent_dim
107
+ hidden_sizes = self.hidden_layers
108
+
109
+ nn_layer_norm, nn_batch_norm, nn_layer_dropout = False, False, False
110
+ na_layer_norm, na_batch_norm, na_layer_dropout = False, False, False
111
+
112
+ if self.post_layer_fct is not None:
113
+ nn_layer_norm=True if ('layernorm' in self.post_layer_fct) or ('layer_norm' in self.post_layer_fct) else False
114
+ nn_batch_norm=True if ('batchnorm' in self.post_layer_fct) or ('batch_norm' in self.post_layer_fct) else False
115
+ nn_layer_dropout=True if 'dropout' in self.post_layer_fct else False
116
+
117
+ if self.post_act_fct is not None:
118
+ na_layer_norm=True if ('layernorm' in self.post_act_fct) or ('layer_norm' in self.post_act_fct) else False
119
+ na_batch_norm=True if ('batchnorm' in self.post_act_fct) or ('batch_norm' in self.post_act_fct) else False
120
+ na_layer_dropout=True if 'dropout' in self.post_act_fct else False
121
+
122
+ if nn_layer_norm and nn_batch_norm and nn_layer_dropout:
123
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
124
+ elif nn_layer_norm and nn_layer_dropout:
125
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
126
+ elif nn_batch_norm and nn_layer_dropout:
127
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
128
+ elif nn_layer_norm and nn_batch_norm:
129
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
130
+ elif nn_layer_norm:
131
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
132
+ elif nn_batch_norm:
133
+ post_layer_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
134
+ elif nn_layer_dropout:
135
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
136
+ else:
137
+ post_layer_fct = lambda layer_ix, total_layers, layer: None
138
+
139
+ if na_layer_norm and na_batch_norm and na_layer_dropout:
140
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
141
+ elif na_layer_norm and na_layer_dropout:
142
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
143
+ elif na_batch_norm and na_layer_dropout:
144
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
145
+ elif na_layer_norm and na_batch_norm:
146
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
147
+ elif na_layer_norm:
148
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
149
+ elif na_batch_norm:
150
+ post_act_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
151
+ elif na_layer_dropout:
152
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
153
+ else:
154
+ post_act_fct = lambda layer_ix, total_layers, layer: None
155
+
156
+ if self.hidden_layer_activation == 'relu':
157
+ activate_fct = nn.ReLU
158
+ elif self.hidden_layer_activation == 'softplus':
159
+ activate_fct = nn.Softplus
160
+ elif self.hidden_layer_activation == 'leakyrelu':
161
+ activate_fct = nn.LeakyReLU
162
+ elif self.hidden_layer_activation == 'linear':
163
+ activate_fct = nn.Identity
164
+
165
+ if self.supervised_mode:
166
+ self.encoder_n = MLP(
167
+ [self.input_size] + hidden_sizes + [self.code_size],
168
+ activation=activate_fct,
169
+ output_activation=None,
170
+ post_layer_fct=post_layer_fct,
171
+ post_act_fct=post_act_fct,
172
+ allow_broadcast=self.allow_broadcast,
173
+ use_cuda=self.use_cuda,
174
+ )
175
+ else:
176
+ self.encoder_n = MLP(
177
+ [self.latent_dim] + hidden_sizes + [self.code_size],
178
+ activation=activate_fct,
179
+ output_activation=None,
180
+ post_layer_fct=post_layer_fct,
181
+ post_act_fct=post_act_fct,
182
+ allow_broadcast=self.allow_broadcast,
183
+ use_cuda=self.use_cuda,
184
+ )
185
+
186
+ self.encoder_zn = MLP(
187
+ [self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
188
+ activation=activate_fct,
189
+ output_activation=[None, Exp],
190
+ post_layer_fct=post_layer_fct,
191
+ post_act_fct=post_act_fct,
192
+ allow_broadcast=self.allow_broadcast,
193
+ use_cuda=self.use_cuda,
194
+ )
195
+
196
+ if self.cell_factor_size>0:
197
+ self.cell_factor_effect = ZeroBiasMLP2(
198
+ [self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
199
+ activation=activate_fct,
200
+ output_activation=None,
201
+ post_layer_fct=post_layer_fct,
202
+ post_act_fct=post_act_fct,
203
+ allow_broadcast=self.allow_broadcast,
204
+ use_cuda=self.use_cuda,
205
+ )
206
+
207
+ self.decoder_concentrate = MLP(
208
+ [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
209
+ activation=activate_fct,
210
+ output_activation=None,
211
+ post_layer_fct=post_layer_fct,
212
+ post_act_fct=post_act_fct,
213
+ allow_broadcast=self.allow_broadcast,
214
+ use_cuda=self.use_cuda,
215
+ )
216
+
217
+ if self.latent_dist == 'studentt':
218
+ self.codebook = MLP(
219
+ [self.code_size] + hidden_sizes + [[latent_dim,latent_dim]],
220
+ activation=activate_fct,
221
+ output_activation=[Exp,None],
222
+ post_layer_fct=post_layer_fct,
223
+ post_act_fct=post_act_fct,
224
+ allow_broadcast=self.allow_broadcast,
225
+ use_cuda=self.use_cuda,
226
+ )
227
+ else:
228
+ self.codebook = MLP(
229
+ [self.code_size] + hidden_sizes + [latent_dim],
230
+ activation=activate_fct,
231
+ output_activation=None,
232
+ post_layer_fct=post_layer_fct,
233
+ post_act_fct=post_act_fct,
234
+ allow_broadcast=self.allow_broadcast,
235
+ use_cuda=self.use_cuda,
236
+ )
237
+
238
+ if self.use_cuda:
239
+ self.cuda()
240
+
241
+ def get_device(self):
242
+ return next(self.parameters()).device
243
+
244
+ def cutoff(self, xs, thresh=None):
245
+ eps = torch.finfo(xs.dtype).eps
246
+
247
+ if not thresh is None:
248
+ if eps < thresh:
249
+ eps = thresh
250
+
251
+ xs = xs.clamp(min=eps)
252
+
253
+ if torch.any(torch.isnan(xs)):
254
+ xs[torch.isnan(xs)] = eps
255
+
256
+ return xs
257
+
258
+ def softmax(self, xs):
259
+ #xs = SoftmaxTransform()(xs)
260
+ xs = dist.Multinomial(total_count=1, logits=xs).mean
261
+ return xs
262
+
263
+ def sigmoid(self, xs):
264
+ #sigm_enc = nn.Sigmoid()
265
+ #xs = sigm_enc(xs)
266
+ #xs = clamp_probs(xs)
267
+ xs = dist.Bernoulli(logits=xs).mean
268
+ return xs
269
+
270
+ def softmax_logit(self, xs):
271
+ eps = torch.finfo(xs.dtype).eps
272
+ xs = self.softmax(xs)
273
+ xs = torch.logit(xs, eps=eps)
274
+ return xs
275
+
276
+ def logit(self, xs):
277
+ eps = torch.finfo(xs.dtype).eps
278
+ xs = torch.logit(xs, eps=eps)
279
+ return xs
280
+
281
+ def dirimulti_param(self, xs):
282
+ xs = self.dirimulti_mass * self.sigmoid(xs)
283
+ return xs
284
+
285
+ def multi_param(self, xs):
286
+ xs = self.softmax(xs)
287
+ return xs
288
+
289
+ def model1(self, xs):
290
+ pyro.module('PerturbE', self)
291
+
292
+ eps = torch.finfo(xs.dtype).eps
293
+ batch_size = xs.size(0)
294
+ self.options = dict(dtype=xs.dtype, device=xs.device)
295
+
296
+ if self.loss_func=='negbinomial':
297
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
298
+ xs.new_ones(self.input_size), constraint=constraints.positive)
299
+
300
+ if self.use_zeroinflate:
301
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
302
+
303
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
304
+
305
+ I = torch.eye(self.code_size)
306
+ if self.latent_dist=='studentt':
307
+ acs_dof,acs_loc = self.codebook(I)
308
+ else:
309
+ acs_loc = self.codebook(I)
310
+
311
+ with pyro.plate('data'):
312
+ prior = torch.zeros(batch_size, self.code_size, **self.options)
313
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
314
+
315
+ zn_loc = torch.matmul(ns,acs_loc)
316
+ #zn_scale = torch.matmul(ns,acs_scale)
317
+ zn_scale = acs_scale
318
+
319
+ if self.latent_dist == 'studentt':
320
+ prior_dof = torch.matmul(ns,acs_dof)
321
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
322
+ elif self.latent_dist == 'laplacian':
323
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
324
+ elif self.latent_dist == 'cauchy':
325
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
326
+ elif self.latent_dist == 'normal':
327
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
328
+ elif self.latent_dist == 'gumbel':
329
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
330
+
331
+ zs = zns
332
+ concentrate = self.decoder_concentrate(zs)
333
+ if self.loss_func in ['bernoulli']:
334
+ log_theta = concentrate
335
+ else:
336
+ rate = concentrate.exp()
337
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
338
+ if self.loss_func == 'poisson':
339
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
340
+
341
+ if self.loss_func == 'negbinomial':
342
+ if self.use_zeroinflate:
343
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
344
+ else:
345
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
346
+ elif self.loss_func == 'poisson':
347
+ if self.use_zeroinflate:
348
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
349
+ else:
350
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
351
+ elif self.loss_func == 'multinomial':
352
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
353
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
354
+ elif self.loss_func == 'bernoulli':
355
+ if self.use_zeroinflate:
356
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
357
+ else:
358
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
359
+
360
+ def guide1(self, xs):
361
+ with pyro.plate('data'):
362
+ #zn_loc, zn_scale = self.encoder_zn(xs)
363
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
364
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
365
+
366
+ alpha = self.encoder_n(zns)
367
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
368
+
369
+ def model2(self, xs, us=None):
370
+ pyro.module('PerturbE', self)
371
+
372
+ eps = torch.finfo(xs.dtype).eps
373
+ batch_size = xs.size(0)
374
+ self.options = dict(dtype=xs.dtype, device=xs.device)
375
+
376
+ if self.loss_func=='negbinomial':
377
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
378
+ xs.new_ones(self.input_size), constraint=constraints.positive)
379
+
380
+ if self.use_zeroinflate:
381
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
382
+
383
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
384
+
385
+ I = torch.eye(self.code_size)
386
+ if self.latent_dist=='studentt':
387
+ acs_dof,acs_loc = self.codebook(I)
388
+ else:
389
+ acs_loc = self.codebook(I)
390
+
391
+ with pyro.plate('data'):
392
+ prior = torch.zeros(batch_size, self.code_size, **self.options)
393
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
394
+
395
+ zn_loc = torch.matmul(ns,acs_loc)
396
+ #zn_scale = torch.matmul(ns,acs_scale)
397
+ zn_scale = acs_scale
398
+
399
+ if self.latent_dist == 'studentt':
400
+ prior_dof = torch.matmul(ns,acs_dof)
401
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
402
+ elif self.latent_dist == 'laplacian':
403
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
404
+ elif self.latent_dist == 'cauchy':
405
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
406
+ elif self.latent_dist == 'normal':
407
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
408
+ elif self.latent_dist == 'gumbel':
409
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
410
+
411
+ if self.cell_factor_size>0:
412
+ zus = self._perturb_effects(us)
413
+ zs = zns+zus
414
+ else:
415
+ zs = zns
416
+
417
+ concentrate = self.decoder_concentrate(zs)
418
+ if self.loss_func in ['bernoulli']:
419
+ log_theta = concentrate
420
+ else:
421
+ rate = concentrate.exp()
422
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
423
+ if self.loss_func == 'poisson':
424
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
425
+
426
+ if self.loss_func == 'negbinomial':
427
+ if self.use_zeroinflate:
428
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
429
+ else:
430
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
431
+ elif self.loss_func == 'poisson':
432
+ if self.use_zeroinflate:
433
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
434
+ else:
435
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
436
+ elif self.loss_func == 'multinomial':
437
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
438
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
439
+ elif self.loss_func == 'bernoulli':
440
+ if self.use_zeroinflate:
441
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
442
+ else:
443
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
444
+
445
+ def guide2(self, xs, us=None):
446
+ with pyro.plate('data'):
447
+ #zn_loc, zn_scale = self.encoder_zn(xs)
448
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
449
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
450
+
451
+ alpha = self.encoder_n(zns)
452
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
453
+
454
+ def model3(self, xs, ys, embeds=None):
455
+ pyro.module('PerturbE', self)
456
+
457
+ eps = torch.finfo(xs.dtype).eps
458
+ batch_size = xs.size(0)
459
+ self.options = dict(dtype=xs.dtype, device=xs.device)
460
+
461
+ if self.loss_func=='negbinomial':
462
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
463
+ xs.new_ones(self.input_size), constraint=constraints.positive)
464
+
465
+ if self.use_zeroinflate:
466
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
467
+
468
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
469
+
470
+ I = torch.eye(self.code_size)
471
+ if self.latent_dist=='studentt':
472
+ acs_dof,acs_loc = self.codebook(I)
473
+ else:
474
+ acs_loc = self.codebook(I)
475
+
476
+ with pyro.plate('data'):
477
+ #prior = torch.zeros(batch_size, self.code_size, **self.options)
478
+ prior = self.encoder_n(xs)
479
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
480
+
481
+ zn_loc = torch.matmul(ns,acs_loc)
482
+ #prior_scale = torch.matmul(ns,acs_scale)
483
+ zn_scale = acs_scale
484
+
485
+ if self.latent_dist=='studentt':
486
+ prior_dof = torch.matmul(ns,acs_dof)
487
+ if embeds is None:
488
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
489
+ else:
490
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
491
+ elif self.latent_dist=='laplacian':
492
+ if embeds is None:
493
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
494
+ else:
495
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
496
+ elif self.latent_dist=='cauchy':
497
+ if embeds is None:
498
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
499
+ else:
500
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
501
+ elif self.latent_dist=='normal':
502
+ if embeds is None:
503
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
504
+ else:
505
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
506
+ elif self.z_dist == 'gumbel':
507
+ if embeds is None:
508
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
509
+ else:
510
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
511
+
512
+ zs = zns
513
+
514
+ concentrate = self.decoder_concentrate(zs)
515
+ if self.loss_func in ['bernoulli']:
516
+ log_theta = concentrate
517
+ else:
518
+ rate = concentrate.exp()
519
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
520
+ if self.loss_func == 'poisson':
521
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
522
+
523
+ if self.loss_func == 'negbinomial':
524
+ if self.use_zeroinflate:
525
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
526
+ else:
527
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
528
+ elif self.loss_func == 'poisson':
529
+ if self.use_zeroinflate:
530
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
531
+ else:
532
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
533
+ elif self.loss_func == 'multinomial':
534
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
535
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
536
+ elif self.loss_func == 'bernoulli':
537
+ if self.use_zeroinflate:
538
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
539
+ else:
540
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
541
+
542
+ def guide3(self, xs, ys, embeds=None):
543
+ with pyro.plate('data'):
544
+ if embeds is None:
545
+ #zn_loc, zn_scale = self.encoder_zn(xs)
546
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
547
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
548
+ else:
549
+ zns = embeds
550
+
551
+ def model4(self, xs, us, ys, embeds=None):
552
+ pyro.module('PerturbE', self)
553
+
554
+ eps = torch.finfo(xs.dtype).eps
555
+ batch_size = xs.size(0)
556
+ self.options = dict(dtype=xs.dtype, device=xs.device)
557
+
558
+ if self.loss_func=='negbinomial':
559
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
560
+ xs.new_ones(self.input_size), constraint=constraints.positive)
561
+
562
+ if self.use_zeroinflate:
563
+ gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
564
+
565
+ acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
566
+
567
+ I = torch.eye(self.code_size)
568
+ if self.latent_dist=='studentt':
569
+ acs_dof,acs_loc = self.codebook(I)
570
+ else:
571
+ acs_loc = self.codebook(I)
572
+
573
+ with pyro.plate('data'):
574
+ #prior = torch.zeros(batch_size, self.code_size, **self.options)
575
+ prior = self.encoder_n(xs)
576
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
577
+
578
+ zn_loc = torch.matmul(ns,acs_loc)
579
+ #prior_scale = torch.matmul(ns,acs_scale)
580
+ zn_scale = acs_scale
581
+
582
+ if self.latent_dist=='studentt':
583
+ prior_dof = torch.matmul(ns,acs_dof)
584
+ if embeds is None:
585
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1))
586
+ else:
587
+ zns = pyro.sample('zn', dist.StudentT(df=prior_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
588
+ elif self.latent_dist=='laplacian':
589
+ if embeds is None:
590
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
591
+ else:
592
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
593
+ elif self.latent_dist=='cauchy':
594
+ if embeds is None:
595
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
596
+ else:
597
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
598
+ elif self.latent_dist=='normal':
599
+ if embeds is None:
600
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
601
+ else:
602
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
603
+ elif self.z_dist == 'gumbel':
604
+ if embeds is None:
605
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
606
+ else:
607
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
608
+
609
+ if self.cell_factor_size>0:
610
+ #zus = None
611
+ #for i in np.arange(self.cell_factor_size):
612
+ # if i==0:
613
+ # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
614
+ # else:
615
+ # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
616
+ zus = self._perturb_effects(us)
617
+ zs = zns+zus
618
+ else:
619
+ zs = zns
620
+
621
+ concentrate = self.decoder_concentrate(zs)
622
+ if self.loss_func in ['bernoulli']:
623
+ log_theta = concentrate
624
+ else:
625
+ rate = concentrate.exp()
626
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
627
+ if self.loss_func == 'poisson':
628
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
629
+
630
+ if self.loss_func == 'negbinomial':
631
+ if self.use_zeroinflate:
632
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
633
+ else:
634
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
635
+ elif self.loss_func == 'poisson':
636
+ if self.use_zeroinflate:
637
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
638
+ else:
639
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
640
+ elif self.loss_func == 'multinomial':
641
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
642
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
643
+ elif self.loss_func == 'bernoulli':
644
+ if self.use_zeroinflate:
645
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
646
+ else:
647
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
648
+
649
+ def guide4(self, xs, us, ys, embeds=None):
650
+ with pyro.plate('data'):
651
+ if embeds is None:
652
+ #zn_loc, zn_scale = self.encoder_zn(xs)
653
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
654
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
655
+ else:
656
+ zns = embeds
657
+
658
+ def _perturb_effects(self, us):
659
+ zus = self._cell_response(us)
660
+ return zus
661
+
662
+ def _get_codebook_identity(self):
663
+ return torch.eye(self.code_size, **self.options)
664
+
665
+ def _get_codebook(self):
666
+ I = torch.eye(self.code_size, **self.options)
667
+ if self.latent_dist=='studentt':
668
+ _,cb = self.codebook(I)
669
+ else:
670
+ cb = self.codebook(I)
671
+ return cb
672
+
673
+ def get_codebook(self):
674
+ """
675
+ Return the mean part of metacell codebook
676
+ """
677
+ cb = self._get_codebook()
678
+ cb = tensor_to_numpy(cb)
679
+ return cb
680
+
681
+ def _get_basal_embedding(self, xs):
682
+ loc, scale = self.encoder_zn(xs)
683
+ return loc, scale
684
+
685
+ def get_basal_embedding(self,
686
+ xs,
687
+ batch_size: int = 1024):
688
+ """
689
+ Return cells' basal latent representations
690
+
691
+ Parameters
692
+ ----------
693
+ xs
694
+ Single-cell expression matrix. It should be a Numpy array or a Pytorch Tensor.
695
+ batch_size
696
+ Size of batch processing.
697
+ use_decoder
698
+ If toggled on, the latent representations will be reconstructed from the metacell codebook
699
+ soft_assign
700
+ If toggled on, the assignments of cells will use probabilistic values.
701
+ """
702
+ xs = self.preprocess(xs)
703
+ xs = convert_to_tensor(xs, device=self.get_device())
704
+ dataset = CustomDataset(xs)
705
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
706
+
707
+ Z = []
708
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
709
+ for X_batch, _ in dataloader:
710
+ zns,_ = self._get_basal_embedding(X_batch)
711
+ Z.append(tensor_to_numpy(zns))
712
+ pbar.update(1)
713
+
714
+ Z = np.concatenate(Z)
715
+ return Z
716
+
717
+ def _code(self, xs):
718
+ if self.supervised_mode:
719
+ alpha = self.encoder_n(xs)
720
+ else:
721
+ #zns,_ = self.encoder_zn(xs)
722
+ zns,_ = self._get_basal_embedding(xs)
723
+ alpha = self.encoder_n(zns)
724
+ return alpha
725
+
726
+ def code(self, xs, batch_size=1024):
727
+ xs = self.preprocess(xs)
728
+ xs = convert_to_tensor(xs, device=self.get_device())
729
+ dataset = CustomDataset(xs)
730
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
731
+
732
+ A = []
733
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
734
+ for X_batch, _ in dataloader:
735
+ a = self._code(X_batch)
736
+ A.append(tensor_to_numpy(a))
737
+ pbar.update(1)
738
+
739
+ A = np.concatenate(A)
740
+ return A
741
+
742
+ def _soft_assignments(self, xs):
743
+ alpha = self._code(xs)
744
+ alpha = self.softmax(alpha)
745
+ return alpha
746
+
747
+ def soft_assignments(self, xs, batch_size=1024):
748
+ """
749
+ Map cells to metacells and return the probabilistic values of metacell assignments
750
+ """
751
+ xs = self.preprocess(xs)
752
+ xs = convert_to_tensor(xs, device=self.get_device())
753
+ dataset = CustomDataset(xs)
754
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
755
+
756
+ A = []
757
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
758
+ for X_batch, _ in dataloader:
759
+ a = self._soft_assignments(X_batch)
760
+ A.append(tensor_to_numpy(a))
761
+ pbar.update(1)
762
+
763
+ A = np.concatenate(A)
764
+ return A
765
+
766
+ def _hard_assignments(self, xs):
767
+ alpha = self._code(xs)
768
+ res, ind = torch.topk(alpha, 1)
769
+ ns = torch.zeros_like(alpha).scatter_(1, ind, 1.0)
770
+ return ns
771
+
772
+ def hard_assignments(self, xs, batch_size=1024):
773
+ """
774
+ Map cells to metacells and return the assigned metacell identities.
775
+ """
776
+ xs = self.preprocess(xs)
777
+ xs = convert_to_tensor(xs, device=self.get_device())
778
+ dataset = CustomDataset(xs)
779
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
780
+
781
+ A = []
782
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
783
+ for X_batch, _ in dataloader:
784
+ a = self._hard_assignments(X_batch)
785
+ A.append(tensor_to_numpy(a))
786
+ pbar.update(1)
787
+
788
+ A = np.concatenate(A)
789
+ return A
790
+
791
+ def predict(self, xs, perturbs_us, library_sizes=None):
792
+ perturbs_reference = np.array(perturbs_reference)
793
+
794
+ # basal embedding
795
+ zs = self.get_basal_embedding(xs)
796
+ dzs = self.get_cell_response(perturbs_us)
797
+ zs = zs + dzs
798
+
799
+ if library_sizes is None:
800
+ library_sizes = np.sum(xs, axis=1, keepdims=True)
801
+ elif type(library_sizes) == list:
802
+ library_sizes = np.array(library_sizes)
803
+ library_sizes = library_sizes.reshape(-1,1)
804
+ elif len(library_sizes.shape)==1:
805
+ library_sizes = library_sizes.reshape(-1,1)
806
+
807
+ counts = self.get_counts(zs, library_sizes=library_sizes)
808
+
809
+ return counts, zs
810
+
811
+ def _cell_response(self, perturb):
812
+ ms = self.cell_factor_effect(perturb)
813
+ return ms
814
+
815
+ def get_cell_response(self,
816
+ perturb_us,
817
+ batch_size: int = 1024):
818
+ """
819
+ Return cells' changes in the latent space induced by specific perturbation of a factor
820
+
821
+ """
822
+ #xs = self.preprocess(xs)
823
+ ps = convert_to_tensor(perturb_us, device=self.get_device())
824
+ dataset = CustomDataset(ps)
825
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
826
+
827
+ Z = []
828
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
829
+ for P_batch, _ in dataloader:
830
+ zns = self._cell_response(P_batch)
831
+ Z.append(tensor_to_numpy(zns))
832
+ pbar.update(1)
833
+
834
+ Z = np.concatenate(Z)
835
+ return Z
836
+
837
+ def _get_expression_response(self, delta_zs):
838
+ return self.decoder_concentrate(delta_zs)
839
+
840
+ def get_expression_response(self,
841
+ delta_zs,
842
+ batch_size: int = 1024):
843
+ """
844
+ Return cells' changes in the feature space induced by specific perturbation of a factor
845
+
846
+ """
847
+ delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
848
+ dataset = CustomDataset(delta_zs)
849
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
850
+
851
+ R = []
852
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
853
+ for delta_Z_batch, _ in dataloader:
854
+ r = self._get_expression_response(delta_Z_batch)
855
+ R.append(tensor_to_numpy(r))
856
+ pbar.update(1)
857
+
858
+ R = np.concatenate(R)
859
+ return R
860
+
861
+ def _count(self, concentrate, library_size=None):
862
+ if self.loss_func == 'bernoulli':
863
+ #counts = self.sigmoid(concentrate)
864
+ counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
865
+ elif self.loss_func == 'multinomial':
866
+ theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
867
+ counts = theta * library_size
868
+ else:
869
+ rate = concentrate.exp()
870
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
871
+ counts = theta * library_size
872
+ return counts
873
+
874
+ def get_counts(self, zs, library_sizes,
875
+ batch_size: int = 1024):
876
+
877
+ zs = convert_to_tensor(zs, device=self.get_device())
878
+
879
+ if type(library_sizes) == list:
880
+ library_sizes = np.array(library_sizes).reshape(-1,1)
881
+ elif len(library_sizes.shape)==1:
882
+ library_sizes = library_sizes.reshape(-1,1)
883
+ ls = convert_to_tensor(library_sizes, device=self.get_device())
884
+
885
+ dataset = CustomDataset2(zs,ls)
886
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
887
+
888
+ E = []
889
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
890
+ for Z_batch, L_batch, _ in dataloader:
891
+ concentrate = self._get_expression_response(Z_batch)
892
+ counts = self._count(concentrate, L_batch)
893
+ E.append(tensor_to_numpy(counts))
894
+ pbar.update(1)
895
+
896
+ E = np.concatenate(E)
897
+ return E
898
+
899
+ def preprocess(self, xs, threshold=0):
900
+ if self.loss_func == 'bernoulli':
901
+ ad = sc.AnnData(xs)
902
+ binarize(ad, threshold=threshold)
903
+ xs = ad.X.copy()
904
+ else:
905
+ xs = np.round(xs)
906
+
907
+ if sparse.issparse(xs):
908
+ xs = xs.toarray()
909
+ return xs
910
+
911
+ def fit(self, xs,
912
+ us = None,
913
+ ys = None,
914
+ zs = None,
915
+ num_epochs: int = 500,
916
+ learning_rate: float = 0.0001,
917
+ batch_size: int = 256,
918
+ algo: Literal['adam','rmsprop','adamw'] = 'adam',
919
+ beta_1: float = 0.9,
920
+ weight_decay: float = 0.005,
921
+ decay_rate: float = 0.9,
922
+ config_enum: str = 'parallel',
923
+ threshold: int = 0,
924
+ use_jax: bool = True):
925
+ """
926
+ Train the PerturbE model.
927
+
928
+ Parameters
929
+ ----------
930
+ xs
931
+ Single-cell experssion matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are features.
932
+ us
933
+ cell-level factor matrix.
934
+ ys
935
+ Desired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are desired factors.
936
+ num_epochs
937
+ Number of training epochs.
938
+ learning_rate
939
+ Parameter for training.
940
+ batch_size
941
+ Size of batch processing.
942
+ algo
943
+ Optimization algorithm.
944
+ beta_1
945
+ Parameter for optimization.
946
+ weight_decay
947
+ Parameter for optimization.
948
+ decay_rate
949
+ Parameter for optimization.
950
+ use_jax
951
+ If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
952
+ the Python script or Jupyter notebook. It is OK if it is used when runing PerturbE in the shell command.
953
+ """
954
+ xs = self.preprocess(xs, threshold=threshold)
955
+ xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
956
+ if us is not None:
957
+ us = convert_to_tensor(us, dtype=self.dtype, device=self.get_device())
958
+ if ys is not None:
959
+ ys = convert_to_tensor(ys, dtype=self.dtype, device=self.get_device())
960
+ if zs is not None:
961
+ zs = convert_to_tensor(zs, dtype=self.dtype, device=self.get_device())
962
+
963
+ dataset = CustomDataset4(xs, us, ys, zs)
964
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
965
+
966
+ # setup the optimizer
967
+ optim_params = {'lr': learning_rate, 'betas': (beta_1, 0.999), 'weight_decay': weight_decay}
968
+
969
+ if algo.lower()=='rmsprop':
970
+ optimizer = torch.optim.RMSprop
971
+ elif algo.lower()=='adam':
972
+ optimizer = torch.optim.Adam
973
+ elif algo.lower() == 'adamw':
974
+ optimizer = torch.optim.AdamW
975
+ else:
976
+ raise ValueError("An optimization algorithm must be specified.")
977
+ scheduler = ExponentialLR({'optimizer': optimizer, 'optim_args': optim_params, 'gamma': decay_rate})
978
+
979
+ pyro.clear_param_store()
980
+
981
+ # set up the loss(es) for inference, wrapping the guide in config_enumerate builds the loss as a sum
982
+ # by enumerating each class label form the sampled discrete categorical distribution in the model
983
+ Elbo = JitTraceEnum_ELBO if use_jax else TraceEnum_ELBO
984
+ elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False)
985
+ if us is None:
986
+ if ys is None:
987
+ guide = config_enumerate(self.guide1, config_enum, expand=True)
988
+ loss_basic = SVI(self.model1, guide, scheduler, loss=elbo)
989
+ else:
990
+ guide = config_enumerate(self.guide3, config_enum, expand=True)
991
+ loss_basic = SVI(self.model3, guide, scheduler, loss=elbo)
992
+ else:
993
+ if ys is None:
994
+ guide = config_enumerate(self.guide2, config_enum, expand=True)
995
+ loss_basic = SVI(self.model2, guide, scheduler, loss=elbo)
996
+ else:
997
+ guide = config_enumerate(self.guide4, config_enum, expand=True)
998
+ loss_basic = SVI(self.model4, guide, scheduler, loss=elbo)
999
+
1000
+ # build a list of all losses considered
1001
+ losses = [loss_basic]
1002
+ num_losses = len(losses)
1003
+
1004
+ with tqdm(total=num_epochs, desc='Training', unit='epoch') as pbar:
1005
+ for epoch in range(num_epochs):
1006
+ epoch_losses = [0.0] * num_losses
1007
+ for batch_x, batch_u, batch_y, batch_z, _ in dataloader:
1008
+ if us is None:
1009
+ batch_u = None
1010
+ if ys is None:
1011
+ batch_y = None
1012
+ if zs is None:
1013
+ batch_z = None
1014
+
1015
+ for loss_id in range(num_losses):
1016
+ if batch_u is None:
1017
+ if batch_y is None:
1018
+ new_loss = losses[loss_id].step(batch_x)
1019
+ else:
1020
+ new_loss = losses[loss_id].step(batch_x, batch_y, batch_z)
1021
+ else:
1022
+ if batch_y is None:
1023
+ new_loss = losses[loss_id].step(batch_x, batch_u)
1024
+ else:
1025
+ new_loss = losses[loss_id].step(batch_x, batch_u, batch_y, batch_z)
1026
+ epoch_losses[loss_id] += new_loss
1027
+
1028
+ avg_epoch_losses_ = map(lambda v: v / len(dataloader), epoch_losses)
1029
+ avg_epoch_losses = map(lambda v: "{:.4f}".format(v), avg_epoch_losses_)
1030
+
1031
+ # store the loss
1032
+ str_loss = " ".join(map(str, avg_epoch_losses))
1033
+
1034
+ # Update progress bar
1035
+ pbar.set_postfix({'loss': str_loss})
1036
+ pbar.update(1)
1037
+
1038
+ @classmethod
1039
+ def save_model(cls, model, file_path, compression=False):
1040
+ """Save the model to the specified file path."""
1041
+ file_path = os.path.abspath(file_path)
1042
+
1043
+ model.eval()
1044
+ if compression:
1045
+ with gzip.open(file_path, 'wb') as pickle_file:
1046
+ pickle.dump(model, pickle_file)
1047
+ else:
1048
+ with open(file_path, 'wb') as pickle_file:
1049
+ pickle.dump(model, pickle_file)
1050
+
1051
+ print(f'Model saved to {file_path}')
1052
+
1053
+ @classmethod
1054
+ def load_model(cls, file_path):
1055
+ """Load the model from the specified file path and return an instance."""
1056
+ print(f'Model loaded from {file_path}')
1057
+
1058
+ file_path = os.path.abspath(file_path)
1059
+ if file_path.endswith('gz'):
1060
+ with gzip.open(file_path, 'rb') as pickle_file:
1061
+ model = pickle.load(pickle_file)
1062
+ else:
1063
+ with open(file_path, 'rb') as pickle_file:
1064
+ model = pickle.load(pickle_file)
1065
+
1066
+ return model
1067
+
1068
+
1069
+ EXAMPLE_RUN = (
1070
+ "example run: PerturbE --help"
1071
+ )
1072
+
1073
+ def parse_args():
1074
+ parser = argparse.ArgumentParser(
1075
+ description="PerturbE\n{}".format(EXAMPLE_RUN))
1076
+
1077
+ parser.add_argument(
1078
+ "--cuda", action="store_true", help="use GPU(s) to speed up training"
1079
+ )
1080
+ parser.add_argument(
1081
+ "--jit", action="store_true", help="use PyTorch jit to speed up training"
1082
+ )
1083
+ parser.add_argument(
1084
+ "-n", "--num-epochs", default=200, type=int, help="number of epochs to run"
1085
+ )
1086
+ parser.add_argument(
1087
+ "-enum",
1088
+ "--enum-discrete",
1089
+ default="parallel",
1090
+ help="parallel, sequential or none. uses parallel enumeration by default",
1091
+ )
1092
+ parser.add_argument(
1093
+ "-data",
1094
+ "--data-file",
1095
+ default=None,
1096
+ type=str,
1097
+ help="the data file",
1098
+ )
1099
+ parser.add_argument(
1100
+ "-cf",
1101
+ "--cell-factor-file",
1102
+ default=None,
1103
+ type=str,
1104
+ help="the file for the record of cell-level factors",
1105
+ )
1106
+ parser.add_argument(
1107
+ "-bs",
1108
+ "--batch-size",
1109
+ default=1000,
1110
+ type=int,
1111
+ help="number of cells to be considered in a batch",
1112
+ )
1113
+ parser.add_argument(
1114
+ "-lr",
1115
+ "--learning-rate",
1116
+ default=0.0001,
1117
+ type=float,
1118
+ help="learning rate for Adam optimizer",
1119
+ )
1120
+ parser.add_argument(
1121
+ "-cs",
1122
+ "--codebook-size",
1123
+ default=100,
1124
+ type=int,
1125
+ help="size of vector quantization codebook",
1126
+ )
1127
+ parser.add_argument(
1128
+ "--z-dist",
1129
+ default='gumbel',
1130
+ type=str,
1131
+ choices=['normal','laplacian','studentt','gumbel','cauchy'],
1132
+ help="distribution model for latent representation",
1133
+ )
1134
+ parser.add_argument(
1135
+ "-zd",
1136
+ "--z-dim",
1137
+ default=10,
1138
+ type=int,
1139
+ help="size of the tensor representing the latent variable z variable",
1140
+ )
1141
+ parser.add_argument(
1142
+ "-likeli",
1143
+ "--likelihood",
1144
+ default='negbinomial',
1145
+ type=str,
1146
+ choices=['negbinomial', 'multinomial', 'poisson', 'bernoulli'],
1147
+ help="specify the distribution likelihood function",
1148
+ )
1149
+ parser.add_argument(
1150
+ "-zi",
1151
+ "--zeroinflate",
1152
+ action="store_true",
1153
+ help="use zero-inflated estimation",
1154
+ )
1155
+ parser.add_argument(
1156
+ "-id",
1157
+ "--inverse-dispersion",
1158
+ default=10.0,
1159
+ type=float,
1160
+ help="inverse dispersion prior for negative binomial",
1161
+ )
1162
+ parser.add_argument(
1163
+ "-hl",
1164
+ "--hidden-layers",
1165
+ nargs="+",
1166
+ default=[500],
1167
+ type=int,
1168
+ help="a tuple (or list) of MLP layers to be used in the neural networks "
1169
+ "representing the parameters of the distributions in our model",
1170
+ )
1171
+ parser.add_argument(
1172
+ "-hla",
1173
+ "--hidden-layer-activation",
1174
+ default='relu',
1175
+ type=str,
1176
+ choices=['relu','softplus','leakyrelu','linear'],
1177
+ help="activation function for hidden layers",
1178
+ )
1179
+ parser.add_argument(
1180
+ "-plf",
1181
+ "--post-layer-function",
1182
+ nargs="+",
1183
+ default=['layernorm'],
1184
+ type=str,
1185
+ help="post functions for hidden layers, could be none, dropout, layernorm, batchnorm, or combination, default is 'dropout layernorm'",
1186
+ )
1187
+ parser.add_argument(
1188
+ "-paf",
1189
+ "--post-activation-function",
1190
+ nargs="+",
1191
+ default=['none'],
1192
+ type=str,
1193
+ help="post functions for activation layers, could be none or dropout, default is 'none'",
1194
+ )
1195
+ parser.add_argument(
1196
+ "-64",
1197
+ "--float64",
1198
+ action="store_true",
1199
+ help="use double float precision",
1200
+ )
1201
+ parser.add_argument(
1202
+ "-dr",
1203
+ "--decay-rate",
1204
+ default=0.9,
1205
+ type=float,
1206
+ help="decay rate for Adam optimizer",
1207
+ )
1208
+ parser.add_argument(
1209
+ "--layer-dropout-rate",
1210
+ default=0.1,
1211
+ type=float,
1212
+ help="droput rate for neural networks",
1213
+ )
1214
+ parser.add_argument(
1215
+ "-b1",
1216
+ "--beta-1",
1217
+ default=0.95,
1218
+ type=float,
1219
+ help="beta-1 parameter for Adam optimizer",
1220
+ )
1221
+ parser.add_argument(
1222
+ "--seed",
1223
+ default=None,
1224
+ type=int,
1225
+ help="seed for controlling randomness in this example",
1226
+ )
1227
+ parser.add_argument(
1228
+ "--save-model",
1229
+ default=None,
1230
+ type=str,
1231
+ help="path to save model for prediction",
1232
+ )
1233
+ args = parser.parse_args()
1234
+ return args
1235
+
1236
+ def main():
1237
+ args = parse_args()
1238
+ assert (
1239
+ (args.data_file is not None) and (
1240
+ os.path.exists(args.data_file))
1241
+ ), "data file must be provided"
1242
+
1243
+ if args.seed is not None:
1244
+ set_random_seed(args.seed)
1245
+
1246
+ if args.float64:
1247
+ dtype = torch.float64
1248
+ torch.set_default_dtype(torch.float64)
1249
+ else:
1250
+ dtype = torch.float32
1251
+ torch.set_default_dtype(torch.float32)
1252
+
1253
+ xs = dt.fread(file=args.data_file, header=True).to_numpy()
1254
+ us = None
1255
+ if args.cell_factor_file is not None:
1256
+ us = dt.fread(file=args.cell_factor_file, header=True).to_numpy()
1257
+
1258
+ input_size = xs.shape[1]
1259
+ cell_factor_size = 0 if us is None else us.shape[1]
1260
+
1261
+ ###########################################
1262
+ perturbe = PerturbE(
1263
+ input_size=input_size,
1264
+ cell_factor_size=cell_factor_size,
1265
+ inverse_dispersion=args.inverse_dispersion,
1266
+ z_dim=args.z_dim,
1267
+ hidden_layers=args.hidden_layers,
1268
+ hidden_layer_activation=args.hidden_layer_activation,
1269
+ use_cuda=args.cuda,
1270
+ config_enum=args.enum_discrete,
1271
+ use_zeroinflate=args.zeroinflate,
1272
+ loss_func=args.likelihood,
1273
+ nn_dropout=args.layer_dropout_rate,
1274
+ post_layer_fct=args.post_layer_function,
1275
+ post_act_fct=args.post_activation_function,
1276
+ codebook_size=args.codebook_size,
1277
+ z_dist = args.z_dist,
1278
+ dtype=dtype,
1279
+ )
1280
+
1281
+ perturbe.fit(xs, us=us,
1282
+ num_epochs=args.num_epochs,
1283
+ learning_rate=args.learning_rate,
1284
+ batch_size=args.batch_size,
1285
+ beta_1=args.beta_1,
1286
+ decay_rate=args.decay_rate,
1287
+ use_jax=args.jit,
1288
+ config_enum=args.enum_discrete,
1289
+ )
1290
+
1291
+ if args.save_model is not None:
1292
+ if args.save_model.endswith('gz'):
1293
+ PerturbE.save_model(perturbe, args.save_model, compression=True)
1294
+ else:
1295
+ PerturbE.save_model(perturbe, args.save_model)
1296
+
1297
+
1298
+
1299
+ if __name__ == "__main__":
1300
+ main()