SURE-tools 2.4.25__tar.gz → 2.4.42__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.4.25 → sure_tools-2.4.42}/PKG-INFO +1 -1
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/DensityFlow.py +152 -70
- sure_tools-2.4.25/SURE/DensityFlow2.py → sure_tools-2.4.42/SURE/DensityFlowLinear.py +91 -99
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/__init__.py +3 -3
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE_tools.egg-info/SOURCES.txt +1 -1
- {sure_tools-2.4.25 → sure_tools-2.4.42}/setup.py +1 -1
- {sure_tools-2.4.25 → sure_tools-2.4.42}/LICENSE +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/README.md +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/EfficientTranscriptomeDecoder.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/PerturbE.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/PerturbationAwareDecoder.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/SURE.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/SimpleTranscriptomeDecoder.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/TranscriptomeDecoder.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/VirtualCellDecoder.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/atac/utils.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/utils/queue.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/utils/utils.py +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.4.25 → sure_tools-2.4.42}/setup.cfg +0 -0
|
@@ -57,16 +57,16 @@ def set_random_seed(seed):
|
|
|
57
57
|
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
|
-
codebook_size: int =
|
|
60
|
+
codebook_size: int = 30,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
62
|
turn_off_cell_specific: bool = False,
|
|
63
63
|
supervised_mode: bool = False,
|
|
64
|
-
z_dim: int =
|
|
65
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
67
|
-
|
|
64
|
+
z_dim: int = 50,
|
|
65
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
|
+
dispersion: float = 8.0,
|
|
68
68
|
use_zeroinflate: bool = False,
|
|
69
|
-
hidden_layers: list = [
|
|
69
|
+
hidden_layers: list = [1024],
|
|
70
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
71
71
|
nn_dropout: float = 0.1,
|
|
72
72
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
self.input_size = input_size
|
|
83
83
|
self.cell_factor_size = cell_factor_size
|
|
84
|
-
self.
|
|
84
|
+
self.dispersion = dispersion
|
|
85
85
|
self.latent_dim = z_dim
|
|
86
86
|
self.hidden_layers = hidden_layers
|
|
87
87
|
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
88
|
+
self.config_enum = config_enum
|
|
88
89
|
self.allow_broadcast = config_enum == 'parallel'
|
|
89
90
|
self.use_cuda = use_cuda
|
|
90
91
|
self.loss_func = loss_func
|
|
@@ -107,10 +108,12 @@ class DensityFlow(nn.Module):
|
|
|
107
108
|
|
|
108
109
|
self.codebook_weights = None
|
|
109
110
|
|
|
111
|
+
self.seed = seed
|
|
110
112
|
set_random_seed(seed)
|
|
111
113
|
self.setup_networks()
|
|
112
114
|
|
|
113
115
|
print(f"🧬 DensityFlow Initialized:")
|
|
116
|
+
print(f" - Codebook size: {self.code_size}")
|
|
114
117
|
print(f" - Latent Dimension: {self.latent_dim}")
|
|
115
118
|
print(f" - Gene Dimension: {self.input_size}")
|
|
116
119
|
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
@@ -258,7 +261,7 @@ class DensityFlow(nn.Module):
|
|
|
258
261
|
)
|
|
259
262
|
)
|
|
260
263
|
|
|
261
|
-
self.
|
|
264
|
+
self.decoder_log_mu = MLP(
|
|
262
265
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
263
266
|
activation=activate_fct,
|
|
264
267
|
output_activation=None,
|
|
@@ -348,8 +351,8 @@ class DensityFlow(nn.Module):
|
|
|
348
351
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
349
352
|
|
|
350
353
|
if self.loss_func=='negbinomial':
|
|
351
|
-
|
|
352
|
-
|
|
354
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
355
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
353
356
|
|
|
354
357
|
if self.use_zeroinflate:
|
|
355
358
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -383,28 +386,32 @@ class DensityFlow(nn.Module):
|
|
|
383
386
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
384
387
|
|
|
385
388
|
zs = zns
|
|
386
|
-
|
|
389
|
+
log_mu = self.decoder_log_mu(zs)
|
|
387
390
|
if self.loss_func in ['bernoulli']:
|
|
388
|
-
log_theta =
|
|
391
|
+
log_theta = log_mu
|
|
392
|
+
elif self.loss_func == 'negbinomial':
|
|
393
|
+
mu = log_mu.exp()
|
|
389
394
|
else:
|
|
390
|
-
rate =
|
|
395
|
+
rate = log_mu.exp()
|
|
391
396
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
392
397
|
if self.loss_func == 'poisson':
|
|
393
398
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
394
399
|
|
|
395
400
|
if self.loss_func == 'negbinomial':
|
|
401
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
396
402
|
if self.use_zeroinflate:
|
|
397
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
403
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
404
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
398
405
|
else:
|
|
399
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
406
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
|
|
407
|
+
logits=logits).to_event(1), obs=xs)
|
|
400
408
|
elif self.loss_func == 'poisson':
|
|
401
409
|
if self.use_zeroinflate:
|
|
402
410
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
403
411
|
else:
|
|
404
412
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
405
413
|
elif self.loss_func == 'multinomial':
|
|
406
|
-
|
|
407
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
414
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
408
415
|
elif self.loss_func == 'bernoulli':
|
|
409
416
|
if self.use_zeroinflate:
|
|
410
417
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -428,8 +435,8 @@ class DensityFlow(nn.Module):
|
|
|
428
435
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
429
436
|
|
|
430
437
|
if self.loss_func=='negbinomial':
|
|
431
|
-
|
|
432
|
-
|
|
438
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
439
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
433
440
|
|
|
434
441
|
if self.use_zeroinflate:
|
|
435
442
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -468,28 +475,30 @@ class DensityFlow(nn.Module):
|
|
|
468
475
|
else:
|
|
469
476
|
zs = zns
|
|
470
477
|
|
|
471
|
-
|
|
478
|
+
log_mu = self.decoder_log_mu(zs)
|
|
472
479
|
if self.loss_func in ['bernoulli']:
|
|
473
|
-
log_theta =
|
|
480
|
+
log_theta = log_mu
|
|
481
|
+
elif self.loss_func == 'negbinomial':
|
|
482
|
+
mu = log_mu.exp()
|
|
474
483
|
else:
|
|
475
|
-
rate =
|
|
484
|
+
rate = log_mu.exp()
|
|
476
485
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
477
486
|
if self.loss_func == 'poisson':
|
|
478
487
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
479
488
|
|
|
480
489
|
if self.loss_func == 'negbinomial':
|
|
490
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
481
491
|
if self.use_zeroinflate:
|
|
482
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
492
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
483
493
|
else:
|
|
484
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
494
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
485
495
|
elif self.loss_func == 'poisson':
|
|
486
496
|
if self.use_zeroinflate:
|
|
487
497
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
488
498
|
else:
|
|
489
499
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
490
500
|
elif self.loss_func == 'multinomial':
|
|
491
|
-
|
|
492
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
501
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
493
502
|
elif self.loss_func == 'bernoulli':
|
|
494
503
|
if self.use_zeroinflate:
|
|
495
504
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -513,8 +522,8 @@ class DensityFlow(nn.Module):
|
|
|
513
522
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
514
523
|
|
|
515
524
|
if self.loss_func=='negbinomial':
|
|
516
|
-
|
|
517
|
-
|
|
525
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
526
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
518
527
|
|
|
519
528
|
if self.use_zeroinflate:
|
|
520
529
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -565,28 +574,31 @@ class DensityFlow(nn.Module):
|
|
|
565
574
|
|
|
566
575
|
zs = zns
|
|
567
576
|
|
|
568
|
-
|
|
577
|
+
log_mu = self.decoder_log_mu(zs)
|
|
569
578
|
if self.loss_func in ['bernoulli']:
|
|
570
|
-
log_theta =
|
|
579
|
+
log_theta = log_mu
|
|
580
|
+
elif self.loss_func in ['negbinomial']:
|
|
581
|
+
mu = log_mu.exp()
|
|
571
582
|
else:
|
|
572
|
-
rate =
|
|
583
|
+
rate = log_mu.exp()
|
|
573
584
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
574
585
|
if self.loss_func == 'poisson':
|
|
575
586
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
576
587
|
|
|
577
588
|
if self.loss_func == 'negbinomial':
|
|
589
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
578
590
|
if self.use_zeroinflate:
|
|
579
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
591
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
592
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
580
593
|
else:
|
|
581
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
594
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
582
595
|
elif self.loss_func == 'poisson':
|
|
583
596
|
if self.use_zeroinflate:
|
|
584
597
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
585
598
|
else:
|
|
586
599
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
587
600
|
elif self.loss_func == 'multinomial':
|
|
588
|
-
|
|
589
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
601
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
590
602
|
elif self.loss_func == 'bernoulli':
|
|
591
603
|
if self.use_zeroinflate:
|
|
592
604
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -610,8 +622,8 @@ class DensityFlow(nn.Module):
|
|
|
610
622
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
611
623
|
|
|
612
624
|
if self.loss_func=='negbinomial':
|
|
613
|
-
|
|
614
|
-
|
|
625
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
626
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
615
627
|
|
|
616
628
|
if self.use_zeroinflate:
|
|
617
629
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -672,28 +684,31 @@ class DensityFlow(nn.Module):
|
|
|
672
684
|
else:
|
|
673
685
|
zs = zns
|
|
674
686
|
|
|
675
|
-
|
|
687
|
+
log_mu = self.decoder_log_mu(zs)
|
|
676
688
|
if self.loss_func in ['bernoulli']:
|
|
677
|
-
log_theta =
|
|
689
|
+
log_theta = log_mu
|
|
690
|
+
elif self.loss_func in ['negbinomial']:
|
|
691
|
+
mu = log_mu.exp()
|
|
678
692
|
else:
|
|
679
|
-
rate =
|
|
693
|
+
rate = log_mu.exp()
|
|
680
694
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
681
695
|
if self.loss_func == 'poisson':
|
|
682
696
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
683
697
|
|
|
684
698
|
if self.loss_func == 'negbinomial':
|
|
699
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
685
700
|
if self.use_zeroinflate:
|
|
686
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
701
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
702
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
687
703
|
else:
|
|
688
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
704
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
689
705
|
elif self.loss_func == 'poisson':
|
|
690
706
|
if self.use_zeroinflate:
|
|
691
707
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
692
708
|
else:
|
|
693
709
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
694
710
|
elif self.loss_func == 'multinomial':
|
|
695
|
-
|
|
696
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
711
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
697
712
|
elif self.loss_func == 'bernoulli':
|
|
698
713
|
if self.use_zeroinflate:
|
|
699
714
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -717,13 +732,13 @@ class DensityFlow(nn.Module):
|
|
|
717
732
|
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
718
733
|
#else:
|
|
719
734
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
720
|
-
zus = self.
|
|
735
|
+
zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
721
736
|
else:
|
|
722
737
|
#if self.turn_off_cell_specific:
|
|
723
738
|
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
724
739
|
#else:
|
|
725
740
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
726
|
-
zus = zus + self.
|
|
741
|
+
zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
727
742
|
return zus
|
|
728
743
|
|
|
729
744
|
def _get_codebook_identity(self):
|
|
@@ -744,6 +759,28 @@ class DensityFlow(nn.Module):
|
|
|
744
759
|
cb = self._get_codebook()
|
|
745
760
|
cb = tensor_to_numpy(cb)
|
|
746
761
|
return cb
|
|
762
|
+
|
|
763
|
+
def _get_complete_embedding(self, xs, us):
|
|
764
|
+
basal,_ = self._get_basal_embedding(xs)
|
|
765
|
+
dzs = self._total_effects(basal, us)
|
|
766
|
+
return basal + dzs
|
|
767
|
+
|
|
768
|
+
def get_complete_embedding(self, xs, us, batch_size:int=1024):
|
|
769
|
+
xs = self.preprocess(xs)
|
|
770
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
771
|
+
us = convert_to_tensor(us, device=self.get_device())
|
|
772
|
+
dataset = CustomDataset2(xs, us)
|
|
773
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
774
|
+
|
|
775
|
+
Z = []
|
|
776
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
777
|
+
for X_batch, U_batch, _ in dataloader:
|
|
778
|
+
zns = self._get_complete_embedding(X_batch, U_batch)
|
|
779
|
+
Z.append(tensor_to_numpy(zns))
|
|
780
|
+
pbar.update(1)
|
|
781
|
+
|
|
782
|
+
Z = np.concatenate(Z)
|
|
783
|
+
return Z
|
|
747
784
|
|
|
748
785
|
def _get_basal_embedding(self, xs):
|
|
749
786
|
loc, scale = self.encoder_zn(xs)
|
|
@@ -865,12 +902,12 @@ class DensityFlow(nn.Module):
|
|
|
865
902
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
866
903
|
|
|
867
904
|
# factor effect of xs
|
|
868
|
-
dzs0 = self.
|
|
905
|
+
dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
869
906
|
|
|
870
907
|
# perturbation effect
|
|
871
908
|
ps = np.ones_like(us_i)
|
|
872
909
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
873
|
-
dzs = self.
|
|
910
|
+
dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
874
911
|
zs = zs + dzs0 + dzs
|
|
875
912
|
else:
|
|
876
913
|
zs = zs + dzs0
|
|
@@ -884,10 +921,11 @@ class DensityFlow(nn.Module):
|
|
|
884
921
|
library_sizes = library_sizes.reshape(-1,1)
|
|
885
922
|
|
|
886
923
|
counts = self.get_counts(zs, library_sizes=library_sizes)
|
|
924
|
+
log_mu = self.get_log_mu(zs)
|
|
887
925
|
|
|
888
|
-
return counts,
|
|
926
|
+
return counts, log_mu
|
|
889
927
|
|
|
890
|
-
def
|
|
928
|
+
def _cell_shift(self, zs, perturb_idx, perturb):
|
|
891
929
|
#zns,_ = self.encoder_zn(xs)
|
|
892
930
|
#zns,_ = self._get_basal_embedding(xs)
|
|
893
931
|
zns = zs
|
|
@@ -904,7 +942,7 @@ class DensityFlow(nn.Module):
|
|
|
904
942
|
|
|
905
943
|
return ms
|
|
906
944
|
|
|
907
|
-
def
|
|
945
|
+
def get_cell_shift(self,
|
|
908
946
|
zs,
|
|
909
947
|
perturb_idx,
|
|
910
948
|
perturb_us,
|
|
@@ -922,46 +960,43 @@ class DensityFlow(nn.Module):
|
|
|
922
960
|
Z = []
|
|
923
961
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
924
962
|
for Z_batch, P_batch, _ in dataloader:
|
|
925
|
-
zns = self.
|
|
963
|
+
zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
|
|
926
964
|
Z.append(tensor_to_numpy(zns))
|
|
927
965
|
pbar.update(1)
|
|
928
966
|
|
|
929
967
|
Z = np.concatenate(Z)
|
|
930
968
|
return Z
|
|
931
969
|
|
|
932
|
-
def
|
|
933
|
-
return self.
|
|
970
|
+
def _log_mu(self, zs):
|
|
971
|
+
return self.decoder_log_mu(zs)
|
|
934
972
|
|
|
935
|
-
def
|
|
936
|
-
delta_zs,
|
|
937
|
-
batch_size: int = 1024):
|
|
973
|
+
def get_log_mu(self, zs, batch_size: int = 1024):
|
|
938
974
|
"""
|
|
939
975
|
Return cells' changes in the feature space induced by specific perturbation of a factor
|
|
940
976
|
|
|
941
977
|
"""
|
|
942
|
-
|
|
943
|
-
dataset = CustomDataset(
|
|
978
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
979
|
+
dataset = CustomDataset(zs)
|
|
944
980
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
945
981
|
|
|
946
982
|
R = []
|
|
947
983
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
948
|
-
for
|
|
949
|
-
r = self.
|
|
984
|
+
for Z_batch, _ in dataloader:
|
|
985
|
+
r = self._log_mu(Z_batch)
|
|
950
986
|
R.append(tensor_to_numpy(r))
|
|
951
987
|
pbar.update(1)
|
|
952
988
|
|
|
953
989
|
R = np.concatenate(R)
|
|
954
990
|
return R
|
|
955
991
|
|
|
956
|
-
def _count(self,
|
|
992
|
+
def _count(self, log_mu, library_size=None):
|
|
957
993
|
if self.loss_func == 'bernoulli':
|
|
958
|
-
|
|
959
|
-
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
994
|
+
counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
|
|
960
995
|
elif self.loss_func == 'multinomial':
|
|
961
|
-
theta = dist.Multinomial(total_count=int(1e8), logits=
|
|
996
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
|
|
962
997
|
counts = theta * library_size
|
|
963
998
|
else:
|
|
964
|
-
rate =
|
|
999
|
+
rate = log_mu.exp()
|
|
965
1000
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
966
1001
|
counts = theta * library_size
|
|
967
1002
|
return counts
|
|
@@ -983,8 +1018,8 @@ class DensityFlow(nn.Module):
|
|
|
983
1018
|
E = []
|
|
984
1019
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
985
1020
|
for Z_batch, L_batch, _ in dataloader:
|
|
986
|
-
|
|
987
|
-
counts = self._count(
|
|
1021
|
+
log_mu = self._log_mu(Z_batch)
|
|
1022
|
+
counts = self._count(log_mu, L_batch)
|
|
988
1023
|
E.append(tensor_to_numpy(counts))
|
|
989
1024
|
pbar.update(1)
|
|
990
1025
|
|
|
@@ -996,7 +1031,7 @@ class DensityFlow(nn.Module):
|
|
|
996
1031
|
ad = sc.AnnData(xs)
|
|
997
1032
|
binarize(ad, threshold=threshold)
|
|
998
1033
|
xs = ad.X.copy()
|
|
999
|
-
|
|
1034
|
+
elif self.loss_func == 'poisson':
|
|
1000
1035
|
xs = np.round(xs)
|
|
1001
1036
|
|
|
1002
1037
|
if sparse.issparse(xs):
|
|
@@ -1157,8 +1192,55 @@ class DensityFlow(nn.Module):
|
|
|
1157
1192
|
else:
|
|
1158
1193
|
with open(file_path, 'rb') as pickle_file:
|
|
1159
1194
|
model = pickle.load(pickle_file)
|
|
1195
|
+
|
|
1196
|
+
print(f"🧬 DensityFlow Initialized:")
|
|
1197
|
+
print(f" - Codebook size: {model.code_size}")
|
|
1198
|
+
print(f" - Latent Dimension: {model.latent_dim}")
|
|
1199
|
+
print(f" - Gene Dimension: {model.input_size}")
|
|
1200
|
+
print(f" - Hidden Dimensions: {model.hidden_layers}")
|
|
1201
|
+
print(f" - Device: {model.get_device()}")
|
|
1202
|
+
print(f" - Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
1160
1203
|
|
|
1161
1204
|
return model
|
|
1205
|
+
|
|
1206
|
+
''' def save(self, path):
|
|
1207
|
+
"""Save model checkpoint"""
|
|
1208
|
+
torch.save({
|
|
1209
|
+
'model_state_dict': self.state_dict(),
|
|
1210
|
+
'model_config': {
|
|
1211
|
+
'input_size': self.input_size,
|
|
1212
|
+
'codebook_size': self.code_size,
|
|
1213
|
+
'cell_factor_size': self.cell_factor_size,
|
|
1214
|
+
'turn_off_cell_specific':self.turn_off_cell_specific,
|
|
1215
|
+
'supervised_mode':self.supervised_mode,
|
|
1216
|
+
'z_dim': self.latent_dim,
|
|
1217
|
+
'z_dist': self.latent_dist,
|
|
1218
|
+
'loss_func': self.loss_func,
|
|
1219
|
+
'dispersion': self.dispersion,
|
|
1220
|
+
'use_zeroinflate': self.use_zeroinflate,
|
|
1221
|
+
'hidden_layers':self.hidden_layers,
|
|
1222
|
+
'hidden_layer_activation':self.hidden_layer_activation,
|
|
1223
|
+
'nn_dropout':self.nn_dropout,
|
|
1224
|
+
'post_layer_fct':self.post_layer_fct,
|
|
1225
|
+
'post_act_fct':self.post_act_fct,
|
|
1226
|
+
'config_enum':self.config_enum,
|
|
1227
|
+
'use_cuda':self.use_cuda,
|
|
1228
|
+
'seed':self.seed,
|
|
1229
|
+
'zero_bias':self.use_bias,
|
|
1230
|
+
'dtype':self.dtype,
|
|
1231
|
+
}
|
|
1232
|
+
}, path)
|
|
1233
|
+
|
|
1234
|
+
@classmethod
|
|
1235
|
+
def load_model(cls, model_path: str):
|
|
1236
|
+
"""Load pre-trained model"""
|
|
1237
|
+
checkpoint = torch.load(model_path)
|
|
1238
|
+
model = DensityFlow(**checkpoint.get('model_config'))
|
|
1239
|
+
|
|
1240
|
+
checkpoint = torch.load(model_path, map_location=model.get_device())
|
|
1241
|
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
1242
|
+
|
|
1243
|
+
return model'''
|
|
1162
1244
|
|
|
1163
1245
|
|
|
1164
1246
|
EXAMPLE_RUN = (
|
|
@@ -1357,7 +1439,7 @@ def main():
|
|
|
1357
1439
|
df = DensityFlow(
|
|
1358
1440
|
input_size=input_size,
|
|
1359
1441
|
cell_factor_size=cell_factor_size,
|
|
1360
|
-
|
|
1442
|
+
dispersion=args.dispersion,
|
|
1361
1443
|
z_dim=args.z_dim,
|
|
1362
1444
|
hidden_layers=args.hidden_layers,
|
|
1363
1445
|
hidden_layer_activation=args.hidden_layer_activation,
|
|
@@ -54,7 +54,7 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class DensityFlowLinear(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
@@ -63,8 +63,8 @@ class DensityFlow2(nn.Module):
|
|
|
63
63
|
supervised_mode: bool = False,
|
|
64
64
|
z_dim: int = 10,
|
|
65
65
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
67
|
-
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
|
+
dispersion: float = 8.0,
|
|
68
68
|
use_zeroinflate: bool = False,
|
|
69
69
|
hidden_layers: list = [500],
|
|
70
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
@@ -81,7 +81,7 @@ class DensityFlow2(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
self.input_size = input_size
|
|
83
83
|
self.cell_factor_size = cell_factor_size
|
|
84
|
-
self.
|
|
84
|
+
self.dispersion = dispersion
|
|
85
85
|
self.latent_dim = z_dim
|
|
86
86
|
self.hidden_layers = hidden_layers
|
|
87
87
|
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
@@ -107,10 +107,11 @@ class DensityFlow2(nn.Module):
|
|
|
107
107
|
|
|
108
108
|
self.codebook_weights = None
|
|
109
109
|
|
|
110
|
+
self.seed = seed
|
|
110
111
|
set_random_seed(seed)
|
|
111
112
|
self.setup_networks()
|
|
112
113
|
|
|
113
|
-
print(f"🧬
|
|
114
|
+
print(f"🧬 DensityFlowLinear Initialized:")
|
|
114
115
|
print(f" - Latent Dimension: {self.latent_dim}")
|
|
115
116
|
print(f" - Gene Dimension: {self.input_size}")
|
|
116
117
|
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
@@ -208,17 +209,6 @@ class DensityFlow2(nn.Module):
|
|
|
208
209
|
use_cuda=self.use_cuda,
|
|
209
210
|
)
|
|
210
211
|
|
|
211
|
-
if self.loss_func == 'negbinomial':
|
|
212
|
-
self.encoder_inverse_dispersion = MLP(
|
|
213
|
-
[self.latent_dim] + hidden_sizes + [[self.input_size, self.input_size]],
|
|
214
|
-
activation=activate_fct,
|
|
215
|
-
output_activation=[Exp, Exp],
|
|
216
|
-
post_layer_fct=post_layer_fct,
|
|
217
|
-
post_act_fct=post_act_fct,
|
|
218
|
-
allow_broadcast=self.allow_broadcast,
|
|
219
|
-
use_cuda=self.use_cuda,
|
|
220
|
-
)
|
|
221
|
-
|
|
222
212
|
if self.cell_factor_size>0:
|
|
223
213
|
self.cell_factor_effect = nn.ModuleList()
|
|
224
214
|
for i in np.arange(self.cell_factor_size):
|
|
@@ -269,7 +259,7 @@ class DensityFlow2(nn.Module):
|
|
|
269
259
|
)
|
|
270
260
|
)
|
|
271
261
|
|
|
272
|
-
self.
|
|
262
|
+
self.decoder_log_mu = MLP(
|
|
273
263
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
274
264
|
activation=activate_fct,
|
|
275
265
|
output_activation=None,
|
|
@@ -352,15 +342,15 @@ class DensityFlow2(nn.Module):
|
|
|
352
342
|
return xs
|
|
353
343
|
|
|
354
344
|
def model1(self, xs):
|
|
355
|
-
pyro.module('
|
|
345
|
+
pyro.module('DensityFlowLinear', self)
|
|
356
346
|
|
|
357
347
|
eps = torch.finfo(xs.dtype).eps
|
|
358
348
|
batch_size = xs.size(0)
|
|
359
349
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
360
350
|
|
|
361
351
|
if self.loss_func=='negbinomial':
|
|
362
|
-
|
|
363
|
-
|
|
352
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
353
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
364
354
|
|
|
365
355
|
if self.use_zeroinflate:
|
|
366
356
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -394,28 +384,31 @@ class DensityFlow2(nn.Module):
|
|
|
394
384
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
395
385
|
|
|
396
386
|
zs = zns
|
|
397
|
-
|
|
387
|
+
log_mu = self.decoder_log_mu(zs)
|
|
398
388
|
if self.loss_func in ['bernoulli']:
|
|
399
|
-
log_theta =
|
|
389
|
+
log_theta = log_mu
|
|
390
|
+
elif self.loss_func in ['negbinomial']:
|
|
391
|
+
mu = log_mu.exp()
|
|
400
392
|
else:
|
|
401
|
-
rate =
|
|
393
|
+
rate = log_mu.exp()
|
|
402
394
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
403
395
|
if self.loss_func == 'poisson':
|
|
404
396
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
405
397
|
|
|
406
398
|
if self.loss_func == 'negbinomial':
|
|
399
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
407
400
|
if self.use_zeroinflate:
|
|
408
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
401
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
402
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
409
403
|
else:
|
|
410
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
404
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
411
405
|
elif self.loss_func == 'poisson':
|
|
412
406
|
if self.use_zeroinflate:
|
|
413
407
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
414
408
|
else:
|
|
415
409
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
416
410
|
elif self.loss_func == 'multinomial':
|
|
417
|
-
|
|
418
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
411
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
419
412
|
elif self.loss_func == 'bernoulli':
|
|
420
413
|
if self.use_zeroinflate:
|
|
421
414
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -432,17 +425,15 @@ class DensityFlow2(nn.Module):
|
|
|
432
425
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
433
426
|
|
|
434
427
|
def model2(self, xs, us=None):
|
|
435
|
-
pyro.module('
|
|
428
|
+
pyro.module('DensityFlowLinear', self)
|
|
436
429
|
|
|
437
430
|
eps = torch.finfo(xs.dtype).eps
|
|
438
431
|
batch_size = xs.size(0)
|
|
439
432
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
440
433
|
|
|
441
434
|
if self.loss_func=='negbinomial':
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
with pyro.plate("genes", self.input_size):
|
|
445
|
-
inverse_dispersion = pyro.sample("inverse_dispersion", dist.LogNormal(self.inverse_dispersion, 0.5).to_event(1))
|
|
435
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
436
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
446
437
|
|
|
447
438
|
if self.use_zeroinflate:
|
|
448
439
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -482,28 +473,28 @@ class DensityFlow2(nn.Module):
|
|
|
482
473
|
zs = zns'''
|
|
483
474
|
|
|
484
475
|
zs = zns
|
|
485
|
-
|
|
476
|
+
log_mu = self.decoder_log_mu(zs)
|
|
486
477
|
for i in np.arange(self.cell_factor_size):
|
|
487
478
|
zus = self._cell_shift(zs, i, us[:,i].reshape(-1,1))
|
|
488
|
-
|
|
479
|
+
log_mu += self.decoder_log_mu(zus)
|
|
489
480
|
|
|
490
481
|
if self.loss_func in ['bernoulli']:
|
|
491
|
-
log_theta =
|
|
482
|
+
log_theta = log_mu
|
|
492
483
|
elif self.loss_func in ['negbinomial']:
|
|
493
|
-
mu =
|
|
484
|
+
mu = log_mu.exp()
|
|
494
485
|
else:
|
|
495
|
-
rate =
|
|
486
|
+
rate = log_mu.exp()
|
|
496
487
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
497
488
|
if self.loss_func == 'poisson':
|
|
498
489
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
499
490
|
|
|
500
491
|
if self.loss_func == 'negbinomial':
|
|
501
|
-
logits = (mu.log()-
|
|
492
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
502
493
|
if self.use_zeroinflate:
|
|
503
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
494
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
504
495
|
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
505
496
|
else:
|
|
506
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
497
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
|
|
507
498
|
logits=logits).to_event(1), obs=xs)
|
|
508
499
|
elif self.loss_func == 'poisson':
|
|
509
500
|
if self.use_zeroinflate:
|
|
@@ -511,8 +502,7 @@ class DensityFlow2(nn.Module):
|
|
|
511
502
|
else:
|
|
512
503
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
513
504
|
elif self.loss_func == 'multinomial':
|
|
514
|
-
|
|
515
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
505
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
516
506
|
elif self.loss_func == 'bernoulli':
|
|
517
507
|
if self.use_zeroinflate:
|
|
518
508
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -528,21 +518,16 @@ class DensityFlow2(nn.Module):
|
|
|
528
518
|
alpha = self.encoder_n(zns)
|
|
529
519
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
530
520
|
|
|
531
|
-
if self.loss_func == 'negbinomial':
|
|
532
|
-
id_loc,id_scale = self.encoder_inverse_dispersion(zns)
|
|
533
|
-
with pyro.plate("genes", self.input_size):
|
|
534
|
-
pyro.sample("inverse_dispersion", dist.LogNormal(id_loc, id_scale).to_event(1))
|
|
535
|
-
|
|
536
521
|
def model3(self, xs, ys, embeds=None):
|
|
537
|
-
pyro.module('
|
|
522
|
+
pyro.module('DensityFlowLinear', self)
|
|
538
523
|
|
|
539
524
|
eps = torch.finfo(xs.dtype).eps
|
|
540
525
|
batch_size = xs.size(0)
|
|
541
526
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
542
527
|
|
|
543
528
|
if self.loss_func=='negbinomial':
|
|
544
|
-
|
|
545
|
-
|
|
529
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
530
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
546
531
|
|
|
547
532
|
if self.use_zeroinflate:
|
|
548
533
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -593,28 +578,31 @@ class DensityFlow2(nn.Module):
|
|
|
593
578
|
|
|
594
579
|
zs = zns
|
|
595
580
|
|
|
596
|
-
|
|
581
|
+
log_mu = self.decoder_log_mu(zs)
|
|
597
582
|
if self.loss_func in ['bernoulli']:
|
|
598
|
-
log_theta =
|
|
583
|
+
log_theta = log_mu
|
|
584
|
+
elif self.loss_func in ['negbinomial']:
|
|
585
|
+
mu = log_mu.exp()
|
|
599
586
|
else:
|
|
600
|
-
rate =
|
|
587
|
+
rate = log_mu.exp()
|
|
601
588
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
602
589
|
if self.loss_func == 'poisson':
|
|
603
590
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
604
591
|
|
|
605
592
|
if self.loss_func == 'negbinomial':
|
|
593
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
606
594
|
if self.use_zeroinflate:
|
|
607
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
595
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
596
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
608
597
|
else:
|
|
609
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
598
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
610
599
|
elif self.loss_func == 'poisson':
|
|
611
600
|
if self.use_zeroinflate:
|
|
612
601
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
613
602
|
else:
|
|
614
603
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
615
604
|
elif self.loss_func == 'multinomial':
|
|
616
|
-
|
|
617
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
605
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
618
606
|
elif self.loss_func == 'bernoulli':
|
|
619
607
|
if self.use_zeroinflate:
|
|
620
608
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -631,15 +619,15 @@ class DensityFlow2(nn.Module):
|
|
|
631
619
|
zns = embeds
|
|
632
620
|
|
|
633
621
|
def model4(self, xs, us, ys, embeds=None):
|
|
634
|
-
pyro.module('
|
|
622
|
+
pyro.module('DensityFlowLinear', self)
|
|
635
623
|
|
|
636
624
|
eps = torch.finfo(xs.dtype).eps
|
|
637
625
|
batch_size = xs.size(0)
|
|
638
626
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
639
627
|
|
|
640
628
|
if self.loss_func=='negbinomial':
|
|
641
|
-
|
|
642
|
-
|
|
629
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
630
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
643
631
|
|
|
644
632
|
if self.use_zeroinflate:
|
|
645
633
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -695,32 +683,34 @@ class DensityFlow2(nn.Module):
|
|
|
695
683
|
zs = zns'''
|
|
696
684
|
|
|
697
685
|
zs = zns
|
|
698
|
-
|
|
686
|
+
log_mu = self.decoder_log_mu(zs)
|
|
699
687
|
for i in np.arange(self.cell_factor_size):
|
|
700
688
|
zus = self._cell_shift(zs, i, us[:,i].reshape(-1,1))
|
|
701
|
-
|
|
689
|
+
log_mu += self.decoder_log_mu(zus)
|
|
702
690
|
|
|
703
691
|
if self.loss_func in ['bernoulli']:
|
|
704
|
-
log_theta =
|
|
692
|
+
log_theta = log_mu
|
|
693
|
+
elif self.loss_func in ['negbinomial']:
|
|
694
|
+
mu = log_mu.exp()
|
|
705
695
|
else:
|
|
706
|
-
rate =
|
|
696
|
+
rate = log_mu.exp()
|
|
707
697
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
708
698
|
if self.loss_func == 'poisson':
|
|
709
699
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
710
700
|
|
|
711
701
|
if self.loss_func == 'negbinomial':
|
|
702
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
712
703
|
if self.use_zeroinflate:
|
|
713
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
704
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
714
705
|
else:
|
|
715
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
706
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
716
707
|
elif self.loss_func == 'poisson':
|
|
717
708
|
if self.use_zeroinflate:
|
|
718
709
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
719
710
|
else:
|
|
720
711
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
721
712
|
elif self.loss_func == 'multinomial':
|
|
722
|
-
|
|
723
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
713
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
724
714
|
elif self.loss_func == 'bernoulli':
|
|
725
715
|
if self.use_zeroinflate:
|
|
726
716
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -887,6 +877,8 @@ class DensityFlow2(nn.Module):
|
|
|
887
877
|
|
|
888
878
|
# basal embedding
|
|
889
879
|
zs = self.get_basal_embedding(xs)
|
|
880
|
+
log_mu = self.get_log_mu(zs)
|
|
881
|
+
|
|
890
882
|
for pert in perturbs_predict:
|
|
891
883
|
pert_idx = int(np.where(perturbs_reference==pert)[0])
|
|
892
884
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
@@ -898,10 +890,12 @@ class DensityFlow2(nn.Module):
|
|
|
898
890
|
ps = np.ones_like(us_i)
|
|
899
891
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
900
892
|
dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
901
|
-
|
|
893
|
+
delta = dzs0 + dzs
|
|
902
894
|
else:
|
|
903
|
-
|
|
904
|
-
|
|
895
|
+
delta = dzs0
|
|
896
|
+
|
|
897
|
+
log_mu = log_mu + self.get_log_mu(delta)
|
|
898
|
+
|
|
905
899
|
if library_sizes is None:
|
|
906
900
|
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
907
901
|
elif type(library_sizes) == list:
|
|
@@ -910,9 +904,9 @@ class DensityFlow2(nn.Module):
|
|
|
910
904
|
elif len(library_sizes.shape)==1:
|
|
911
905
|
library_sizes = library_sizes.reshape(-1,1)
|
|
912
906
|
|
|
913
|
-
counts = self.get_counts(
|
|
907
|
+
counts = self.get_counts(log_mu, library_sizes=library_sizes)
|
|
914
908
|
|
|
915
|
-
return counts,
|
|
909
|
+
return counts, log_mu
|
|
916
910
|
|
|
917
911
|
def _cell_shift(self, zs, perturb_idx, perturb):
|
|
918
912
|
#zns,_ = self.encoder_zn(xs)
|
|
@@ -956,48 +950,46 @@ class DensityFlow2(nn.Module):
|
|
|
956
950
|
Z = np.concatenate(Z)
|
|
957
951
|
return Z
|
|
958
952
|
|
|
959
|
-
def
|
|
960
|
-
return self.
|
|
953
|
+
def _log_mu(self, zs):
|
|
954
|
+
return self.decoder_log_mu(zs)
|
|
961
955
|
|
|
962
|
-
def
|
|
963
|
-
delta_zs,
|
|
964
|
-
batch_size: int = 1024):
|
|
956
|
+
def get_log_mu(self, zs, batch_size: int = 1024):
|
|
965
957
|
"""
|
|
966
958
|
Return cells' changes in the feature space induced by specific perturbation of a factor
|
|
967
959
|
|
|
968
960
|
"""
|
|
969
|
-
|
|
970
|
-
dataset = CustomDataset(
|
|
961
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
962
|
+
dataset = CustomDataset(zs)
|
|
971
963
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
972
964
|
|
|
973
965
|
R = []
|
|
974
966
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
975
|
-
for
|
|
976
|
-
r = self.
|
|
967
|
+
for Z_batch, _ in dataloader:
|
|
968
|
+
r = self._log_mu(Z_batch)
|
|
977
969
|
R.append(tensor_to_numpy(r))
|
|
978
970
|
pbar.update(1)
|
|
979
971
|
|
|
980
972
|
R = np.concatenate(R)
|
|
981
973
|
return R
|
|
982
974
|
|
|
983
|
-
def _count(self,
|
|
975
|
+
def _count(self, log_mu, library_size=None):
|
|
984
976
|
if self.loss_func == 'bernoulli':
|
|
985
977
|
#counts = self.sigmoid(concentrate)
|
|
986
|
-
counts = dist.Bernoulli(logits=
|
|
978
|
+
counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
|
|
987
979
|
elif self.loss_func == 'multinomial':
|
|
988
|
-
theta = dist.Multinomial(total_count=int(1e8), logits=
|
|
980
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
|
|
989
981
|
counts = theta * library_size
|
|
990
982
|
else:
|
|
991
|
-
rate =
|
|
983
|
+
rate = log_mu.exp()
|
|
992
984
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
993
985
|
counts = theta * library_size
|
|
994
986
|
return counts
|
|
995
987
|
|
|
996
|
-
def get_counts(self,
|
|
988
|
+
def get_counts(self, log_mu,
|
|
997
989
|
library_sizes,
|
|
998
990
|
batch_size: int = 1024):
|
|
999
991
|
|
|
1000
|
-
|
|
992
|
+
log_mu = convert_to_tensor(log_mu, device=self.get_device())
|
|
1001
993
|
|
|
1002
994
|
if type(library_sizes) == list:
|
|
1003
995
|
library_sizes = np.array(library_sizes).reshape(-1,1)
|
|
@@ -1005,13 +997,13 @@ class DensityFlow2(nn.Module):
|
|
|
1005
997
|
library_sizes = library_sizes.reshape(-1,1)
|
|
1006
998
|
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
1007
999
|
|
|
1008
|
-
dataset = CustomDataset2(
|
|
1000
|
+
dataset = CustomDataset2(log_mu,ls)
|
|
1009
1001
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
1010
1002
|
|
|
1011
1003
|
E = []
|
|
1012
1004
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
1013
|
-
for
|
|
1014
|
-
counts = self._count(
|
|
1005
|
+
for Mu_batch, L_batch, _ in dataloader:
|
|
1006
|
+
counts = self._count(Mu_batch, L_batch)
|
|
1015
1007
|
E.append(tensor_to_numpy(counts))
|
|
1016
1008
|
pbar.update(1)
|
|
1017
1009
|
|
|
@@ -1045,7 +1037,7 @@ class DensityFlow2(nn.Module):
|
|
|
1045
1037
|
threshold: int = 0,
|
|
1046
1038
|
use_jax: bool = True):
|
|
1047
1039
|
"""
|
|
1048
|
-
Train the
|
|
1040
|
+
Train the DensityFlowLinear model.
|
|
1049
1041
|
|
|
1050
1042
|
Parameters
|
|
1051
1043
|
----------
|
|
@@ -1071,7 +1063,7 @@ class DensityFlow2(nn.Module):
|
|
|
1071
1063
|
Parameter for optimization.
|
|
1072
1064
|
use_jax
|
|
1073
1065
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
1074
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
1066
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlowLinear in the shell command.
|
|
1075
1067
|
"""
|
|
1076
1068
|
xs = self.preprocess(xs, threshold=threshold)
|
|
1077
1069
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1189,12 +1181,12 @@ class DensityFlow2(nn.Module):
|
|
|
1189
1181
|
|
|
1190
1182
|
|
|
1191
1183
|
EXAMPLE_RUN = (
|
|
1192
|
-
"example run:
|
|
1184
|
+
"example run: DensityFlowLinear --help"
|
|
1193
1185
|
)
|
|
1194
1186
|
|
|
1195
1187
|
def parse_args():
|
|
1196
1188
|
parser = argparse.ArgumentParser(
|
|
1197
|
-
description="
|
|
1189
|
+
description="DensityFlowLinear\n{}".format(EXAMPLE_RUN))
|
|
1198
1190
|
|
|
1199
1191
|
parser.add_argument(
|
|
1200
1192
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1381,10 +1373,10 @@ def main():
|
|
|
1381
1373
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1382
1374
|
|
|
1383
1375
|
###########################################
|
|
1384
|
-
df =
|
|
1376
|
+
df = DensityFlowLinear(
|
|
1385
1377
|
input_size=input_size,
|
|
1386
1378
|
cell_factor_size=cell_factor_size,
|
|
1387
|
-
|
|
1379
|
+
dispersion=args.dispersion,
|
|
1388
1380
|
z_dim=args.z_dim,
|
|
1389
1381
|
hidden_layers=args.hidden_layers,
|
|
1390
1382
|
hidden_layer_activation=args.hidden_layer_activation,
|
|
@@ -1412,9 +1404,9 @@ def main():
|
|
|
1412
1404
|
|
|
1413
1405
|
if args.save_model is not None:
|
|
1414
1406
|
if args.save_model.endswith('gz'):
|
|
1415
|
-
|
|
1407
|
+
DensityFlowLinear.save_model(df, args.save_model, compression=True)
|
|
1416
1408
|
else:
|
|
1417
|
-
|
|
1409
|
+
DensityFlowLinear.save_model(df, args.save_model)
|
|
1418
1410
|
|
|
1419
1411
|
|
|
1420
1412
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
2
|
from .DensityFlow import DensityFlow
|
|
3
|
-
from .
|
|
3
|
+
from .DensityFlowLinear import DensityFlowLinear
|
|
4
4
|
from .PerturbE import PerturbE
|
|
5
5
|
from .TranscriptomeDecoder import TranscriptomeDecoder
|
|
6
6
|
from .SimpleTranscriptomeDecoder import SimpleTranscriptomeDecoder
|
|
@@ -12,7 +12,7 @@ from . import utils
|
|
|
12
12
|
from . import codebook
|
|
13
13
|
from . import SURE
|
|
14
14
|
from . import DensityFlow
|
|
15
|
-
from . import
|
|
15
|
+
from . import DensityFlowLinear
|
|
16
16
|
from . import atac
|
|
17
17
|
from . import flow
|
|
18
18
|
from . import perturb
|
|
@@ -23,6 +23,6 @@ from . import EfficientTranscriptomeDecoder
|
|
|
23
23
|
from . import VirtualCellDecoder
|
|
24
24
|
from . import PerturbationAwareDecoder
|
|
25
25
|
|
|
26
|
-
__all__ = ['SURE', 'DensityFlow', '
|
|
26
|
+
__all__ = ['SURE', 'DensityFlow', 'DensityFlowLinear', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
|
|
27
27
|
'EfficientTranscriptomeDecoder', 'VirtualCellDecoder', 'PerturbationAwareDecoder',
|
|
28
28
|
'flow', 'perturb', 'atac', 'utils', 'codebook']
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|