SURE-tools 2.4.7__py3-none-any.whl → 2.4.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/DensityFlow.py +159 -70
- SURE/DensityFlow2.py +1422 -0
- SURE/DensityFlowLinear.py +1414 -0
- SURE/EfficientTranscriptomeDecoder.py +552 -0
- SURE/PerturbationAwareDecoder.py +737 -0
- SURE/SimpleTranscriptomeDecoder.py +567 -0
- SURE/TranscriptomeDecoder.py +13 -1
- SURE/VirtualCellDecoder.py +658 -0
- SURE/__init__.py +13 -1
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.42.dist-info}/METADATA +1 -1
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.42.dist-info}/RECORD +15 -9
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.42.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.42.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.42.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.42.dist-info}/top_level.txt +0 -0
SURE/DensityFlow.py
CHANGED
|
@@ -57,16 +57,16 @@ def set_random_seed(seed):
|
|
|
57
57
|
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
|
-
codebook_size: int =
|
|
60
|
+
codebook_size: int = 30,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
62
|
turn_off_cell_specific: bool = False,
|
|
63
63
|
supervised_mode: bool = False,
|
|
64
|
-
z_dim: int =
|
|
65
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
67
|
-
|
|
64
|
+
z_dim: int = 50,
|
|
65
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
|
+
dispersion: float = 8.0,
|
|
68
68
|
use_zeroinflate: bool = False,
|
|
69
|
-
hidden_layers: list = [
|
|
69
|
+
hidden_layers: list = [1024],
|
|
70
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
71
71
|
nn_dropout: float = 0.1,
|
|
72
72
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
self.input_size = input_size
|
|
83
83
|
self.cell_factor_size = cell_factor_size
|
|
84
|
-
self.
|
|
84
|
+
self.dispersion = dispersion
|
|
85
85
|
self.latent_dim = z_dim
|
|
86
86
|
self.hidden_layers = hidden_layers
|
|
87
87
|
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
88
|
+
self.config_enum = config_enum
|
|
88
89
|
self.allow_broadcast = config_enum == 'parallel'
|
|
89
90
|
self.use_cuda = use_cuda
|
|
90
91
|
self.loss_func = loss_func
|
|
@@ -107,8 +108,17 @@ class DensityFlow(nn.Module):
|
|
|
107
108
|
|
|
108
109
|
self.codebook_weights = None
|
|
109
110
|
|
|
111
|
+
self.seed = seed
|
|
110
112
|
set_random_seed(seed)
|
|
111
113
|
self.setup_networks()
|
|
114
|
+
|
|
115
|
+
print(f"🧬 DensityFlow Initialized:")
|
|
116
|
+
print(f" - Codebook size: {self.code_size}")
|
|
117
|
+
print(f" - Latent Dimension: {self.latent_dim}")
|
|
118
|
+
print(f" - Gene Dimension: {self.input_size}")
|
|
119
|
+
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
120
|
+
print(f" - Device: {self.get_device()}")
|
|
121
|
+
print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
|
|
112
122
|
|
|
113
123
|
def setup_networks(self):
|
|
114
124
|
latent_dim = self.latent_dim
|
|
@@ -251,7 +261,7 @@ class DensityFlow(nn.Module):
|
|
|
251
261
|
)
|
|
252
262
|
)
|
|
253
263
|
|
|
254
|
-
self.
|
|
264
|
+
self.decoder_log_mu = MLP(
|
|
255
265
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
256
266
|
activation=activate_fct,
|
|
257
267
|
output_activation=None,
|
|
@@ -341,8 +351,8 @@ class DensityFlow(nn.Module):
|
|
|
341
351
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
342
352
|
|
|
343
353
|
if self.loss_func=='negbinomial':
|
|
344
|
-
|
|
345
|
-
|
|
354
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
355
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
346
356
|
|
|
347
357
|
if self.use_zeroinflate:
|
|
348
358
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -376,28 +386,32 @@ class DensityFlow(nn.Module):
|
|
|
376
386
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
377
387
|
|
|
378
388
|
zs = zns
|
|
379
|
-
|
|
389
|
+
log_mu = self.decoder_log_mu(zs)
|
|
380
390
|
if self.loss_func in ['bernoulli']:
|
|
381
|
-
log_theta =
|
|
391
|
+
log_theta = log_mu
|
|
392
|
+
elif self.loss_func == 'negbinomial':
|
|
393
|
+
mu = log_mu.exp()
|
|
382
394
|
else:
|
|
383
|
-
rate =
|
|
395
|
+
rate = log_mu.exp()
|
|
384
396
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
385
397
|
if self.loss_func == 'poisson':
|
|
386
398
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
387
399
|
|
|
388
400
|
if self.loss_func == 'negbinomial':
|
|
401
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
389
402
|
if self.use_zeroinflate:
|
|
390
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
403
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
404
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
391
405
|
else:
|
|
392
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
406
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
|
|
407
|
+
logits=logits).to_event(1), obs=xs)
|
|
393
408
|
elif self.loss_func == 'poisson':
|
|
394
409
|
if self.use_zeroinflate:
|
|
395
410
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
396
411
|
else:
|
|
397
412
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
398
413
|
elif self.loss_func == 'multinomial':
|
|
399
|
-
|
|
400
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
414
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
401
415
|
elif self.loss_func == 'bernoulli':
|
|
402
416
|
if self.use_zeroinflate:
|
|
403
417
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -421,8 +435,8 @@ class DensityFlow(nn.Module):
|
|
|
421
435
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
422
436
|
|
|
423
437
|
if self.loss_func=='negbinomial':
|
|
424
|
-
|
|
425
|
-
|
|
438
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
439
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
426
440
|
|
|
427
441
|
if self.use_zeroinflate:
|
|
428
442
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -461,28 +475,30 @@ class DensityFlow(nn.Module):
|
|
|
461
475
|
else:
|
|
462
476
|
zs = zns
|
|
463
477
|
|
|
464
|
-
|
|
478
|
+
log_mu = self.decoder_log_mu(zs)
|
|
465
479
|
if self.loss_func in ['bernoulli']:
|
|
466
|
-
log_theta =
|
|
480
|
+
log_theta = log_mu
|
|
481
|
+
elif self.loss_func == 'negbinomial':
|
|
482
|
+
mu = log_mu.exp()
|
|
467
483
|
else:
|
|
468
|
-
rate =
|
|
484
|
+
rate = log_mu.exp()
|
|
469
485
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
470
486
|
if self.loss_func == 'poisson':
|
|
471
487
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
472
488
|
|
|
473
489
|
if self.loss_func == 'negbinomial':
|
|
490
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
474
491
|
if self.use_zeroinflate:
|
|
475
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
492
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
476
493
|
else:
|
|
477
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
494
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
478
495
|
elif self.loss_func == 'poisson':
|
|
479
496
|
if self.use_zeroinflate:
|
|
480
497
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
481
498
|
else:
|
|
482
499
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
483
500
|
elif self.loss_func == 'multinomial':
|
|
484
|
-
|
|
485
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
501
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
486
502
|
elif self.loss_func == 'bernoulli':
|
|
487
503
|
if self.use_zeroinflate:
|
|
488
504
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -506,8 +522,8 @@ class DensityFlow(nn.Module):
|
|
|
506
522
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
507
523
|
|
|
508
524
|
if self.loss_func=='negbinomial':
|
|
509
|
-
|
|
510
|
-
|
|
525
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
526
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
511
527
|
|
|
512
528
|
if self.use_zeroinflate:
|
|
513
529
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -558,28 +574,31 @@ class DensityFlow(nn.Module):
|
|
|
558
574
|
|
|
559
575
|
zs = zns
|
|
560
576
|
|
|
561
|
-
|
|
577
|
+
log_mu = self.decoder_log_mu(zs)
|
|
562
578
|
if self.loss_func in ['bernoulli']:
|
|
563
|
-
log_theta =
|
|
579
|
+
log_theta = log_mu
|
|
580
|
+
elif self.loss_func in ['negbinomial']:
|
|
581
|
+
mu = log_mu.exp()
|
|
564
582
|
else:
|
|
565
|
-
rate =
|
|
583
|
+
rate = log_mu.exp()
|
|
566
584
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
567
585
|
if self.loss_func == 'poisson':
|
|
568
586
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
569
587
|
|
|
570
588
|
if self.loss_func == 'negbinomial':
|
|
589
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
571
590
|
if self.use_zeroinflate:
|
|
572
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
591
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
592
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
573
593
|
else:
|
|
574
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
594
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
575
595
|
elif self.loss_func == 'poisson':
|
|
576
596
|
if self.use_zeroinflate:
|
|
577
597
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
578
598
|
else:
|
|
579
599
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
580
600
|
elif self.loss_func == 'multinomial':
|
|
581
|
-
|
|
582
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
601
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
583
602
|
elif self.loss_func == 'bernoulli':
|
|
584
603
|
if self.use_zeroinflate:
|
|
585
604
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -603,8 +622,8 @@ class DensityFlow(nn.Module):
|
|
|
603
622
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
604
623
|
|
|
605
624
|
if self.loss_func=='negbinomial':
|
|
606
|
-
|
|
607
|
-
|
|
625
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
626
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
608
627
|
|
|
609
628
|
if self.use_zeroinflate:
|
|
610
629
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -665,28 +684,31 @@ class DensityFlow(nn.Module):
|
|
|
665
684
|
else:
|
|
666
685
|
zs = zns
|
|
667
686
|
|
|
668
|
-
|
|
687
|
+
log_mu = self.decoder_log_mu(zs)
|
|
669
688
|
if self.loss_func in ['bernoulli']:
|
|
670
|
-
log_theta =
|
|
689
|
+
log_theta = log_mu
|
|
690
|
+
elif self.loss_func in ['negbinomial']:
|
|
691
|
+
mu = log_mu.exp()
|
|
671
692
|
else:
|
|
672
|
-
rate =
|
|
693
|
+
rate = log_mu.exp()
|
|
673
694
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
674
695
|
if self.loss_func == 'poisson':
|
|
675
696
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
676
697
|
|
|
677
698
|
if self.loss_func == 'negbinomial':
|
|
699
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
678
700
|
if self.use_zeroinflate:
|
|
679
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
701
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
702
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
680
703
|
else:
|
|
681
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
704
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
682
705
|
elif self.loss_func == 'poisson':
|
|
683
706
|
if self.use_zeroinflate:
|
|
684
707
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
685
708
|
else:
|
|
686
709
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
687
710
|
elif self.loss_func == 'multinomial':
|
|
688
|
-
|
|
689
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
711
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
690
712
|
elif self.loss_func == 'bernoulli':
|
|
691
713
|
if self.use_zeroinflate:
|
|
692
714
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -710,13 +732,13 @@ class DensityFlow(nn.Module):
|
|
|
710
732
|
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
711
733
|
#else:
|
|
712
734
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
713
|
-
zus = self.
|
|
735
|
+
zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
714
736
|
else:
|
|
715
737
|
#if self.turn_off_cell_specific:
|
|
716
738
|
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
717
739
|
#else:
|
|
718
740
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
719
|
-
zus = zus + self.
|
|
741
|
+
zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
720
742
|
return zus
|
|
721
743
|
|
|
722
744
|
def _get_codebook_identity(self):
|
|
@@ -737,6 +759,28 @@ class DensityFlow(nn.Module):
|
|
|
737
759
|
cb = self._get_codebook()
|
|
738
760
|
cb = tensor_to_numpy(cb)
|
|
739
761
|
return cb
|
|
762
|
+
|
|
763
|
+
def _get_complete_embedding(self, xs, us):
|
|
764
|
+
basal,_ = self._get_basal_embedding(xs)
|
|
765
|
+
dzs = self._total_effects(basal, us)
|
|
766
|
+
return basal + dzs
|
|
767
|
+
|
|
768
|
+
def get_complete_embedding(self, xs, us, batch_size:int=1024):
|
|
769
|
+
xs = self.preprocess(xs)
|
|
770
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
771
|
+
us = convert_to_tensor(us, device=self.get_device())
|
|
772
|
+
dataset = CustomDataset2(xs, us)
|
|
773
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
774
|
+
|
|
775
|
+
Z = []
|
|
776
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
777
|
+
for X_batch, U_batch, _ in dataloader:
|
|
778
|
+
zns = self._get_complete_embedding(X_batch, U_batch)
|
|
779
|
+
Z.append(tensor_to_numpy(zns))
|
|
780
|
+
pbar.update(1)
|
|
781
|
+
|
|
782
|
+
Z = np.concatenate(Z)
|
|
783
|
+
return Z
|
|
740
784
|
|
|
741
785
|
def _get_basal_embedding(self, xs):
|
|
742
786
|
loc, scale = self.encoder_zn(xs)
|
|
@@ -858,12 +902,12 @@ class DensityFlow(nn.Module):
|
|
|
858
902
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
859
903
|
|
|
860
904
|
# factor effect of xs
|
|
861
|
-
dzs0 = self.
|
|
905
|
+
dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
862
906
|
|
|
863
907
|
# perturbation effect
|
|
864
908
|
ps = np.ones_like(us_i)
|
|
865
909
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
866
|
-
dzs = self.
|
|
910
|
+
dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
867
911
|
zs = zs + dzs0 + dzs
|
|
868
912
|
else:
|
|
869
913
|
zs = zs + dzs0
|
|
@@ -877,10 +921,11 @@ class DensityFlow(nn.Module):
|
|
|
877
921
|
library_sizes = library_sizes.reshape(-1,1)
|
|
878
922
|
|
|
879
923
|
counts = self.get_counts(zs, library_sizes=library_sizes)
|
|
924
|
+
log_mu = self.get_log_mu(zs)
|
|
880
925
|
|
|
881
|
-
return counts,
|
|
926
|
+
return counts, log_mu
|
|
882
927
|
|
|
883
|
-
def
|
|
928
|
+
def _cell_shift(self, zs, perturb_idx, perturb):
|
|
884
929
|
#zns,_ = self.encoder_zn(xs)
|
|
885
930
|
#zns,_ = self._get_basal_embedding(xs)
|
|
886
931
|
zns = zs
|
|
@@ -897,7 +942,7 @@ class DensityFlow(nn.Module):
|
|
|
897
942
|
|
|
898
943
|
return ms
|
|
899
944
|
|
|
900
|
-
def
|
|
945
|
+
def get_cell_shift(self,
|
|
901
946
|
zs,
|
|
902
947
|
perturb_idx,
|
|
903
948
|
perturb_us,
|
|
@@ -915,46 +960,43 @@ class DensityFlow(nn.Module):
|
|
|
915
960
|
Z = []
|
|
916
961
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
917
962
|
for Z_batch, P_batch, _ in dataloader:
|
|
918
|
-
zns = self.
|
|
963
|
+
zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
|
|
919
964
|
Z.append(tensor_to_numpy(zns))
|
|
920
965
|
pbar.update(1)
|
|
921
966
|
|
|
922
967
|
Z = np.concatenate(Z)
|
|
923
968
|
return Z
|
|
924
969
|
|
|
925
|
-
def
|
|
926
|
-
return self.
|
|
970
|
+
def _log_mu(self, zs):
|
|
971
|
+
return self.decoder_log_mu(zs)
|
|
927
972
|
|
|
928
|
-
def
|
|
929
|
-
delta_zs,
|
|
930
|
-
batch_size: int = 1024):
|
|
973
|
+
def get_log_mu(self, zs, batch_size: int = 1024):
|
|
931
974
|
"""
|
|
932
975
|
Return cells' changes in the feature space induced by specific perturbation of a factor
|
|
933
976
|
|
|
934
977
|
"""
|
|
935
|
-
|
|
936
|
-
dataset = CustomDataset(
|
|
978
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
979
|
+
dataset = CustomDataset(zs)
|
|
937
980
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
938
981
|
|
|
939
982
|
R = []
|
|
940
983
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
941
|
-
for
|
|
942
|
-
r = self.
|
|
984
|
+
for Z_batch, _ in dataloader:
|
|
985
|
+
r = self._log_mu(Z_batch)
|
|
943
986
|
R.append(tensor_to_numpy(r))
|
|
944
987
|
pbar.update(1)
|
|
945
988
|
|
|
946
989
|
R = np.concatenate(R)
|
|
947
990
|
return R
|
|
948
991
|
|
|
949
|
-
def _count(self,
|
|
992
|
+
def _count(self, log_mu, library_size=None):
|
|
950
993
|
if self.loss_func == 'bernoulli':
|
|
951
|
-
|
|
952
|
-
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
994
|
+
counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
|
|
953
995
|
elif self.loss_func == 'multinomial':
|
|
954
|
-
theta = dist.Multinomial(total_count=int(1e8), logits=
|
|
996
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
|
|
955
997
|
counts = theta * library_size
|
|
956
998
|
else:
|
|
957
|
-
rate =
|
|
999
|
+
rate = log_mu.exp()
|
|
958
1000
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
959
1001
|
counts = theta * library_size
|
|
960
1002
|
return counts
|
|
@@ -976,8 +1018,8 @@ class DensityFlow(nn.Module):
|
|
|
976
1018
|
E = []
|
|
977
1019
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
978
1020
|
for Z_batch, L_batch, _ in dataloader:
|
|
979
|
-
|
|
980
|
-
counts = self._count(
|
|
1021
|
+
log_mu = self._log_mu(Z_batch)
|
|
1022
|
+
counts = self._count(log_mu, L_batch)
|
|
981
1023
|
E.append(tensor_to_numpy(counts))
|
|
982
1024
|
pbar.update(1)
|
|
983
1025
|
|
|
@@ -989,7 +1031,7 @@ class DensityFlow(nn.Module):
|
|
|
989
1031
|
ad = sc.AnnData(xs)
|
|
990
1032
|
binarize(ad, threshold=threshold)
|
|
991
1033
|
xs = ad.X.copy()
|
|
992
|
-
|
|
1034
|
+
elif self.loss_func == 'poisson':
|
|
993
1035
|
xs = np.round(xs)
|
|
994
1036
|
|
|
995
1037
|
if sparse.issparse(xs):
|
|
@@ -1150,8 +1192,55 @@ class DensityFlow(nn.Module):
|
|
|
1150
1192
|
else:
|
|
1151
1193
|
with open(file_path, 'rb') as pickle_file:
|
|
1152
1194
|
model = pickle.load(pickle_file)
|
|
1195
|
+
|
|
1196
|
+
print(f"🧬 DensityFlow Initialized:")
|
|
1197
|
+
print(f" - Codebook size: {model.code_size}")
|
|
1198
|
+
print(f" - Latent Dimension: {model.latent_dim}")
|
|
1199
|
+
print(f" - Gene Dimension: {model.input_size}")
|
|
1200
|
+
print(f" - Hidden Dimensions: {model.hidden_layers}")
|
|
1201
|
+
print(f" - Device: {model.get_device()}")
|
|
1202
|
+
print(f" - Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
1153
1203
|
|
|
1154
1204
|
return model
|
|
1205
|
+
|
|
1206
|
+
''' def save(self, path):
|
|
1207
|
+
"""Save model checkpoint"""
|
|
1208
|
+
torch.save({
|
|
1209
|
+
'model_state_dict': self.state_dict(),
|
|
1210
|
+
'model_config': {
|
|
1211
|
+
'input_size': self.input_size,
|
|
1212
|
+
'codebook_size': self.code_size,
|
|
1213
|
+
'cell_factor_size': self.cell_factor_size,
|
|
1214
|
+
'turn_off_cell_specific':self.turn_off_cell_specific,
|
|
1215
|
+
'supervised_mode':self.supervised_mode,
|
|
1216
|
+
'z_dim': self.latent_dim,
|
|
1217
|
+
'z_dist': self.latent_dist,
|
|
1218
|
+
'loss_func': self.loss_func,
|
|
1219
|
+
'dispersion': self.dispersion,
|
|
1220
|
+
'use_zeroinflate': self.use_zeroinflate,
|
|
1221
|
+
'hidden_layers':self.hidden_layers,
|
|
1222
|
+
'hidden_layer_activation':self.hidden_layer_activation,
|
|
1223
|
+
'nn_dropout':self.nn_dropout,
|
|
1224
|
+
'post_layer_fct':self.post_layer_fct,
|
|
1225
|
+
'post_act_fct':self.post_act_fct,
|
|
1226
|
+
'config_enum':self.config_enum,
|
|
1227
|
+
'use_cuda':self.use_cuda,
|
|
1228
|
+
'seed':self.seed,
|
|
1229
|
+
'zero_bias':self.use_bias,
|
|
1230
|
+
'dtype':self.dtype,
|
|
1231
|
+
}
|
|
1232
|
+
}, path)
|
|
1233
|
+
|
|
1234
|
+
@classmethod
|
|
1235
|
+
def load_model(cls, model_path: str):
|
|
1236
|
+
"""Load pre-trained model"""
|
|
1237
|
+
checkpoint = torch.load(model_path)
|
|
1238
|
+
model = DensityFlow(**checkpoint.get('model_config'))
|
|
1239
|
+
|
|
1240
|
+
checkpoint = torch.load(model_path, map_location=model.get_device())
|
|
1241
|
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
1242
|
+
|
|
1243
|
+
return model'''
|
|
1155
1244
|
|
|
1156
1245
|
|
|
1157
1246
|
EXAMPLE_RUN = (
|
|
@@ -1350,7 +1439,7 @@ def main():
|
|
|
1350
1439
|
df = DensityFlow(
|
|
1351
1440
|
input_size=input_size,
|
|
1352
1441
|
cell_factor_size=cell_factor_size,
|
|
1353
|
-
|
|
1442
|
+
dispersion=args.dispersion,
|
|
1354
1443
|
z_dim=args.z_dim,
|
|
1355
1444
|
hidden_layers=args.hidden_layers,
|
|
1356
1445
|
hidden_layer_activation=args.hidden_layer_activation,
|