SURE-tools 2.4.25__tar.gz → 2.4.38__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.4.25 → sure_tools-2.4.38}/PKG-INFO +1 -1
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/DensityFlow.py +149 -69
- sure_tools-2.4.25/SURE/DensityFlow2.py → sure_tools-2.4.38/SURE/DensityFlowLinear.py +91 -99
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/__init__.py +3 -3
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE_tools.egg-info/SOURCES.txt +1 -1
- {sure_tools-2.4.25 → sure_tools-2.4.38}/setup.py +1 -1
- {sure_tools-2.4.25 → sure_tools-2.4.38}/LICENSE +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/README.md +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/EfficientTranscriptomeDecoder.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/PerturbE.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/PerturbationAwareDecoder.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/SURE.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/SimpleTranscriptomeDecoder.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/TranscriptomeDecoder.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/VirtualCellDecoder.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/atac/utils.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/utils/queue.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/utils/utils.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.38}/setup.cfg +0 -0
|
@@ -57,16 +57,16 @@ def set_random_seed(seed):
|
|
|
57
57
|
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
|
-
codebook_size: int =
|
|
60
|
+
codebook_size: int = 100,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
62
|
turn_off_cell_specific: bool = False,
|
|
63
63
|
supervised_mode: bool = False,
|
|
64
|
-
z_dim: int =
|
|
65
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
67
|
-
|
|
64
|
+
z_dim: int = 50,
|
|
65
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
|
+
dispersion: float = 8.0,
|
|
68
68
|
use_zeroinflate: bool = False,
|
|
69
|
-
hidden_layers: list = [
|
|
69
|
+
hidden_layers: list = [1024],
|
|
70
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
71
71
|
nn_dropout: float = 0.1,
|
|
72
72
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
self.input_size = input_size
|
|
83
83
|
self.cell_factor_size = cell_factor_size
|
|
84
|
-
self.
|
|
84
|
+
self.dispersion = dispersion
|
|
85
85
|
self.latent_dim = z_dim
|
|
86
86
|
self.hidden_layers = hidden_layers
|
|
87
87
|
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
88
|
+
self.config_enum = config_enum
|
|
88
89
|
self.allow_broadcast = config_enum == 'parallel'
|
|
89
90
|
self.use_cuda = use_cuda
|
|
90
91
|
self.loss_func = loss_func
|
|
@@ -107,6 +108,7 @@ class DensityFlow(nn.Module):
|
|
|
107
108
|
|
|
108
109
|
self.codebook_weights = None
|
|
109
110
|
|
|
111
|
+
self.seed = seed
|
|
110
112
|
set_random_seed(seed)
|
|
111
113
|
self.setup_networks()
|
|
112
114
|
|
|
@@ -258,7 +260,7 @@ class DensityFlow(nn.Module):
|
|
|
258
260
|
)
|
|
259
261
|
)
|
|
260
262
|
|
|
261
|
-
self.
|
|
263
|
+
self.decoder_log_mu = MLP(
|
|
262
264
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
263
265
|
activation=activate_fct,
|
|
264
266
|
output_activation=None,
|
|
@@ -348,8 +350,8 @@ class DensityFlow(nn.Module):
|
|
|
348
350
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
349
351
|
|
|
350
352
|
if self.loss_func=='negbinomial':
|
|
351
|
-
|
|
352
|
-
|
|
353
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
354
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
353
355
|
|
|
354
356
|
if self.use_zeroinflate:
|
|
355
357
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -383,28 +385,32 @@ class DensityFlow(nn.Module):
|
|
|
383
385
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
384
386
|
|
|
385
387
|
zs = zns
|
|
386
|
-
|
|
388
|
+
log_mu = self.decoder_log_mu(zs)
|
|
387
389
|
if self.loss_func in ['bernoulli']:
|
|
388
|
-
log_theta =
|
|
390
|
+
log_theta = log_mu
|
|
391
|
+
elif self.loss_func == 'negbinomial':
|
|
392
|
+
mu = log_mu.exp()
|
|
389
393
|
else:
|
|
390
|
-
rate =
|
|
394
|
+
rate = log_mu.exp()
|
|
391
395
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
392
396
|
if self.loss_func == 'poisson':
|
|
393
397
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
394
398
|
|
|
395
399
|
if self.loss_func == 'negbinomial':
|
|
400
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
396
401
|
if self.use_zeroinflate:
|
|
397
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
402
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
403
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
398
404
|
else:
|
|
399
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
405
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
|
|
406
|
+
logits=logits).to_event(1), obs=xs)
|
|
400
407
|
elif self.loss_func == 'poisson':
|
|
401
408
|
if self.use_zeroinflate:
|
|
402
409
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
403
410
|
else:
|
|
404
411
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
405
412
|
elif self.loss_func == 'multinomial':
|
|
406
|
-
|
|
407
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
413
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
408
414
|
elif self.loss_func == 'bernoulli':
|
|
409
415
|
if self.use_zeroinflate:
|
|
410
416
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -428,8 +434,8 @@ class DensityFlow(nn.Module):
|
|
|
428
434
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
429
435
|
|
|
430
436
|
if self.loss_func=='negbinomial':
|
|
431
|
-
|
|
432
|
-
|
|
437
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
438
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
433
439
|
|
|
434
440
|
if self.use_zeroinflate:
|
|
435
441
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -468,28 +474,30 @@ class DensityFlow(nn.Module):
|
|
|
468
474
|
else:
|
|
469
475
|
zs = zns
|
|
470
476
|
|
|
471
|
-
|
|
477
|
+
log_mu = self.decoder_log_mu(zs)
|
|
472
478
|
if self.loss_func in ['bernoulli']:
|
|
473
|
-
log_theta =
|
|
479
|
+
log_theta = log_mu
|
|
480
|
+
elif self.loss_func == 'negbinomial':
|
|
481
|
+
mu = log_mu.exp()
|
|
474
482
|
else:
|
|
475
|
-
rate =
|
|
483
|
+
rate = log_mu.exp()
|
|
476
484
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
477
485
|
if self.loss_func == 'poisson':
|
|
478
486
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
479
487
|
|
|
480
488
|
if self.loss_func == 'negbinomial':
|
|
489
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
481
490
|
if self.use_zeroinflate:
|
|
482
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
491
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
483
492
|
else:
|
|
484
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
493
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
485
494
|
elif self.loss_func == 'poisson':
|
|
486
495
|
if self.use_zeroinflate:
|
|
487
496
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
488
497
|
else:
|
|
489
498
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
490
499
|
elif self.loss_func == 'multinomial':
|
|
491
|
-
|
|
492
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
500
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
493
501
|
elif self.loss_func == 'bernoulli':
|
|
494
502
|
if self.use_zeroinflate:
|
|
495
503
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -513,8 +521,8 @@ class DensityFlow(nn.Module):
|
|
|
513
521
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
514
522
|
|
|
515
523
|
if self.loss_func=='negbinomial':
|
|
516
|
-
|
|
517
|
-
|
|
524
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
525
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
518
526
|
|
|
519
527
|
if self.use_zeroinflate:
|
|
520
528
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -565,28 +573,31 @@ class DensityFlow(nn.Module):
|
|
|
565
573
|
|
|
566
574
|
zs = zns
|
|
567
575
|
|
|
568
|
-
|
|
576
|
+
log_mu = self.decoder_log_mu(zs)
|
|
569
577
|
if self.loss_func in ['bernoulli']:
|
|
570
|
-
log_theta =
|
|
578
|
+
log_theta = log_mu
|
|
579
|
+
elif self.loss_func in ['negbinomial']:
|
|
580
|
+
mu = log_mu.exp()
|
|
571
581
|
else:
|
|
572
|
-
rate =
|
|
582
|
+
rate = log_mu.exp()
|
|
573
583
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
574
584
|
if self.loss_func == 'poisson':
|
|
575
585
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
576
586
|
|
|
577
587
|
if self.loss_func == 'negbinomial':
|
|
588
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
578
589
|
if self.use_zeroinflate:
|
|
579
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
590
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
591
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
580
592
|
else:
|
|
581
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
593
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
582
594
|
elif self.loss_func == 'poisson':
|
|
583
595
|
if self.use_zeroinflate:
|
|
584
596
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
585
597
|
else:
|
|
586
598
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
587
599
|
elif self.loss_func == 'multinomial':
|
|
588
|
-
|
|
589
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
600
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
590
601
|
elif self.loss_func == 'bernoulli':
|
|
591
602
|
if self.use_zeroinflate:
|
|
592
603
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -610,8 +621,8 @@ class DensityFlow(nn.Module):
|
|
|
610
621
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
611
622
|
|
|
612
623
|
if self.loss_func=='negbinomial':
|
|
613
|
-
|
|
614
|
-
|
|
624
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
625
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
615
626
|
|
|
616
627
|
if self.use_zeroinflate:
|
|
617
628
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -672,28 +683,31 @@ class DensityFlow(nn.Module):
|
|
|
672
683
|
else:
|
|
673
684
|
zs = zns
|
|
674
685
|
|
|
675
|
-
|
|
686
|
+
log_mu = self.decoder_log_mu(zs)
|
|
676
687
|
if self.loss_func in ['bernoulli']:
|
|
677
|
-
log_theta =
|
|
688
|
+
log_theta = log_mu
|
|
689
|
+
elif self.loss_func in ['negbinomial']:
|
|
690
|
+
mu = log_mu.exp()
|
|
678
691
|
else:
|
|
679
|
-
rate =
|
|
692
|
+
rate = log_mu.exp()
|
|
680
693
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
681
694
|
if self.loss_func == 'poisson':
|
|
682
695
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
683
696
|
|
|
684
697
|
if self.loss_func == 'negbinomial':
|
|
698
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
685
699
|
if self.use_zeroinflate:
|
|
686
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
700
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
701
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
687
702
|
else:
|
|
688
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
703
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
689
704
|
elif self.loss_func == 'poisson':
|
|
690
705
|
if self.use_zeroinflate:
|
|
691
706
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
692
707
|
else:
|
|
693
708
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
694
709
|
elif self.loss_func == 'multinomial':
|
|
695
|
-
|
|
696
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
710
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
697
711
|
elif self.loss_func == 'bernoulli':
|
|
698
712
|
if self.use_zeroinflate:
|
|
699
713
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -717,13 +731,13 @@ class DensityFlow(nn.Module):
|
|
|
717
731
|
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
718
732
|
#else:
|
|
719
733
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
720
|
-
zus = self.
|
|
734
|
+
zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
721
735
|
else:
|
|
722
736
|
#if self.turn_off_cell_specific:
|
|
723
737
|
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
724
738
|
#else:
|
|
725
739
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
726
|
-
zus = zus + self.
|
|
740
|
+
zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
727
741
|
return zus
|
|
728
742
|
|
|
729
743
|
def _get_codebook_identity(self):
|
|
@@ -744,6 +758,28 @@ class DensityFlow(nn.Module):
|
|
|
744
758
|
cb = self._get_codebook()
|
|
745
759
|
cb = tensor_to_numpy(cb)
|
|
746
760
|
return cb
|
|
761
|
+
|
|
762
|
+
def _get_complete_embedding(self, xs, us):
|
|
763
|
+
basal,_ = self._get_basal_embedding(xs)
|
|
764
|
+
dzs = self._total_effects(basal, us)
|
|
765
|
+
return basal + dzs
|
|
766
|
+
|
|
767
|
+
def get_complete_embedding(self, xs, us, batch_size:int=1024):
|
|
768
|
+
xs = self.preprocess(xs)
|
|
769
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
770
|
+
us = convert_to_tensor(us, device=self.get_device())
|
|
771
|
+
dataset = CustomDataset2(xs, us)
|
|
772
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
773
|
+
|
|
774
|
+
Z = []
|
|
775
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
776
|
+
for X_batch, U_batch, _ in dataloader:
|
|
777
|
+
zns = self._get_basal_embedding(X_batch, U_batch)
|
|
778
|
+
Z.append(tensor_to_numpy(zns))
|
|
779
|
+
pbar.update(1)
|
|
780
|
+
|
|
781
|
+
Z = np.concatenate(Z)
|
|
782
|
+
return Z
|
|
747
783
|
|
|
748
784
|
def _get_basal_embedding(self, xs):
|
|
749
785
|
loc, scale = self.encoder_zn(xs)
|
|
@@ -865,12 +901,12 @@ class DensityFlow(nn.Module):
|
|
|
865
901
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
866
902
|
|
|
867
903
|
# factor effect of xs
|
|
868
|
-
dzs0 = self.
|
|
904
|
+
dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
869
905
|
|
|
870
906
|
# perturbation effect
|
|
871
907
|
ps = np.ones_like(us_i)
|
|
872
908
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
873
|
-
dzs = self.
|
|
909
|
+
dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
874
910
|
zs = zs + dzs0 + dzs
|
|
875
911
|
else:
|
|
876
912
|
zs = zs + dzs0
|
|
@@ -884,10 +920,11 @@ class DensityFlow(nn.Module):
|
|
|
884
920
|
library_sizes = library_sizes.reshape(-1,1)
|
|
885
921
|
|
|
886
922
|
counts = self.get_counts(zs, library_sizes=library_sizes)
|
|
923
|
+
log_mu = self.get_log_mu(zs)
|
|
887
924
|
|
|
888
|
-
return counts,
|
|
925
|
+
return counts, log_mu
|
|
889
926
|
|
|
890
|
-
def
|
|
927
|
+
def _cell_shift(self, zs, perturb_idx, perturb):
|
|
891
928
|
#zns,_ = self.encoder_zn(xs)
|
|
892
929
|
#zns,_ = self._get_basal_embedding(xs)
|
|
893
930
|
zns = zs
|
|
@@ -904,7 +941,7 @@ class DensityFlow(nn.Module):
|
|
|
904
941
|
|
|
905
942
|
return ms
|
|
906
943
|
|
|
907
|
-
def
|
|
944
|
+
def get_cell_shift(self,
|
|
908
945
|
zs,
|
|
909
946
|
perturb_idx,
|
|
910
947
|
perturb_us,
|
|
@@ -922,46 +959,43 @@ class DensityFlow(nn.Module):
|
|
|
922
959
|
Z = []
|
|
923
960
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
924
961
|
for Z_batch, P_batch, _ in dataloader:
|
|
925
|
-
zns = self.
|
|
962
|
+
zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
|
|
926
963
|
Z.append(tensor_to_numpy(zns))
|
|
927
964
|
pbar.update(1)
|
|
928
965
|
|
|
929
966
|
Z = np.concatenate(Z)
|
|
930
967
|
return Z
|
|
931
968
|
|
|
932
|
-
def
|
|
933
|
-
return self.
|
|
969
|
+
def _log_mu(self, zs):
|
|
970
|
+
return self.decoder_log_mu(zs)
|
|
934
971
|
|
|
935
|
-
def
|
|
936
|
-
delta_zs,
|
|
937
|
-
batch_size: int = 1024):
|
|
972
|
+
def get_log_mu(self, zs, batch_size: int = 1024):
|
|
938
973
|
"""
|
|
939
974
|
Return cells' changes in the feature space induced by specific perturbation of a factor
|
|
940
975
|
|
|
941
976
|
"""
|
|
942
|
-
|
|
943
|
-
dataset = CustomDataset(
|
|
977
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
978
|
+
dataset = CustomDataset(zs)
|
|
944
979
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
945
980
|
|
|
946
981
|
R = []
|
|
947
982
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
948
|
-
for
|
|
949
|
-
r = self.
|
|
983
|
+
for Z_batch, _ in dataloader:
|
|
984
|
+
r = self._log_mu(Z_batch)
|
|
950
985
|
R.append(tensor_to_numpy(r))
|
|
951
986
|
pbar.update(1)
|
|
952
987
|
|
|
953
988
|
R = np.concatenate(R)
|
|
954
989
|
return R
|
|
955
990
|
|
|
956
|
-
def _count(self,
|
|
991
|
+
def _count(self, log_mu, library_size=None):
|
|
957
992
|
if self.loss_func == 'bernoulli':
|
|
958
|
-
|
|
959
|
-
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
993
|
+
counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
|
|
960
994
|
elif self.loss_func == 'multinomial':
|
|
961
|
-
theta = dist.Multinomial(total_count=int(1e8), logits=
|
|
995
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
|
|
962
996
|
counts = theta * library_size
|
|
963
997
|
else:
|
|
964
|
-
rate =
|
|
998
|
+
rate = log_mu.exp()
|
|
965
999
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
966
1000
|
counts = theta * library_size
|
|
967
1001
|
return counts
|
|
@@ -983,8 +1017,8 @@ class DensityFlow(nn.Module):
|
|
|
983
1017
|
E = []
|
|
984
1018
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
985
1019
|
for Z_batch, L_batch, _ in dataloader:
|
|
986
|
-
|
|
987
|
-
counts = self._count(
|
|
1020
|
+
log_mu = self._log_mu(Z_batch)
|
|
1021
|
+
counts = self._count(log_mu, L_batch)
|
|
988
1022
|
E.append(tensor_to_numpy(counts))
|
|
989
1023
|
pbar.update(1)
|
|
990
1024
|
|
|
@@ -1157,8 +1191,54 @@ class DensityFlow(nn.Module):
|
|
|
1157
1191
|
else:
|
|
1158
1192
|
with open(file_path, 'rb') as pickle_file:
|
|
1159
1193
|
model = pickle.load(pickle_file)
|
|
1194
|
+
|
|
1195
|
+
print(f"🧬 DensityFlow Initialized:")
|
|
1196
|
+
print(f" - Latent Dimension: {model.latent_dim}")
|
|
1197
|
+
print(f" - Gene Dimension: {model.input_size}")
|
|
1198
|
+
print(f" - Hidden Dimensions: {model.hidden_layers}")
|
|
1199
|
+
print(f" - Device: {model.get_device()}")
|
|
1200
|
+
print(f" - Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
1160
1201
|
|
|
1161
1202
|
return model
|
|
1203
|
+
|
|
1204
|
+
''' def save(self, path):
|
|
1205
|
+
"""Save model checkpoint"""
|
|
1206
|
+
torch.save({
|
|
1207
|
+
'model_state_dict': self.state_dict(),
|
|
1208
|
+
'model_config': {
|
|
1209
|
+
'input_size': self.input_size,
|
|
1210
|
+
'codebook_size': self.code_size,
|
|
1211
|
+
'cell_factor_size': self.cell_factor_size,
|
|
1212
|
+
'turn_off_cell_specific':self.turn_off_cell_specific,
|
|
1213
|
+
'supervised_mode':self.supervised_mode,
|
|
1214
|
+
'z_dim': self.latent_dim,
|
|
1215
|
+
'z_dist': self.latent_dist,
|
|
1216
|
+
'loss_func': self.loss_func,
|
|
1217
|
+
'dispersion': self.dispersion,
|
|
1218
|
+
'use_zeroinflate': self.use_zeroinflate,
|
|
1219
|
+
'hidden_layers':self.hidden_layers,
|
|
1220
|
+
'hidden_layer_activation':self.hidden_layer_activation,
|
|
1221
|
+
'nn_dropout':self.nn_dropout,
|
|
1222
|
+
'post_layer_fct':self.post_layer_fct,
|
|
1223
|
+
'post_act_fct':self.post_act_fct,
|
|
1224
|
+
'config_enum':self.config_enum,
|
|
1225
|
+
'use_cuda':self.use_cuda,
|
|
1226
|
+
'seed':self.seed,
|
|
1227
|
+
'zero_bias':self.use_bias,
|
|
1228
|
+
'dtype':self.dtype,
|
|
1229
|
+
}
|
|
1230
|
+
}, path)
|
|
1231
|
+
|
|
1232
|
+
@classmethod
|
|
1233
|
+
def load_model(cls, model_path: str):
|
|
1234
|
+
"""Load pre-trained model"""
|
|
1235
|
+
checkpoint = torch.load(model_path)
|
|
1236
|
+
model = DensityFlow(**checkpoint.get('model_config'))
|
|
1237
|
+
|
|
1238
|
+
checkpoint = torch.load(model_path, map_location=model.get_device())
|
|
1239
|
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
1240
|
+
|
|
1241
|
+
return model'''
|
|
1162
1242
|
|
|
1163
1243
|
|
|
1164
1244
|
EXAMPLE_RUN = (
|
|
@@ -1357,7 +1437,7 @@ def main():
|
|
|
1357
1437
|
df = DensityFlow(
|
|
1358
1438
|
input_size=input_size,
|
|
1359
1439
|
cell_factor_size=cell_factor_size,
|
|
1360
|
-
|
|
1440
|
+
dispersion=args.dispersion,
|
|
1361
1441
|
z_dim=args.z_dim,
|
|
1362
1442
|
hidden_layers=args.hidden_layers,
|
|
1363
1443
|
hidden_layer_activation=args.hidden_layer_activation,
|
|
@@ -54,7 +54,7 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class DensityFlowLinear(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
@@ -63,8 +63,8 @@ class DensityFlow2(nn.Module):
|
|
|
63
63
|
supervised_mode: bool = False,
|
|
64
64
|
z_dim: int = 10,
|
|
65
65
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
67
|
-
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
|
+
dispersion: float = 8.0,
|
|
68
68
|
use_zeroinflate: bool = False,
|
|
69
69
|
hidden_layers: list = [500],
|
|
70
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
@@ -81,7 +81,7 @@ class DensityFlow2(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
self.input_size = input_size
|
|
83
83
|
self.cell_factor_size = cell_factor_size
|
|
84
|
-
self.
|
|
84
|
+
self.dispersion = dispersion
|
|
85
85
|
self.latent_dim = z_dim
|
|
86
86
|
self.hidden_layers = hidden_layers
|
|
87
87
|
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
@@ -107,10 +107,11 @@ class DensityFlow2(nn.Module):
|
|
|
107
107
|
|
|
108
108
|
self.codebook_weights = None
|
|
109
109
|
|
|
110
|
+
self.seed = seed
|
|
110
111
|
set_random_seed(seed)
|
|
111
112
|
self.setup_networks()
|
|
112
113
|
|
|
113
|
-
print(f"🧬
|
|
114
|
+
print(f"🧬 DensityFlowLinear Initialized:")
|
|
114
115
|
print(f" - Latent Dimension: {self.latent_dim}")
|
|
115
116
|
print(f" - Gene Dimension: {self.input_size}")
|
|
116
117
|
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
@@ -208,17 +209,6 @@ class DensityFlow2(nn.Module):
|
|
|
208
209
|
use_cuda=self.use_cuda,
|
|
209
210
|
)
|
|
210
211
|
|
|
211
|
-
if self.loss_func == 'negbinomial':
|
|
212
|
-
self.encoder_inverse_dispersion = MLP(
|
|
213
|
-
[self.latent_dim] + hidden_sizes + [[self.input_size, self.input_size]],
|
|
214
|
-
activation=activate_fct,
|
|
215
|
-
output_activation=[Exp, Exp],
|
|
216
|
-
post_layer_fct=post_layer_fct,
|
|
217
|
-
post_act_fct=post_act_fct,
|
|
218
|
-
allow_broadcast=self.allow_broadcast,
|
|
219
|
-
use_cuda=self.use_cuda,
|
|
220
|
-
)
|
|
221
|
-
|
|
222
212
|
if self.cell_factor_size>0:
|
|
223
213
|
self.cell_factor_effect = nn.ModuleList()
|
|
224
214
|
for i in np.arange(self.cell_factor_size):
|
|
@@ -269,7 +259,7 @@ class DensityFlow2(nn.Module):
|
|
|
269
259
|
)
|
|
270
260
|
)
|
|
271
261
|
|
|
272
|
-
self.
|
|
262
|
+
self.decoder_log_mu = MLP(
|
|
273
263
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
274
264
|
activation=activate_fct,
|
|
275
265
|
output_activation=None,
|
|
@@ -352,15 +342,15 @@ class DensityFlow2(nn.Module):
|
|
|
352
342
|
return xs
|
|
353
343
|
|
|
354
344
|
def model1(self, xs):
|
|
355
|
-
pyro.module('
|
|
345
|
+
pyro.module('DensityFlowLinear', self)
|
|
356
346
|
|
|
357
347
|
eps = torch.finfo(xs.dtype).eps
|
|
358
348
|
batch_size = xs.size(0)
|
|
359
349
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
360
350
|
|
|
361
351
|
if self.loss_func=='negbinomial':
|
|
362
|
-
|
|
363
|
-
|
|
352
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
353
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
364
354
|
|
|
365
355
|
if self.use_zeroinflate:
|
|
366
356
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -394,28 +384,31 @@ class DensityFlow2(nn.Module):
|
|
|
394
384
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
395
385
|
|
|
396
386
|
zs = zns
|
|
397
|
-
|
|
387
|
+
log_mu = self.decoder_log_mu(zs)
|
|
398
388
|
if self.loss_func in ['bernoulli']:
|
|
399
|
-
log_theta =
|
|
389
|
+
log_theta = log_mu
|
|
390
|
+
elif self.loss_func in ['negbinomial']:
|
|
391
|
+
mu = log_mu.exp()
|
|
400
392
|
else:
|
|
401
|
-
rate =
|
|
393
|
+
rate = log_mu.exp()
|
|
402
394
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
403
395
|
if self.loss_func == 'poisson':
|
|
404
396
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
405
397
|
|
|
406
398
|
if self.loss_func == 'negbinomial':
|
|
399
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
407
400
|
if self.use_zeroinflate:
|
|
408
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
401
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
402
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
409
403
|
else:
|
|
410
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
404
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
411
405
|
elif self.loss_func == 'poisson':
|
|
412
406
|
if self.use_zeroinflate:
|
|
413
407
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
414
408
|
else:
|
|
415
409
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
416
410
|
elif self.loss_func == 'multinomial':
|
|
417
|
-
|
|
418
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
411
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
419
412
|
elif self.loss_func == 'bernoulli':
|
|
420
413
|
if self.use_zeroinflate:
|
|
421
414
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -432,17 +425,15 @@ class DensityFlow2(nn.Module):
|
|
|
432
425
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
433
426
|
|
|
434
427
|
def model2(self, xs, us=None):
|
|
435
|
-
pyro.module('
|
|
428
|
+
pyro.module('DensityFlowLinear', self)
|
|
436
429
|
|
|
437
430
|
eps = torch.finfo(xs.dtype).eps
|
|
438
431
|
batch_size = xs.size(0)
|
|
439
432
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
440
433
|
|
|
441
434
|
if self.loss_func=='negbinomial':
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
with pyro.plate("genes", self.input_size):
|
|
445
|
-
inverse_dispersion = pyro.sample("inverse_dispersion", dist.LogNormal(self.inverse_dispersion, 0.5).to_event(1))
|
|
435
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
436
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
446
437
|
|
|
447
438
|
if self.use_zeroinflate:
|
|
448
439
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -482,28 +473,28 @@ class DensityFlow2(nn.Module):
|
|
|
482
473
|
zs = zns'''
|
|
483
474
|
|
|
484
475
|
zs = zns
|
|
485
|
-
|
|
476
|
+
log_mu = self.decoder_log_mu(zs)
|
|
486
477
|
for i in np.arange(self.cell_factor_size):
|
|
487
478
|
zus = self._cell_shift(zs, i, us[:,i].reshape(-1,1))
|
|
488
|
-
|
|
479
|
+
log_mu += self.decoder_log_mu(zus)
|
|
489
480
|
|
|
490
481
|
if self.loss_func in ['bernoulli']:
|
|
491
|
-
log_theta =
|
|
482
|
+
log_theta = log_mu
|
|
492
483
|
elif self.loss_func in ['negbinomial']:
|
|
493
|
-
mu =
|
|
484
|
+
mu = log_mu.exp()
|
|
494
485
|
else:
|
|
495
|
-
rate =
|
|
486
|
+
rate = log_mu.exp()
|
|
496
487
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
497
488
|
if self.loss_func == 'poisson':
|
|
498
489
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
499
490
|
|
|
500
491
|
if self.loss_func == 'negbinomial':
|
|
501
|
-
logits = (mu.log()-
|
|
492
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
502
493
|
if self.use_zeroinflate:
|
|
503
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
494
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
504
495
|
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
505
496
|
else:
|
|
506
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
497
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
|
|
507
498
|
logits=logits).to_event(1), obs=xs)
|
|
508
499
|
elif self.loss_func == 'poisson':
|
|
509
500
|
if self.use_zeroinflate:
|
|
@@ -511,8 +502,7 @@ class DensityFlow2(nn.Module):
|
|
|
511
502
|
else:
|
|
512
503
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
513
504
|
elif self.loss_func == 'multinomial':
|
|
514
|
-
|
|
515
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
505
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
516
506
|
elif self.loss_func == 'bernoulli':
|
|
517
507
|
if self.use_zeroinflate:
|
|
518
508
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -528,21 +518,16 @@ class DensityFlow2(nn.Module):
|
|
|
528
518
|
alpha = self.encoder_n(zns)
|
|
529
519
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
530
520
|
|
|
531
|
-
if self.loss_func == 'negbinomial':
|
|
532
|
-
id_loc,id_scale = self.encoder_inverse_dispersion(zns)
|
|
533
|
-
with pyro.plate("genes", self.input_size):
|
|
534
|
-
pyro.sample("inverse_dispersion", dist.LogNormal(id_loc, id_scale).to_event(1))
|
|
535
|
-
|
|
536
521
|
def model3(self, xs, ys, embeds=None):
|
|
537
|
-
pyro.module('
|
|
522
|
+
pyro.module('DensityFlowLinear', self)
|
|
538
523
|
|
|
539
524
|
eps = torch.finfo(xs.dtype).eps
|
|
540
525
|
batch_size = xs.size(0)
|
|
541
526
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
542
527
|
|
|
543
528
|
if self.loss_func=='negbinomial':
|
|
544
|
-
|
|
545
|
-
|
|
529
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
530
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
546
531
|
|
|
547
532
|
if self.use_zeroinflate:
|
|
548
533
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -593,28 +578,31 @@ class DensityFlow2(nn.Module):
|
|
|
593
578
|
|
|
594
579
|
zs = zns
|
|
595
580
|
|
|
596
|
-
|
|
581
|
+
log_mu = self.decoder_log_mu(zs)
|
|
597
582
|
if self.loss_func in ['bernoulli']:
|
|
598
|
-
log_theta =
|
|
583
|
+
log_theta = log_mu
|
|
584
|
+
elif self.loss_func in ['negbinomial']:
|
|
585
|
+
mu = log_mu.exp()
|
|
599
586
|
else:
|
|
600
|
-
rate =
|
|
587
|
+
rate = log_mu.exp()
|
|
601
588
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
602
589
|
if self.loss_func == 'poisson':
|
|
603
590
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
604
591
|
|
|
605
592
|
if self.loss_func == 'negbinomial':
|
|
593
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
606
594
|
if self.use_zeroinflate:
|
|
607
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
595
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
596
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
608
597
|
else:
|
|
609
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
598
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
610
599
|
elif self.loss_func == 'poisson':
|
|
611
600
|
if self.use_zeroinflate:
|
|
612
601
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
613
602
|
else:
|
|
614
603
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
615
604
|
elif self.loss_func == 'multinomial':
|
|
616
|
-
|
|
617
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
605
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
618
606
|
elif self.loss_func == 'bernoulli':
|
|
619
607
|
if self.use_zeroinflate:
|
|
620
608
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -631,15 +619,15 @@ class DensityFlow2(nn.Module):
|
|
|
631
619
|
zns = embeds
|
|
632
620
|
|
|
633
621
|
def model4(self, xs, us, ys, embeds=None):
|
|
634
|
-
pyro.module('
|
|
622
|
+
pyro.module('DensityFlowLinear', self)
|
|
635
623
|
|
|
636
624
|
eps = torch.finfo(xs.dtype).eps
|
|
637
625
|
batch_size = xs.size(0)
|
|
638
626
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
639
627
|
|
|
640
628
|
if self.loss_func=='negbinomial':
|
|
641
|
-
|
|
642
|
-
|
|
629
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
630
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
643
631
|
|
|
644
632
|
if self.use_zeroinflate:
|
|
645
633
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -695,32 +683,34 @@ class DensityFlow2(nn.Module):
|
|
|
695
683
|
zs = zns'''
|
|
696
684
|
|
|
697
685
|
zs = zns
|
|
698
|
-
|
|
686
|
+
log_mu = self.decoder_log_mu(zs)
|
|
699
687
|
for i in np.arange(self.cell_factor_size):
|
|
700
688
|
zus = self._cell_shift(zs, i, us[:,i].reshape(-1,1))
|
|
701
|
-
|
|
689
|
+
log_mu += self.decoder_log_mu(zus)
|
|
702
690
|
|
|
703
691
|
if self.loss_func in ['bernoulli']:
|
|
704
|
-
log_theta =
|
|
692
|
+
log_theta = log_mu
|
|
693
|
+
elif self.loss_func in ['negbinomial']:
|
|
694
|
+
mu = log_mu.exp()
|
|
705
695
|
else:
|
|
706
|
-
rate =
|
|
696
|
+
rate = log_mu.exp()
|
|
707
697
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
708
698
|
if self.loss_func == 'poisson':
|
|
709
699
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
710
700
|
|
|
711
701
|
if self.loss_func == 'negbinomial':
|
|
702
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
712
703
|
if self.use_zeroinflate:
|
|
713
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
704
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
714
705
|
else:
|
|
715
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
706
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
716
707
|
elif self.loss_func == 'poisson':
|
|
717
708
|
if self.use_zeroinflate:
|
|
718
709
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
719
710
|
else:
|
|
720
711
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
721
712
|
elif self.loss_func == 'multinomial':
|
|
722
|
-
|
|
723
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
713
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
724
714
|
elif self.loss_func == 'bernoulli':
|
|
725
715
|
if self.use_zeroinflate:
|
|
726
716
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -887,6 +877,8 @@ class DensityFlow2(nn.Module):
|
|
|
887
877
|
|
|
888
878
|
# basal embedding
|
|
889
879
|
zs = self.get_basal_embedding(xs)
|
|
880
|
+
log_mu = self.get_log_mu(zs)
|
|
881
|
+
|
|
890
882
|
for pert in perturbs_predict:
|
|
891
883
|
pert_idx = int(np.where(perturbs_reference==pert)[0])
|
|
892
884
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
@@ -898,10 +890,12 @@ class DensityFlow2(nn.Module):
|
|
|
898
890
|
ps = np.ones_like(us_i)
|
|
899
891
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
900
892
|
dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
901
|
-
|
|
893
|
+
delta = dzs0 + dzs
|
|
902
894
|
else:
|
|
903
|
-
|
|
904
|
-
|
|
895
|
+
delta = dzs0
|
|
896
|
+
|
|
897
|
+
log_mu = log_mu + self.get_log_mu(delta)
|
|
898
|
+
|
|
905
899
|
if library_sizes is None:
|
|
906
900
|
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
907
901
|
elif type(library_sizes) == list:
|
|
@@ -910,9 +904,9 @@ class DensityFlow2(nn.Module):
|
|
|
910
904
|
elif len(library_sizes.shape)==1:
|
|
911
905
|
library_sizes = library_sizes.reshape(-1,1)
|
|
912
906
|
|
|
913
|
-
counts = self.get_counts(
|
|
907
|
+
counts = self.get_counts(log_mu, library_sizes=library_sizes)
|
|
914
908
|
|
|
915
|
-
return counts,
|
|
909
|
+
return counts, log_mu
|
|
916
910
|
|
|
917
911
|
def _cell_shift(self, zs, perturb_idx, perturb):
|
|
918
912
|
#zns,_ = self.encoder_zn(xs)
|
|
@@ -956,48 +950,46 @@ class DensityFlow2(nn.Module):
|
|
|
956
950
|
Z = np.concatenate(Z)
|
|
957
951
|
return Z
|
|
958
952
|
|
|
959
|
-
def
|
|
960
|
-
return self.
|
|
953
|
+
def _log_mu(self, zs):
|
|
954
|
+
return self.decoder_log_mu(zs)
|
|
961
955
|
|
|
962
|
-
def
|
|
963
|
-
delta_zs,
|
|
964
|
-
batch_size: int = 1024):
|
|
956
|
+
def get_log_mu(self, zs, batch_size: int = 1024):
|
|
965
957
|
"""
|
|
966
958
|
Return cells' changes in the feature space induced by specific perturbation of a factor
|
|
967
959
|
|
|
968
960
|
"""
|
|
969
|
-
|
|
970
|
-
dataset = CustomDataset(
|
|
961
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
962
|
+
dataset = CustomDataset(zs)
|
|
971
963
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
972
964
|
|
|
973
965
|
R = []
|
|
974
966
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
975
|
-
for
|
|
976
|
-
r = self.
|
|
967
|
+
for Z_batch, _ in dataloader:
|
|
968
|
+
r = self._log_mu(Z_batch)
|
|
977
969
|
R.append(tensor_to_numpy(r))
|
|
978
970
|
pbar.update(1)
|
|
979
971
|
|
|
980
972
|
R = np.concatenate(R)
|
|
981
973
|
return R
|
|
982
974
|
|
|
983
|
-
def _count(self,
|
|
975
|
+
def _count(self, log_mu, library_size=None):
|
|
984
976
|
if self.loss_func == 'bernoulli':
|
|
985
977
|
#counts = self.sigmoid(concentrate)
|
|
986
|
-
counts = dist.Bernoulli(logits=
|
|
978
|
+
counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
|
|
987
979
|
elif self.loss_func == 'multinomial':
|
|
988
|
-
theta = dist.Multinomial(total_count=int(1e8), logits=
|
|
980
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
|
|
989
981
|
counts = theta * library_size
|
|
990
982
|
else:
|
|
991
|
-
rate =
|
|
983
|
+
rate = log_mu.exp()
|
|
992
984
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
993
985
|
counts = theta * library_size
|
|
994
986
|
return counts
|
|
995
987
|
|
|
996
|
-
def get_counts(self,
|
|
988
|
+
def get_counts(self, log_mu,
|
|
997
989
|
library_sizes,
|
|
998
990
|
batch_size: int = 1024):
|
|
999
991
|
|
|
1000
|
-
|
|
992
|
+
log_mu = convert_to_tensor(log_mu, device=self.get_device())
|
|
1001
993
|
|
|
1002
994
|
if type(library_sizes) == list:
|
|
1003
995
|
library_sizes = np.array(library_sizes).reshape(-1,1)
|
|
@@ -1005,13 +997,13 @@ class DensityFlow2(nn.Module):
|
|
|
1005
997
|
library_sizes = library_sizes.reshape(-1,1)
|
|
1006
998
|
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
1007
999
|
|
|
1008
|
-
dataset = CustomDataset2(
|
|
1000
|
+
dataset = CustomDataset2(log_mu,ls)
|
|
1009
1001
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
1010
1002
|
|
|
1011
1003
|
E = []
|
|
1012
1004
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
1013
|
-
for
|
|
1014
|
-
counts = self._count(
|
|
1005
|
+
for Mu_batch, L_batch, _ in dataloader:
|
|
1006
|
+
counts = self._count(Mu_batch, L_batch)
|
|
1015
1007
|
E.append(tensor_to_numpy(counts))
|
|
1016
1008
|
pbar.update(1)
|
|
1017
1009
|
|
|
@@ -1045,7 +1037,7 @@ class DensityFlow2(nn.Module):
|
|
|
1045
1037
|
threshold: int = 0,
|
|
1046
1038
|
use_jax: bool = True):
|
|
1047
1039
|
"""
|
|
1048
|
-
Train the
|
|
1040
|
+
Train the DensityFlowLinear model.
|
|
1049
1041
|
|
|
1050
1042
|
Parameters
|
|
1051
1043
|
----------
|
|
@@ -1071,7 +1063,7 @@ class DensityFlow2(nn.Module):
|
|
|
1071
1063
|
Parameter for optimization.
|
|
1072
1064
|
use_jax
|
|
1073
1065
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
1074
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
1066
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlowLinear in the shell command.
|
|
1075
1067
|
"""
|
|
1076
1068
|
xs = self.preprocess(xs, threshold=threshold)
|
|
1077
1069
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1189,12 +1181,12 @@ class DensityFlow2(nn.Module):
|
|
|
1189
1181
|
|
|
1190
1182
|
|
|
1191
1183
|
EXAMPLE_RUN = (
|
|
1192
|
-
"example run:
|
|
1184
|
+
"example run: DensityFlowLinear --help"
|
|
1193
1185
|
)
|
|
1194
1186
|
|
|
1195
1187
|
def parse_args():
|
|
1196
1188
|
parser = argparse.ArgumentParser(
|
|
1197
|
-
description="
|
|
1189
|
+
description="DensityFlowLinear\n{}".format(EXAMPLE_RUN))
|
|
1198
1190
|
|
|
1199
1191
|
parser.add_argument(
|
|
1200
1192
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1381,10 +1373,10 @@ def main():
|
|
|
1381
1373
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1382
1374
|
|
|
1383
1375
|
###########################################
|
|
1384
|
-
df =
|
|
1376
|
+
df = DensityFlowLinear(
|
|
1385
1377
|
input_size=input_size,
|
|
1386
1378
|
cell_factor_size=cell_factor_size,
|
|
1387
|
-
|
|
1379
|
+
dispersion=args.dispersion,
|
|
1388
1380
|
z_dim=args.z_dim,
|
|
1389
1381
|
hidden_layers=args.hidden_layers,
|
|
1390
1382
|
hidden_layer_activation=args.hidden_layer_activation,
|
|
@@ -1412,9 +1404,9 @@ def main():
|
|
|
1412
1404
|
|
|
1413
1405
|
if args.save_model is not None:
|
|
1414
1406
|
if args.save_model.endswith('gz'):
|
|
1415
|
-
|
|
1407
|
+
DensityFlowLinear.save_model(df, args.save_model, compression=True)
|
|
1416
1408
|
else:
|
|
1417
|
-
|
|
1409
|
+
DensityFlowLinear.save_model(df, args.save_model)
|
|
1418
1410
|
|
|
1419
1411
|
|
|
1420
1412
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
2
|
from .DensityFlow import DensityFlow
|
|
3
|
-
from .
|
|
3
|
+
from .DensityFlowLinear import DensityFlowLinear
|
|
4
4
|
from .PerturbE import PerturbE
|
|
5
5
|
from .TranscriptomeDecoder import TranscriptomeDecoder
|
|
6
6
|
from .SimpleTranscriptomeDecoder import SimpleTranscriptomeDecoder
|
|
@@ -12,7 +12,7 @@ from . import utils
|
|
|
12
12
|
from . import codebook
|
|
13
13
|
from . import SURE
|
|
14
14
|
from . import DensityFlow
|
|
15
|
-
from . import
|
|
15
|
+
from . import DensityFlowLinear
|
|
16
16
|
from . import atac
|
|
17
17
|
from . import flow
|
|
18
18
|
from . import perturb
|
|
@@ -23,6 +23,6 @@ from . import EfficientTranscriptomeDecoder
|
|
|
23
23
|
from . import VirtualCellDecoder
|
|
24
24
|
from . import PerturbationAwareDecoder
|
|
25
25
|
|
|
26
|
-
__all__ = ['SURE', 'DensityFlow', '
|
|
26
|
+
__all__ = ['SURE', 'DensityFlow', 'DensityFlowLinear', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
|
|
27
27
|
'EfficientTranscriptomeDecoder', 'VirtualCellDecoder', 'PerturbationAwareDecoder',
|
|
28
28
|
'flow', 'perturb', 'atac', 'utils', 'codebook']
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|