SURE-tools 2.4.20__tar.gz → 2.4.35__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-2.4.20 → sure_tools-2.4.35}/PKG-INFO +1 -1
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/DensityFlow.py +122 -71
- sure_tools-2.4.35/SURE/DensityFlowLinear.py +1414 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/PerturbationAwareDecoder.py +217 -250
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/__init__.py +3 -1
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE_tools.egg-info/SOURCES.txt +1 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/setup.py +1 -1
- {sure_tools-2.4.20 → sure_tools-2.4.35}/LICENSE +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/README.md +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/EfficientTranscriptomeDecoder.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/PerturbE.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/SURE.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/SimpleTranscriptomeDecoder.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/TranscriptomeDecoder.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/VirtualCellDecoder.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/atac/utils.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/utils/queue.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/utils/utils.py +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.4.20 → sure_tools-2.4.35}/setup.cfg +0 -0
|
@@ -57,16 +57,16 @@ def set_random_seed(seed):
|
|
|
57
57
|
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
|
-
codebook_size: int =
|
|
60
|
+
codebook_size: int = 100,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
62
|
turn_off_cell_specific: bool = False,
|
|
63
63
|
supervised_mode: bool = False,
|
|
64
|
-
z_dim: int =
|
|
65
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
67
|
-
|
|
64
|
+
z_dim: int = 50,
|
|
65
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
|
+
dispersion: float = 8.0,
|
|
68
68
|
use_zeroinflate: bool = False,
|
|
69
|
-
hidden_layers: list = [
|
|
69
|
+
hidden_layers: list = [1024],
|
|
70
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
71
71
|
nn_dropout: float = 0.1,
|
|
72
72
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
self.input_size = input_size
|
|
83
83
|
self.cell_factor_size = cell_factor_size
|
|
84
|
-
self.
|
|
84
|
+
self.dispersion = dispersion
|
|
85
85
|
self.latent_dim = z_dim
|
|
86
86
|
self.hidden_layers = hidden_layers
|
|
87
87
|
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
88
|
+
self.config_enum = config_enum
|
|
88
89
|
self.allow_broadcast = config_enum == 'parallel'
|
|
89
90
|
self.use_cuda = use_cuda
|
|
90
91
|
self.loss_func = loss_func
|
|
@@ -107,6 +108,7 @@ class DensityFlow(nn.Module):
|
|
|
107
108
|
|
|
108
109
|
self.codebook_weights = None
|
|
109
110
|
|
|
111
|
+
self.seed = seed
|
|
110
112
|
set_random_seed(seed)
|
|
111
113
|
self.setup_networks()
|
|
112
114
|
|
|
@@ -115,7 +117,7 @@ class DensityFlow(nn.Module):
|
|
|
115
117
|
print(f" - Gene Dimension: {self.input_size}")
|
|
116
118
|
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
117
119
|
print(f" - Device: {self.get_device()}")
|
|
118
|
-
print(f" - Parameters: {sum(p.numel() for p in self.
|
|
120
|
+
print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
|
|
119
121
|
|
|
120
122
|
def setup_networks(self):
|
|
121
123
|
latent_dim = self.latent_dim
|
|
@@ -258,7 +260,7 @@ class DensityFlow(nn.Module):
|
|
|
258
260
|
)
|
|
259
261
|
)
|
|
260
262
|
|
|
261
|
-
self.
|
|
263
|
+
self.decoder_log_mu = MLP(
|
|
262
264
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
263
265
|
activation=activate_fct,
|
|
264
266
|
output_activation=None,
|
|
@@ -348,8 +350,8 @@ class DensityFlow(nn.Module):
|
|
|
348
350
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
349
351
|
|
|
350
352
|
if self.loss_func=='negbinomial':
|
|
351
|
-
|
|
352
|
-
|
|
353
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
354
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
353
355
|
|
|
354
356
|
if self.use_zeroinflate:
|
|
355
357
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -383,28 +385,32 @@ class DensityFlow(nn.Module):
|
|
|
383
385
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
384
386
|
|
|
385
387
|
zs = zns
|
|
386
|
-
|
|
388
|
+
log_mu = self.decoder_log_mu(zs)
|
|
387
389
|
if self.loss_func in ['bernoulli']:
|
|
388
|
-
log_theta =
|
|
390
|
+
log_theta = log_mu
|
|
391
|
+
elif self.loss_func == 'negbinomial':
|
|
392
|
+
mu = log_mu.exp()
|
|
389
393
|
else:
|
|
390
|
-
rate =
|
|
394
|
+
rate = log_mu.exp()
|
|
391
395
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
392
396
|
if self.loss_func == 'poisson':
|
|
393
397
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
394
398
|
|
|
395
399
|
if self.loss_func == 'negbinomial':
|
|
400
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
396
401
|
if self.use_zeroinflate:
|
|
397
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
402
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
403
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
398
404
|
else:
|
|
399
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
405
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
|
|
406
|
+
logits=logits).to_event(1), obs=xs)
|
|
400
407
|
elif self.loss_func == 'poisson':
|
|
401
408
|
if self.use_zeroinflate:
|
|
402
409
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
403
410
|
else:
|
|
404
411
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
405
412
|
elif self.loss_func == 'multinomial':
|
|
406
|
-
|
|
407
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
413
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
408
414
|
elif self.loss_func == 'bernoulli':
|
|
409
415
|
if self.use_zeroinflate:
|
|
410
416
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -428,8 +434,8 @@ class DensityFlow(nn.Module):
|
|
|
428
434
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
429
435
|
|
|
430
436
|
if self.loss_func=='negbinomial':
|
|
431
|
-
|
|
432
|
-
|
|
437
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
438
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
433
439
|
|
|
434
440
|
if self.use_zeroinflate:
|
|
435
441
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -468,28 +474,30 @@ class DensityFlow(nn.Module):
|
|
|
468
474
|
else:
|
|
469
475
|
zs = zns
|
|
470
476
|
|
|
471
|
-
|
|
477
|
+
log_mu = self.decoder_log_mu(zs)
|
|
472
478
|
if self.loss_func in ['bernoulli']:
|
|
473
|
-
log_theta =
|
|
479
|
+
log_theta = log_mu
|
|
480
|
+
elif self.loss_func == 'negbinomial':
|
|
481
|
+
mu = log_mu.exp()
|
|
474
482
|
else:
|
|
475
|
-
rate =
|
|
483
|
+
rate = log_mu.exp()
|
|
476
484
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
477
485
|
if self.loss_func == 'poisson':
|
|
478
486
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
479
487
|
|
|
480
488
|
if self.loss_func == 'negbinomial':
|
|
489
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
481
490
|
if self.use_zeroinflate:
|
|
482
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
491
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
483
492
|
else:
|
|
484
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
493
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
485
494
|
elif self.loss_func == 'poisson':
|
|
486
495
|
if self.use_zeroinflate:
|
|
487
496
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
488
497
|
else:
|
|
489
498
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
490
499
|
elif self.loss_func == 'multinomial':
|
|
491
|
-
|
|
492
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
500
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
493
501
|
elif self.loss_func == 'bernoulli':
|
|
494
502
|
if self.use_zeroinflate:
|
|
495
503
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -513,8 +521,8 @@ class DensityFlow(nn.Module):
|
|
|
513
521
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
514
522
|
|
|
515
523
|
if self.loss_func=='negbinomial':
|
|
516
|
-
|
|
517
|
-
|
|
524
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
525
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
518
526
|
|
|
519
527
|
if self.use_zeroinflate:
|
|
520
528
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -565,28 +573,31 @@ class DensityFlow(nn.Module):
|
|
|
565
573
|
|
|
566
574
|
zs = zns
|
|
567
575
|
|
|
568
|
-
|
|
576
|
+
log_mu = self.decoder_log_mu(zs)
|
|
569
577
|
if self.loss_func in ['bernoulli']:
|
|
570
|
-
log_theta =
|
|
578
|
+
log_theta = log_mu
|
|
579
|
+
elif self.loss_func in ['negbinomial']:
|
|
580
|
+
mu = log_mu.exp()
|
|
571
581
|
else:
|
|
572
|
-
rate =
|
|
582
|
+
rate = log_mu.exp()
|
|
573
583
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
574
584
|
if self.loss_func == 'poisson':
|
|
575
585
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
576
586
|
|
|
577
587
|
if self.loss_func == 'negbinomial':
|
|
588
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
578
589
|
if self.use_zeroinflate:
|
|
579
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
590
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
591
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
580
592
|
else:
|
|
581
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
593
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
582
594
|
elif self.loss_func == 'poisson':
|
|
583
595
|
if self.use_zeroinflate:
|
|
584
596
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
585
597
|
else:
|
|
586
598
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
587
599
|
elif self.loss_func == 'multinomial':
|
|
588
|
-
|
|
589
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
600
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
590
601
|
elif self.loss_func == 'bernoulli':
|
|
591
602
|
if self.use_zeroinflate:
|
|
592
603
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -610,8 +621,8 @@ class DensityFlow(nn.Module):
|
|
|
610
621
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
611
622
|
|
|
612
623
|
if self.loss_func=='negbinomial':
|
|
613
|
-
|
|
614
|
-
|
|
624
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
625
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
615
626
|
|
|
616
627
|
if self.use_zeroinflate:
|
|
617
628
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -672,28 +683,31 @@ class DensityFlow(nn.Module):
|
|
|
672
683
|
else:
|
|
673
684
|
zs = zns
|
|
674
685
|
|
|
675
|
-
|
|
686
|
+
log_mu = self.decoder_log_mu(zs)
|
|
676
687
|
if self.loss_func in ['bernoulli']:
|
|
677
|
-
log_theta =
|
|
688
|
+
log_theta = log_mu
|
|
689
|
+
elif self.loss_func in ['negbinomial']:
|
|
690
|
+
mu = log_mu.exp()
|
|
678
691
|
else:
|
|
679
|
-
rate =
|
|
692
|
+
rate = log_mu.exp()
|
|
680
693
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
681
694
|
if self.loss_func == 'poisson':
|
|
682
695
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
683
696
|
|
|
684
697
|
if self.loss_func == 'negbinomial':
|
|
698
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
685
699
|
if self.use_zeroinflate:
|
|
686
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
700
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
701
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
687
702
|
else:
|
|
688
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
703
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
689
704
|
elif self.loss_func == 'poisson':
|
|
690
705
|
if self.use_zeroinflate:
|
|
691
706
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
692
707
|
else:
|
|
693
708
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
694
709
|
elif self.loss_func == 'multinomial':
|
|
695
|
-
|
|
696
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
710
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
697
711
|
elif self.loss_func == 'bernoulli':
|
|
698
712
|
if self.use_zeroinflate:
|
|
699
713
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -717,13 +731,13 @@ class DensityFlow(nn.Module):
|
|
|
717
731
|
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
718
732
|
#else:
|
|
719
733
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
720
|
-
zus = self.
|
|
734
|
+
zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
721
735
|
else:
|
|
722
736
|
#if self.turn_off_cell_specific:
|
|
723
737
|
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
724
738
|
#else:
|
|
725
739
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
726
|
-
zus = zus + self.
|
|
740
|
+
zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
727
741
|
return zus
|
|
728
742
|
|
|
729
743
|
def _get_codebook_identity(self):
|
|
@@ -865,12 +879,12 @@ class DensityFlow(nn.Module):
|
|
|
865
879
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
866
880
|
|
|
867
881
|
# factor effect of xs
|
|
868
|
-
dzs0 = self.
|
|
882
|
+
dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
869
883
|
|
|
870
884
|
# perturbation effect
|
|
871
885
|
ps = np.ones_like(us_i)
|
|
872
886
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
873
|
-
dzs = self.
|
|
887
|
+
dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
874
888
|
zs = zs + dzs0 + dzs
|
|
875
889
|
else:
|
|
876
890
|
zs = zs + dzs0
|
|
@@ -884,10 +898,11 @@ class DensityFlow(nn.Module):
|
|
|
884
898
|
library_sizes = library_sizes.reshape(-1,1)
|
|
885
899
|
|
|
886
900
|
counts = self.get_counts(zs, library_sizes=library_sizes)
|
|
901
|
+
log_mu = self.get_log_mu(zs)
|
|
887
902
|
|
|
888
|
-
return counts,
|
|
903
|
+
return counts, log_mu
|
|
889
904
|
|
|
890
|
-
def
|
|
905
|
+
def _cell_shift(self, zs, perturb_idx, perturb):
|
|
891
906
|
#zns,_ = self.encoder_zn(xs)
|
|
892
907
|
#zns,_ = self._get_basal_embedding(xs)
|
|
893
908
|
zns = zs
|
|
@@ -904,7 +919,7 @@ class DensityFlow(nn.Module):
|
|
|
904
919
|
|
|
905
920
|
return ms
|
|
906
921
|
|
|
907
|
-
def
|
|
922
|
+
def get_cell_shift(self,
|
|
908
923
|
zs,
|
|
909
924
|
perturb_idx,
|
|
910
925
|
perturb_us,
|
|
@@ -922,46 +937,43 @@ class DensityFlow(nn.Module):
|
|
|
922
937
|
Z = []
|
|
923
938
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
924
939
|
for Z_batch, P_batch, _ in dataloader:
|
|
925
|
-
zns = self.
|
|
940
|
+
zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
|
|
926
941
|
Z.append(tensor_to_numpy(zns))
|
|
927
942
|
pbar.update(1)
|
|
928
943
|
|
|
929
944
|
Z = np.concatenate(Z)
|
|
930
945
|
return Z
|
|
931
946
|
|
|
932
|
-
def
|
|
933
|
-
return self.
|
|
947
|
+
def _log_mu(self, zs):
|
|
948
|
+
return self.decoder_log_mu(zs)
|
|
934
949
|
|
|
935
|
-
def
|
|
936
|
-
delta_zs,
|
|
937
|
-
batch_size: int = 1024):
|
|
950
|
+
def get_log_mu(self, zs, batch_size: int = 1024):
|
|
938
951
|
"""
|
|
939
952
|
Return cells' changes in the feature space induced by specific perturbation of a factor
|
|
940
953
|
|
|
941
954
|
"""
|
|
942
|
-
|
|
943
|
-
dataset = CustomDataset(
|
|
955
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
956
|
+
dataset = CustomDataset(zs)
|
|
944
957
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
945
958
|
|
|
946
959
|
R = []
|
|
947
960
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
948
|
-
for
|
|
949
|
-
r = self.
|
|
961
|
+
for Z_batch, _ in dataloader:
|
|
962
|
+
r = self._log_mu(Z_batch)
|
|
950
963
|
R.append(tensor_to_numpy(r))
|
|
951
964
|
pbar.update(1)
|
|
952
965
|
|
|
953
966
|
R = np.concatenate(R)
|
|
954
967
|
return R
|
|
955
968
|
|
|
956
|
-
def _count(self,
|
|
969
|
+
def _count(self, log_mu, library_size=None):
|
|
957
970
|
if self.loss_func == 'bernoulli':
|
|
958
|
-
|
|
959
|
-
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
971
|
+
counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
|
|
960
972
|
elif self.loss_func == 'multinomial':
|
|
961
|
-
theta = dist.Multinomial(total_count=int(1e8), logits=
|
|
973
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
|
|
962
974
|
counts = theta * library_size
|
|
963
975
|
else:
|
|
964
|
-
rate =
|
|
976
|
+
rate = log_mu.exp()
|
|
965
977
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
966
978
|
counts = theta * library_size
|
|
967
979
|
return counts
|
|
@@ -983,8 +995,8 @@ class DensityFlow(nn.Module):
|
|
|
983
995
|
E = []
|
|
984
996
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
985
997
|
for Z_batch, L_batch, _ in dataloader:
|
|
986
|
-
|
|
987
|
-
counts = self._count(
|
|
998
|
+
log_mu = self._log_mu(Z_batch)
|
|
999
|
+
counts = self._count(log_mu, L_batch)
|
|
988
1000
|
E.append(tensor_to_numpy(counts))
|
|
989
1001
|
pbar.update(1)
|
|
990
1002
|
|
|
@@ -1130,7 +1142,7 @@ class DensityFlow(nn.Module):
|
|
|
1130
1142
|
pbar.set_postfix({'loss': str_loss})
|
|
1131
1143
|
pbar.update(1)
|
|
1132
1144
|
|
|
1133
|
-
@classmethod
|
|
1145
|
+
'''@classmethod
|
|
1134
1146
|
def save_model(cls, model, file_path, compression=False):
|
|
1135
1147
|
"""Save the model to the specified file path."""
|
|
1136
1148
|
file_path = os.path.abspath(file_path)
|
|
@@ -1158,6 +1170,45 @@ class DensityFlow(nn.Module):
|
|
|
1158
1170
|
with open(file_path, 'rb') as pickle_file:
|
|
1159
1171
|
model = pickle.load(pickle_file)
|
|
1160
1172
|
|
|
1173
|
+
return model'''
|
|
1174
|
+
|
|
1175
|
+
def save(self, path):
|
|
1176
|
+
"""Save model checkpoint"""
|
|
1177
|
+
torch.save({
|
|
1178
|
+
'model_state_dict': self.state_dict(),
|
|
1179
|
+
'model_config': {
|
|
1180
|
+
'input_size': self.input_size,
|
|
1181
|
+
'codebook_size': self.code_size,
|
|
1182
|
+
'cell_factor_size': self.cell_factor_size,
|
|
1183
|
+
'turn_off_cell_specific':self.turn_off_cell_specific,
|
|
1184
|
+
'supervised_mode':self.supervised_mode,
|
|
1185
|
+
'z_dim': self.latent_dim,
|
|
1186
|
+
'z_dist': self.latent_dist,
|
|
1187
|
+
'loss_func': self.loss_func,
|
|
1188
|
+
'dispersion': self.dispersion,
|
|
1189
|
+
'use_zeroinflate': self.use_zeroinflate,
|
|
1190
|
+
'hidden_layers':self.hidden_layers,
|
|
1191
|
+
'hidden_layer_activation':self.hidden_layer_activation,
|
|
1192
|
+
'nn_dropout':self.nn_dropout,
|
|
1193
|
+
'post_layer_fct':self.post_layer_fct,
|
|
1194
|
+
'post_act_fct':self.post_act_fct,
|
|
1195
|
+
'config_enum':self.config_enum,
|
|
1196
|
+
'use_cuda':self.use_cuda,
|
|
1197
|
+
'seed':self.seed,
|
|
1198
|
+
'zero_bias':self.use_bias,
|
|
1199
|
+
'dtype':self.dtype,
|
|
1200
|
+
}
|
|
1201
|
+
}, path)
|
|
1202
|
+
|
|
1203
|
+
@classmethod
|
|
1204
|
+
def load_model(cls, model_path: str):
|
|
1205
|
+
"""Load pre-trained model"""
|
|
1206
|
+
checkpoint = torch.load(model_path)
|
|
1207
|
+
model = DensityFlow(**checkpoint.get('model_config'))
|
|
1208
|
+
|
|
1209
|
+
checkpoint = torch.load(model_path, map_location=model.get_device())
|
|
1210
|
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
1211
|
+
|
|
1161
1212
|
return model
|
|
1162
1213
|
|
|
1163
1214
|
|
|
@@ -1357,7 +1408,7 @@ def main():
|
|
|
1357
1408
|
df = DensityFlow(
|
|
1358
1409
|
input_size=input_size,
|
|
1359
1410
|
cell_factor_size=cell_factor_size,
|
|
1360
|
-
|
|
1411
|
+
dispersion=args.dispersion,
|
|
1361
1412
|
z_dim=args.z_dim,
|
|
1362
1413
|
hidden_layers=args.hidden_layers,
|
|
1363
1414
|
hidden_layer_activation=args.hidden_layer_activation,
|