SURE-tools 2.4.22__py3-none-any.whl → 2.4.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +151 -69
- SURE/DensityFlow2.py +1422 -0
- SURE/DensityFlowLinear.py +1414 -0
- SURE/PerturbationAwareDecoder.py +162 -148
- SURE/__init__.py +3 -1
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/METADATA +1 -1
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/RECORD +11 -9
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.43.dist-info}/top_level.txt +0 -0
SURE/DensityFlow.py
CHANGED
|
@@ -57,16 +57,16 @@ def set_random_seed(seed):
|
|
|
57
57
|
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
|
-
codebook_size: int =
|
|
60
|
+
codebook_size: int = 30,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
62
|
turn_off_cell_specific: bool = False,
|
|
63
63
|
supervised_mode: bool = False,
|
|
64
|
-
z_dim: int =
|
|
65
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
67
|
-
|
|
64
|
+
z_dim: int = 50,
|
|
65
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
|
+
dispersion: float = 8.0,
|
|
68
68
|
use_zeroinflate: bool = False,
|
|
69
|
-
hidden_layers: list = [
|
|
69
|
+
hidden_layers: list = [1024],
|
|
70
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
71
71
|
nn_dropout: float = 0.1,
|
|
72
72
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
self.input_size = input_size
|
|
83
83
|
self.cell_factor_size = cell_factor_size
|
|
84
|
-
self.
|
|
84
|
+
self.dispersion = dispersion
|
|
85
85
|
self.latent_dim = z_dim
|
|
86
86
|
self.hidden_layers = hidden_layers
|
|
87
87
|
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
88
|
+
self.config_enum = config_enum
|
|
88
89
|
self.allow_broadcast = config_enum == 'parallel'
|
|
89
90
|
self.use_cuda = use_cuda
|
|
90
91
|
self.loss_func = loss_func
|
|
@@ -107,10 +108,12 @@ class DensityFlow(nn.Module):
|
|
|
107
108
|
|
|
108
109
|
self.codebook_weights = None
|
|
109
110
|
|
|
111
|
+
self.seed = seed
|
|
110
112
|
set_random_seed(seed)
|
|
111
113
|
self.setup_networks()
|
|
112
114
|
|
|
113
115
|
print(f"🧬 DensityFlow Initialized:")
|
|
116
|
+
print(f" - Codebook size: {self.code_size}")
|
|
114
117
|
print(f" - Latent Dimension: {self.latent_dim}")
|
|
115
118
|
print(f" - Gene Dimension: {self.input_size}")
|
|
116
119
|
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
@@ -258,7 +261,7 @@ class DensityFlow(nn.Module):
|
|
|
258
261
|
)
|
|
259
262
|
)
|
|
260
263
|
|
|
261
|
-
self.
|
|
264
|
+
self.decoder_log_mu = MLP(
|
|
262
265
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
263
266
|
activation=activate_fct,
|
|
264
267
|
output_activation=None,
|
|
@@ -348,8 +351,8 @@ class DensityFlow(nn.Module):
|
|
|
348
351
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
349
352
|
|
|
350
353
|
if self.loss_func=='negbinomial':
|
|
351
|
-
|
|
352
|
-
|
|
354
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
355
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
353
356
|
|
|
354
357
|
if self.use_zeroinflate:
|
|
355
358
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -383,28 +386,32 @@ class DensityFlow(nn.Module):
|
|
|
383
386
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
384
387
|
|
|
385
388
|
zs = zns
|
|
386
|
-
|
|
389
|
+
log_mu = self.decoder_log_mu(zs)
|
|
387
390
|
if self.loss_func in ['bernoulli']:
|
|
388
|
-
log_theta =
|
|
391
|
+
log_theta = log_mu
|
|
392
|
+
elif self.loss_func == 'negbinomial':
|
|
393
|
+
mu = log_mu.exp()
|
|
389
394
|
else:
|
|
390
|
-
rate =
|
|
395
|
+
rate = log_mu.exp()
|
|
391
396
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
392
397
|
if self.loss_func == 'poisson':
|
|
393
398
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
394
399
|
|
|
395
400
|
if self.loss_func == 'negbinomial':
|
|
401
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
396
402
|
if self.use_zeroinflate:
|
|
397
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
403
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
404
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
398
405
|
else:
|
|
399
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
406
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
|
|
407
|
+
logits=logits).to_event(1), obs=xs)
|
|
400
408
|
elif self.loss_func == 'poisson':
|
|
401
409
|
if self.use_zeroinflate:
|
|
402
410
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
403
411
|
else:
|
|
404
412
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
405
413
|
elif self.loss_func == 'multinomial':
|
|
406
|
-
|
|
407
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
414
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
408
415
|
elif self.loss_func == 'bernoulli':
|
|
409
416
|
if self.use_zeroinflate:
|
|
410
417
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -428,8 +435,8 @@ class DensityFlow(nn.Module):
|
|
|
428
435
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
429
436
|
|
|
430
437
|
if self.loss_func=='negbinomial':
|
|
431
|
-
|
|
432
|
-
|
|
438
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
439
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
433
440
|
|
|
434
441
|
if self.use_zeroinflate:
|
|
435
442
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -468,28 +475,30 @@ class DensityFlow(nn.Module):
|
|
|
468
475
|
else:
|
|
469
476
|
zs = zns
|
|
470
477
|
|
|
471
|
-
|
|
478
|
+
log_mu = self.decoder_log_mu(zs)
|
|
472
479
|
if self.loss_func in ['bernoulli']:
|
|
473
|
-
log_theta =
|
|
480
|
+
log_theta = log_mu
|
|
481
|
+
elif self.loss_func == 'negbinomial':
|
|
482
|
+
mu = log_mu.exp()
|
|
474
483
|
else:
|
|
475
|
-
rate =
|
|
484
|
+
rate = log_mu.exp()
|
|
476
485
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
477
486
|
if self.loss_func == 'poisson':
|
|
478
487
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
479
488
|
|
|
480
489
|
if self.loss_func == 'negbinomial':
|
|
490
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
481
491
|
if self.use_zeroinflate:
|
|
482
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
492
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
483
493
|
else:
|
|
484
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
494
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
485
495
|
elif self.loss_func == 'poisson':
|
|
486
496
|
if self.use_zeroinflate:
|
|
487
497
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
488
498
|
else:
|
|
489
499
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
490
500
|
elif self.loss_func == 'multinomial':
|
|
491
|
-
|
|
492
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
501
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
493
502
|
elif self.loss_func == 'bernoulli':
|
|
494
503
|
if self.use_zeroinflate:
|
|
495
504
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -513,8 +522,8 @@ class DensityFlow(nn.Module):
|
|
|
513
522
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
514
523
|
|
|
515
524
|
if self.loss_func=='negbinomial':
|
|
516
|
-
|
|
517
|
-
|
|
525
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
526
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
518
527
|
|
|
519
528
|
if self.use_zeroinflate:
|
|
520
529
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -565,28 +574,31 @@ class DensityFlow(nn.Module):
|
|
|
565
574
|
|
|
566
575
|
zs = zns
|
|
567
576
|
|
|
568
|
-
|
|
577
|
+
log_mu = self.decoder_log_mu(zs)
|
|
569
578
|
if self.loss_func in ['bernoulli']:
|
|
570
|
-
log_theta =
|
|
579
|
+
log_theta = log_mu
|
|
580
|
+
elif self.loss_func in ['negbinomial']:
|
|
581
|
+
mu = log_mu.exp()
|
|
571
582
|
else:
|
|
572
|
-
rate =
|
|
583
|
+
rate = log_mu.exp()
|
|
573
584
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
574
585
|
if self.loss_func == 'poisson':
|
|
575
586
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
576
587
|
|
|
577
588
|
if self.loss_func == 'negbinomial':
|
|
589
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
578
590
|
if self.use_zeroinflate:
|
|
579
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
591
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
592
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
580
593
|
else:
|
|
581
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
594
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
582
595
|
elif self.loss_func == 'poisson':
|
|
583
596
|
if self.use_zeroinflate:
|
|
584
597
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
585
598
|
else:
|
|
586
599
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
587
600
|
elif self.loss_func == 'multinomial':
|
|
588
|
-
|
|
589
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
601
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
590
602
|
elif self.loss_func == 'bernoulli':
|
|
591
603
|
if self.use_zeroinflate:
|
|
592
604
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -610,8 +622,8 @@ class DensityFlow(nn.Module):
|
|
|
610
622
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
611
623
|
|
|
612
624
|
if self.loss_func=='negbinomial':
|
|
613
|
-
|
|
614
|
-
|
|
625
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
626
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
615
627
|
|
|
616
628
|
if self.use_zeroinflate:
|
|
617
629
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -672,28 +684,31 @@ class DensityFlow(nn.Module):
|
|
|
672
684
|
else:
|
|
673
685
|
zs = zns
|
|
674
686
|
|
|
675
|
-
|
|
687
|
+
log_mu = self.decoder_log_mu(zs)
|
|
676
688
|
if self.loss_func in ['bernoulli']:
|
|
677
|
-
log_theta =
|
|
689
|
+
log_theta = log_mu
|
|
690
|
+
elif self.loss_func in ['negbinomial']:
|
|
691
|
+
mu = log_mu.exp()
|
|
678
692
|
else:
|
|
679
|
-
rate =
|
|
693
|
+
rate = log_mu.exp()
|
|
680
694
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
681
695
|
if self.loss_func == 'poisson':
|
|
682
696
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
683
697
|
|
|
684
698
|
if self.loss_func == 'negbinomial':
|
|
699
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
685
700
|
if self.use_zeroinflate:
|
|
686
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
701
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
702
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
687
703
|
else:
|
|
688
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
704
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
689
705
|
elif self.loss_func == 'poisson':
|
|
690
706
|
if self.use_zeroinflate:
|
|
691
707
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
692
708
|
else:
|
|
693
709
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
694
710
|
elif self.loss_func == 'multinomial':
|
|
695
|
-
|
|
696
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
711
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
697
712
|
elif self.loss_func == 'bernoulli':
|
|
698
713
|
if self.use_zeroinflate:
|
|
699
714
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -717,13 +732,13 @@ class DensityFlow(nn.Module):
|
|
|
717
732
|
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
718
733
|
#else:
|
|
719
734
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
720
|
-
zus = self.
|
|
735
|
+
zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
721
736
|
else:
|
|
722
737
|
#if self.turn_off_cell_specific:
|
|
723
738
|
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
724
739
|
#else:
|
|
725
740
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
726
|
-
zus = zus + self.
|
|
741
|
+
zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
727
742
|
return zus
|
|
728
743
|
|
|
729
744
|
def _get_codebook_identity(self):
|
|
@@ -744,6 +759,28 @@ class DensityFlow(nn.Module):
|
|
|
744
759
|
cb = self._get_codebook()
|
|
745
760
|
cb = tensor_to_numpy(cb)
|
|
746
761
|
return cb
|
|
762
|
+
|
|
763
|
+
def _get_complete_embedding(self, xs, us):
|
|
764
|
+
basal,_ = self._get_basal_embedding(xs)
|
|
765
|
+
dzs = self._total_effects(basal, us)
|
|
766
|
+
return basal + dzs
|
|
767
|
+
|
|
768
|
+
def get_complete_embedding(self, xs, us, batch_size:int=1024):
|
|
769
|
+
xs = self.preprocess(xs)
|
|
770
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
771
|
+
us = convert_to_tensor(us, device=self.get_device())
|
|
772
|
+
dataset = CustomDataset2(xs, us)
|
|
773
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
774
|
+
|
|
775
|
+
Z = []
|
|
776
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
777
|
+
for X_batch, U_batch, _ in dataloader:
|
|
778
|
+
zns = self._get_complete_embedding(X_batch, U_batch)
|
|
779
|
+
Z.append(tensor_to_numpy(zns))
|
|
780
|
+
pbar.update(1)
|
|
781
|
+
|
|
782
|
+
Z = np.concatenate(Z)
|
|
783
|
+
return Z
|
|
747
784
|
|
|
748
785
|
def _get_basal_embedding(self, xs):
|
|
749
786
|
loc, scale = self.encoder_zn(xs)
|
|
@@ -865,12 +902,12 @@ class DensityFlow(nn.Module):
|
|
|
865
902
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
866
903
|
|
|
867
904
|
# factor effect of xs
|
|
868
|
-
dzs0 = self.
|
|
905
|
+
dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
869
906
|
|
|
870
907
|
# perturbation effect
|
|
871
908
|
ps = np.ones_like(us_i)
|
|
872
909
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
873
|
-
dzs = self.
|
|
910
|
+
dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
874
911
|
zs = zs + dzs0 + dzs
|
|
875
912
|
else:
|
|
876
913
|
zs = zs + dzs0
|
|
@@ -884,10 +921,11 @@ class DensityFlow(nn.Module):
|
|
|
884
921
|
library_sizes = library_sizes.reshape(-1,1)
|
|
885
922
|
|
|
886
923
|
counts = self.get_counts(zs, library_sizes=library_sizes)
|
|
924
|
+
log_mu = self.get_log_mu(zs)
|
|
887
925
|
|
|
888
|
-
return counts,
|
|
926
|
+
return counts, log_mu
|
|
889
927
|
|
|
890
|
-
def
|
|
928
|
+
def _cell_shift(self, zs, perturb_idx, perturb):
|
|
891
929
|
#zns,_ = self.encoder_zn(xs)
|
|
892
930
|
#zns,_ = self._get_basal_embedding(xs)
|
|
893
931
|
zns = zs
|
|
@@ -904,7 +942,7 @@ class DensityFlow(nn.Module):
|
|
|
904
942
|
|
|
905
943
|
return ms
|
|
906
944
|
|
|
907
|
-
def
|
|
945
|
+
def get_cell_shift(self,
|
|
908
946
|
zs,
|
|
909
947
|
perturb_idx,
|
|
910
948
|
perturb_us,
|
|
@@ -922,46 +960,43 @@ class DensityFlow(nn.Module):
|
|
|
922
960
|
Z = []
|
|
923
961
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
924
962
|
for Z_batch, P_batch, _ in dataloader:
|
|
925
|
-
zns = self.
|
|
963
|
+
zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
|
|
926
964
|
Z.append(tensor_to_numpy(zns))
|
|
927
965
|
pbar.update(1)
|
|
928
966
|
|
|
929
967
|
Z = np.concatenate(Z)
|
|
930
968
|
return Z
|
|
931
969
|
|
|
932
|
-
def
|
|
933
|
-
return self.
|
|
970
|
+
def _log_mu(self, zs):
|
|
971
|
+
return self.decoder_log_mu(zs)
|
|
934
972
|
|
|
935
|
-
def
|
|
936
|
-
delta_zs,
|
|
937
|
-
batch_size: int = 1024):
|
|
973
|
+
def get_log_mu(self, zs, batch_size: int = 1024):
|
|
938
974
|
"""
|
|
939
975
|
Return cells' changes in the feature space induced by specific perturbation of a factor
|
|
940
976
|
|
|
941
977
|
"""
|
|
942
|
-
|
|
943
|
-
dataset = CustomDataset(
|
|
978
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
979
|
+
dataset = CustomDataset(zs)
|
|
944
980
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
945
981
|
|
|
946
982
|
R = []
|
|
947
983
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
948
|
-
for
|
|
949
|
-
r = self.
|
|
984
|
+
for Z_batch, _ in dataloader:
|
|
985
|
+
r = self._log_mu(Z_batch)
|
|
950
986
|
R.append(tensor_to_numpy(r))
|
|
951
987
|
pbar.update(1)
|
|
952
988
|
|
|
953
989
|
R = np.concatenate(R)
|
|
954
990
|
return R
|
|
955
991
|
|
|
956
|
-
def _count(self,
|
|
992
|
+
def _count(self, log_mu, library_size=None):
|
|
957
993
|
if self.loss_func == 'bernoulli':
|
|
958
|
-
|
|
959
|
-
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
994
|
+
counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
|
|
960
995
|
elif self.loss_func == 'multinomial':
|
|
961
|
-
theta = dist.Multinomial(total_count=int(1e8), logits=
|
|
996
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
|
|
962
997
|
counts = theta * library_size
|
|
963
998
|
else:
|
|
964
|
-
rate =
|
|
999
|
+
rate = log_mu.exp()
|
|
965
1000
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
966
1001
|
counts = theta * library_size
|
|
967
1002
|
return counts
|
|
@@ -983,8 +1018,8 @@ class DensityFlow(nn.Module):
|
|
|
983
1018
|
E = []
|
|
984
1019
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
985
1020
|
for Z_batch, L_batch, _ in dataloader:
|
|
986
|
-
|
|
987
|
-
counts = self._count(
|
|
1021
|
+
log_mu = self._log_mu(Z_batch)
|
|
1022
|
+
counts = self._count(log_mu, L_batch)
|
|
988
1023
|
E.append(tensor_to_numpy(counts))
|
|
989
1024
|
pbar.update(1)
|
|
990
1025
|
|
|
@@ -1157,8 +1192,55 @@ class DensityFlow(nn.Module):
|
|
|
1157
1192
|
else:
|
|
1158
1193
|
with open(file_path, 'rb') as pickle_file:
|
|
1159
1194
|
model = pickle.load(pickle_file)
|
|
1195
|
+
|
|
1196
|
+
print(f"🧬 DensityFlow Initialized:")
|
|
1197
|
+
print(f" - Codebook size: {model.code_size}")
|
|
1198
|
+
print(f" - Latent Dimension: {model.latent_dim}")
|
|
1199
|
+
print(f" - Gene Dimension: {model.input_size}")
|
|
1200
|
+
print(f" - Hidden Dimensions: {model.hidden_layers}")
|
|
1201
|
+
print(f" - Device: {model.get_device()}")
|
|
1202
|
+
print(f" - Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
1160
1203
|
|
|
1161
1204
|
return model
|
|
1205
|
+
|
|
1206
|
+
''' def save(self, path):
|
|
1207
|
+
"""Save model checkpoint"""
|
|
1208
|
+
torch.save({
|
|
1209
|
+
'model_state_dict': self.state_dict(),
|
|
1210
|
+
'model_config': {
|
|
1211
|
+
'input_size': self.input_size,
|
|
1212
|
+
'codebook_size': self.code_size,
|
|
1213
|
+
'cell_factor_size': self.cell_factor_size,
|
|
1214
|
+
'turn_off_cell_specific':self.turn_off_cell_specific,
|
|
1215
|
+
'supervised_mode':self.supervised_mode,
|
|
1216
|
+
'z_dim': self.latent_dim,
|
|
1217
|
+
'z_dist': self.latent_dist,
|
|
1218
|
+
'loss_func': self.loss_func,
|
|
1219
|
+
'dispersion': self.dispersion,
|
|
1220
|
+
'use_zeroinflate': self.use_zeroinflate,
|
|
1221
|
+
'hidden_layers':self.hidden_layers,
|
|
1222
|
+
'hidden_layer_activation':self.hidden_layer_activation,
|
|
1223
|
+
'nn_dropout':self.nn_dropout,
|
|
1224
|
+
'post_layer_fct':self.post_layer_fct,
|
|
1225
|
+
'post_act_fct':self.post_act_fct,
|
|
1226
|
+
'config_enum':self.config_enum,
|
|
1227
|
+
'use_cuda':self.use_cuda,
|
|
1228
|
+
'seed':self.seed,
|
|
1229
|
+
'zero_bias':self.use_bias,
|
|
1230
|
+
'dtype':self.dtype,
|
|
1231
|
+
}
|
|
1232
|
+
}, path)
|
|
1233
|
+
|
|
1234
|
+
@classmethod
|
|
1235
|
+
def load_model(cls, model_path: str):
|
|
1236
|
+
"""Load pre-trained model"""
|
|
1237
|
+
checkpoint = torch.load(model_path)
|
|
1238
|
+
model = DensityFlow(**checkpoint.get('model_config'))
|
|
1239
|
+
|
|
1240
|
+
checkpoint = torch.load(model_path, map_location=model.get_device())
|
|
1241
|
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
1242
|
+
|
|
1243
|
+
return model'''
|
|
1162
1244
|
|
|
1163
1245
|
|
|
1164
1246
|
EXAMPLE_RUN = (
|
|
@@ -1357,7 +1439,7 @@ def main():
|
|
|
1357
1439
|
df = DensityFlow(
|
|
1358
1440
|
input_size=input_size,
|
|
1359
1441
|
cell_factor_size=cell_factor_size,
|
|
1360
|
-
|
|
1442
|
+
dispersion=args.dispersion,
|
|
1361
1443
|
z_dim=args.z_dim,
|
|
1362
1444
|
hidden_layers=args.hidden_layers,
|
|
1363
1445
|
hidden_layer_activation=args.hidden_layer_activation,
|