SURE-tools 2.4.7__py3-none-any.whl → 2.4.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/DensityFlow.py +128 -70
- SURE/DensityFlow2.py +1422 -0
- SURE/DensityFlowLinear.py +1414 -0
- SURE/EfficientTranscriptomeDecoder.py +552 -0
- SURE/PerturbationAwareDecoder.py +737 -0
- SURE/SimpleTranscriptomeDecoder.py +567 -0
- SURE/TranscriptomeDecoder.py +13 -1
- SURE/VirtualCellDecoder.py +658 -0
- SURE/__init__.py +13 -1
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.34.dist-info}/METADATA +1 -1
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.34.dist-info}/RECORD +15 -9
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.34.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.34.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.34.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.7.dist-info → sure_tools-2.4.34.dist-info}/top_level.txt +0 -0
SURE/DensityFlow.py
CHANGED
|
@@ -57,16 +57,16 @@ def set_random_seed(seed):
|
|
|
57
57
|
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
|
-
codebook_size: int =
|
|
60
|
+
codebook_size: int = 100,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
62
|
turn_off_cell_specific: bool = False,
|
|
63
63
|
supervised_mode: bool = False,
|
|
64
|
-
z_dim: int =
|
|
65
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
67
|
-
|
|
64
|
+
z_dim: int = 50,
|
|
65
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
|
+
dispersion: float = 8.0,
|
|
68
68
|
use_zeroinflate: bool = False,
|
|
69
|
-
hidden_layers: list = [
|
|
69
|
+
hidden_layers: list = [1024],
|
|
70
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
71
71
|
nn_dropout: float = 0.1,
|
|
72
72
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
self.input_size = input_size
|
|
83
83
|
self.cell_factor_size = cell_factor_size
|
|
84
|
-
self.
|
|
84
|
+
self.dispersion = dispersion
|
|
85
85
|
self.latent_dim = z_dim
|
|
86
86
|
self.hidden_layers = hidden_layers
|
|
87
87
|
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
88
|
+
self.config_enum = config_enum
|
|
88
89
|
self.allow_broadcast = config_enum == 'parallel'
|
|
89
90
|
self.use_cuda = use_cuda
|
|
90
91
|
self.loss_func = loss_func
|
|
@@ -107,8 +108,16 @@ class DensityFlow(nn.Module):
|
|
|
107
108
|
|
|
108
109
|
self.codebook_weights = None
|
|
109
110
|
|
|
111
|
+
self.seed = seed
|
|
110
112
|
set_random_seed(seed)
|
|
111
113
|
self.setup_networks()
|
|
114
|
+
|
|
115
|
+
print(f"🧬 DensityFlow Initialized:")
|
|
116
|
+
print(f" - Latent Dimension: {self.latent_dim}")
|
|
117
|
+
print(f" - Gene Dimension: {self.input_size}")
|
|
118
|
+
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
119
|
+
print(f" - Device: {self.get_device()}")
|
|
120
|
+
print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
|
|
112
121
|
|
|
113
122
|
def setup_networks(self):
|
|
114
123
|
latent_dim = self.latent_dim
|
|
@@ -251,7 +260,7 @@ class DensityFlow(nn.Module):
|
|
|
251
260
|
)
|
|
252
261
|
)
|
|
253
262
|
|
|
254
|
-
self.
|
|
263
|
+
self.decoder_log_mu = MLP(
|
|
255
264
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
256
265
|
activation=activate_fct,
|
|
257
266
|
output_activation=None,
|
|
@@ -341,8 +350,8 @@ class DensityFlow(nn.Module):
|
|
|
341
350
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
342
351
|
|
|
343
352
|
if self.loss_func=='negbinomial':
|
|
344
|
-
|
|
345
|
-
|
|
353
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
354
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
346
355
|
|
|
347
356
|
if self.use_zeroinflate:
|
|
348
357
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -376,28 +385,32 @@ class DensityFlow(nn.Module):
|
|
|
376
385
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
377
386
|
|
|
378
387
|
zs = zns
|
|
379
|
-
|
|
388
|
+
log_mu = self.decoder_log_mu(zs)
|
|
380
389
|
if self.loss_func in ['bernoulli']:
|
|
381
|
-
log_theta =
|
|
390
|
+
log_theta = log_mu
|
|
391
|
+
elif self.loss_func == 'negbinomial':
|
|
392
|
+
mu = log_mu.exp()
|
|
382
393
|
else:
|
|
383
|
-
rate =
|
|
394
|
+
rate = log_mu.exp()
|
|
384
395
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
385
396
|
if self.loss_func == 'poisson':
|
|
386
397
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
387
398
|
|
|
388
399
|
if self.loss_func == 'negbinomial':
|
|
400
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
389
401
|
if self.use_zeroinflate:
|
|
390
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
402
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
403
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
391
404
|
else:
|
|
392
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
405
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
|
|
406
|
+
logits=logits).to_event(1), obs=xs)
|
|
393
407
|
elif self.loss_func == 'poisson':
|
|
394
408
|
if self.use_zeroinflate:
|
|
395
409
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
396
410
|
else:
|
|
397
411
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
398
412
|
elif self.loss_func == 'multinomial':
|
|
399
|
-
|
|
400
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
413
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
401
414
|
elif self.loss_func == 'bernoulli':
|
|
402
415
|
if self.use_zeroinflate:
|
|
403
416
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -421,8 +434,8 @@ class DensityFlow(nn.Module):
|
|
|
421
434
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
422
435
|
|
|
423
436
|
if self.loss_func=='negbinomial':
|
|
424
|
-
|
|
425
|
-
|
|
437
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
438
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
426
439
|
|
|
427
440
|
if self.use_zeroinflate:
|
|
428
441
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -461,28 +474,30 @@ class DensityFlow(nn.Module):
|
|
|
461
474
|
else:
|
|
462
475
|
zs = zns
|
|
463
476
|
|
|
464
|
-
|
|
477
|
+
log_mu = self.decoder_log_mu(zs)
|
|
465
478
|
if self.loss_func in ['bernoulli']:
|
|
466
|
-
log_theta =
|
|
479
|
+
log_theta = log_mu
|
|
480
|
+
elif self.loss_func == 'negbinomial':
|
|
481
|
+
mu = log_mu.exp()
|
|
467
482
|
else:
|
|
468
|
-
rate =
|
|
483
|
+
rate = log_mu.exp()
|
|
469
484
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
470
485
|
if self.loss_func == 'poisson':
|
|
471
486
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
472
487
|
|
|
473
488
|
if self.loss_func == 'negbinomial':
|
|
489
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
474
490
|
if self.use_zeroinflate:
|
|
475
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
491
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
476
492
|
else:
|
|
477
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
493
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
478
494
|
elif self.loss_func == 'poisson':
|
|
479
495
|
if self.use_zeroinflate:
|
|
480
496
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
481
497
|
else:
|
|
482
498
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
483
499
|
elif self.loss_func == 'multinomial':
|
|
484
|
-
|
|
485
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
500
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
486
501
|
elif self.loss_func == 'bernoulli':
|
|
487
502
|
if self.use_zeroinflate:
|
|
488
503
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -506,8 +521,8 @@ class DensityFlow(nn.Module):
|
|
|
506
521
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
507
522
|
|
|
508
523
|
if self.loss_func=='negbinomial':
|
|
509
|
-
|
|
510
|
-
|
|
524
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
525
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
511
526
|
|
|
512
527
|
if self.use_zeroinflate:
|
|
513
528
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -558,28 +573,31 @@ class DensityFlow(nn.Module):
|
|
|
558
573
|
|
|
559
574
|
zs = zns
|
|
560
575
|
|
|
561
|
-
|
|
576
|
+
log_mu = self.decoder_log_mu(zs)
|
|
562
577
|
if self.loss_func in ['bernoulli']:
|
|
563
|
-
log_theta =
|
|
578
|
+
log_theta = log_mu
|
|
579
|
+
elif self.loss_func in ['negbinomial']:
|
|
580
|
+
mu = log_mu.exp()
|
|
564
581
|
else:
|
|
565
|
-
rate =
|
|
582
|
+
rate = log_mu.exp()
|
|
566
583
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
567
584
|
if self.loss_func == 'poisson':
|
|
568
585
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
569
586
|
|
|
570
587
|
if self.loss_func == 'negbinomial':
|
|
588
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
571
589
|
if self.use_zeroinflate:
|
|
572
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
590
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
591
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
573
592
|
else:
|
|
574
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
593
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
575
594
|
elif self.loss_func == 'poisson':
|
|
576
595
|
if self.use_zeroinflate:
|
|
577
596
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
578
597
|
else:
|
|
579
598
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
580
599
|
elif self.loss_func == 'multinomial':
|
|
581
|
-
|
|
582
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
600
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
583
601
|
elif self.loss_func == 'bernoulli':
|
|
584
602
|
if self.use_zeroinflate:
|
|
585
603
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -603,8 +621,8 @@ class DensityFlow(nn.Module):
|
|
|
603
621
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
604
622
|
|
|
605
623
|
if self.loss_func=='negbinomial':
|
|
606
|
-
|
|
607
|
-
|
|
624
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
625
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
608
626
|
|
|
609
627
|
if self.use_zeroinflate:
|
|
610
628
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -665,28 +683,31 @@ class DensityFlow(nn.Module):
|
|
|
665
683
|
else:
|
|
666
684
|
zs = zns
|
|
667
685
|
|
|
668
|
-
|
|
686
|
+
log_mu = self.decoder_log_mu(zs)
|
|
669
687
|
if self.loss_func in ['bernoulli']:
|
|
670
|
-
log_theta =
|
|
688
|
+
log_theta = log_mu
|
|
689
|
+
elif self.loss_func in ['negbinomial']:
|
|
690
|
+
mu = log_mu.exp()
|
|
671
691
|
else:
|
|
672
|
-
rate =
|
|
692
|
+
rate = log_mu.exp()
|
|
673
693
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
674
694
|
if self.loss_func == 'poisson':
|
|
675
695
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
676
696
|
|
|
677
697
|
if self.loss_func == 'negbinomial':
|
|
698
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
678
699
|
if self.use_zeroinflate:
|
|
679
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
700
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
701
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
680
702
|
else:
|
|
681
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
703
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
682
704
|
elif self.loss_func == 'poisson':
|
|
683
705
|
if self.use_zeroinflate:
|
|
684
706
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
685
707
|
else:
|
|
686
708
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
687
709
|
elif self.loss_func == 'multinomial':
|
|
688
|
-
|
|
689
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
710
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
690
711
|
elif self.loss_func == 'bernoulli':
|
|
691
712
|
if self.use_zeroinflate:
|
|
692
713
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -710,13 +731,13 @@ class DensityFlow(nn.Module):
|
|
|
710
731
|
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
711
732
|
#else:
|
|
712
733
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
713
|
-
zus = self.
|
|
734
|
+
zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
714
735
|
else:
|
|
715
736
|
#if self.turn_off_cell_specific:
|
|
716
737
|
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
717
738
|
#else:
|
|
718
739
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
719
|
-
zus = zus + self.
|
|
740
|
+
zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
720
741
|
return zus
|
|
721
742
|
|
|
722
743
|
def _get_codebook_identity(self):
|
|
@@ -858,12 +879,12 @@ class DensityFlow(nn.Module):
|
|
|
858
879
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
859
880
|
|
|
860
881
|
# factor effect of xs
|
|
861
|
-
dzs0 = self.
|
|
882
|
+
dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
862
883
|
|
|
863
884
|
# perturbation effect
|
|
864
885
|
ps = np.ones_like(us_i)
|
|
865
886
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
866
|
-
dzs = self.
|
|
887
|
+
dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
867
888
|
zs = zs + dzs0 + dzs
|
|
868
889
|
else:
|
|
869
890
|
zs = zs + dzs0
|
|
@@ -877,10 +898,11 @@ class DensityFlow(nn.Module):
|
|
|
877
898
|
library_sizes = library_sizes.reshape(-1,1)
|
|
878
899
|
|
|
879
900
|
counts = self.get_counts(zs, library_sizes=library_sizes)
|
|
901
|
+
log_mu = self.get_log_mu(zs)
|
|
880
902
|
|
|
881
|
-
return counts,
|
|
903
|
+
return counts, log_mu
|
|
882
904
|
|
|
883
|
-
def
|
|
905
|
+
def _cell_shift(self, zs, perturb_idx, perturb):
|
|
884
906
|
#zns,_ = self.encoder_zn(xs)
|
|
885
907
|
#zns,_ = self._get_basal_embedding(xs)
|
|
886
908
|
zns = zs
|
|
@@ -897,7 +919,7 @@ class DensityFlow(nn.Module):
|
|
|
897
919
|
|
|
898
920
|
return ms
|
|
899
921
|
|
|
900
|
-
def
|
|
922
|
+
def get_cell_shift(self,
|
|
901
923
|
zs,
|
|
902
924
|
perturb_idx,
|
|
903
925
|
perturb_us,
|
|
@@ -915,46 +937,43 @@ class DensityFlow(nn.Module):
|
|
|
915
937
|
Z = []
|
|
916
938
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
917
939
|
for Z_batch, P_batch, _ in dataloader:
|
|
918
|
-
zns = self.
|
|
940
|
+
zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
|
|
919
941
|
Z.append(tensor_to_numpy(zns))
|
|
920
942
|
pbar.update(1)
|
|
921
943
|
|
|
922
944
|
Z = np.concatenate(Z)
|
|
923
945
|
return Z
|
|
924
946
|
|
|
925
|
-
def
|
|
926
|
-
return self.
|
|
947
|
+
def _log_mu(self, zs):
|
|
948
|
+
return self.decoder_log_mu(zs)
|
|
927
949
|
|
|
928
|
-
def
|
|
929
|
-
delta_zs,
|
|
930
|
-
batch_size: int = 1024):
|
|
950
|
+
def get_log_mu(self, zs, batch_size: int = 1024):
|
|
931
951
|
"""
|
|
932
952
|
Return cells' changes in the feature space induced by specific perturbation of a factor
|
|
933
953
|
|
|
934
954
|
"""
|
|
935
|
-
|
|
936
|
-
dataset = CustomDataset(
|
|
955
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
956
|
+
dataset = CustomDataset(zs)
|
|
937
957
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
938
958
|
|
|
939
959
|
R = []
|
|
940
960
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
941
|
-
for
|
|
942
|
-
r = self.
|
|
961
|
+
for Z_batch, _ in dataloader:
|
|
962
|
+
r = self._log_mu(Z_batch)
|
|
943
963
|
R.append(tensor_to_numpy(r))
|
|
944
964
|
pbar.update(1)
|
|
945
965
|
|
|
946
966
|
R = np.concatenate(R)
|
|
947
967
|
return R
|
|
948
968
|
|
|
949
|
-
def _count(self,
|
|
969
|
+
def _count(self, log_mu, library_size=None):
|
|
950
970
|
if self.loss_func == 'bernoulli':
|
|
951
|
-
|
|
952
|
-
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
971
|
+
counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
|
|
953
972
|
elif self.loss_func == 'multinomial':
|
|
954
|
-
theta = dist.Multinomial(total_count=int(1e8), logits=
|
|
973
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
|
|
955
974
|
counts = theta * library_size
|
|
956
975
|
else:
|
|
957
|
-
rate =
|
|
976
|
+
rate = log_mu.exp()
|
|
958
977
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
959
978
|
counts = theta * library_size
|
|
960
979
|
return counts
|
|
@@ -976,8 +995,8 @@ class DensityFlow(nn.Module):
|
|
|
976
995
|
E = []
|
|
977
996
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
978
997
|
for Z_batch, L_batch, _ in dataloader:
|
|
979
|
-
|
|
980
|
-
counts = self._count(
|
|
998
|
+
log_mu = self._log_mu(Z_batch)
|
|
999
|
+
counts = self._count(log_mu, L_batch)
|
|
981
1000
|
E.append(tensor_to_numpy(counts))
|
|
982
1001
|
pbar.update(1)
|
|
983
1002
|
|
|
@@ -1123,7 +1142,7 @@ class DensityFlow(nn.Module):
|
|
|
1123
1142
|
pbar.set_postfix({'loss': str_loss})
|
|
1124
1143
|
pbar.update(1)
|
|
1125
1144
|
|
|
1126
|
-
@classmethod
|
|
1145
|
+
'''@classmethod
|
|
1127
1146
|
def save_model(cls, model, file_path, compression=False):
|
|
1128
1147
|
"""Save the model to the specified file path."""
|
|
1129
1148
|
file_path = os.path.abspath(file_path)
|
|
@@ -1151,6 +1170,45 @@ class DensityFlow(nn.Module):
|
|
|
1151
1170
|
with open(file_path, 'rb') as pickle_file:
|
|
1152
1171
|
model = pickle.load(pickle_file)
|
|
1153
1172
|
|
|
1173
|
+
return model'''
|
|
1174
|
+
|
|
1175
|
+
def save_model(self, path):
|
|
1176
|
+
"""Save model checkpoint"""
|
|
1177
|
+
torch.save({
|
|
1178
|
+
'model_state_dict': self.state_dict(),
|
|
1179
|
+
'model_config': {
|
|
1180
|
+
'input_size': self.input_size,
|
|
1181
|
+
'codebook_size': self.code_size,
|
|
1182
|
+
'cell_factor_size': self.cell_factor_size,
|
|
1183
|
+
'turn_off_cell_specific':self.turn_off_cell_specific,
|
|
1184
|
+
'supervised_mode':self.supervised_mode,
|
|
1185
|
+
'z_dim': self.latent_dim,
|
|
1186
|
+
'z_dist': self.latent_dist,
|
|
1187
|
+
'loss_func': self.loss_func,
|
|
1188
|
+
'dispersion': self.dispersion,
|
|
1189
|
+
'use_zeroinflate': self.use_zeroinflate,
|
|
1190
|
+
'hidden_layers':self.hidden_layers,
|
|
1191
|
+
'hidden_layer_activation':self.hidden_layer_activation,
|
|
1192
|
+
'nn_dropout':self.nn_dropout,
|
|
1193
|
+
'post_layer_fct':self.post_layer_fct,
|
|
1194
|
+
'post_act_fct':self.post_act_fct,
|
|
1195
|
+
'config_enum':self.config_enum,
|
|
1196
|
+
'use_cuda':self.use_cuda,
|
|
1197
|
+
'seed':self.seed,
|
|
1198
|
+
'zero_bias':self.use_bias,
|
|
1199
|
+
'dtype':self.dtype,
|
|
1200
|
+
}
|
|
1201
|
+
}, path)
|
|
1202
|
+
|
|
1203
|
+
@classmethod
|
|
1204
|
+
def load_model(cls, model_path: str):
|
|
1205
|
+
"""Load pre-trained model"""
|
|
1206
|
+
checkpoint = torch.load(model_path)
|
|
1207
|
+
model = DensityFlow(**checkpoint.get('model_config'))
|
|
1208
|
+
|
|
1209
|
+
checkpoint = torch.load(model_path, map_location=model.get_device())
|
|
1210
|
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
1211
|
+
|
|
1154
1212
|
return model
|
|
1155
1213
|
|
|
1156
1214
|
|
|
@@ -1350,7 +1408,7 @@ def main():
|
|
|
1350
1408
|
df = DensityFlow(
|
|
1351
1409
|
input_size=input_size,
|
|
1352
1410
|
cell_factor_size=cell_factor_size,
|
|
1353
|
-
|
|
1411
|
+
dispersion=args.dispersion,
|
|
1354
1412
|
z_dim=args.z_dim,
|
|
1355
1413
|
hidden_layers=args.hidden_layers,
|
|
1356
1414
|
hidden_layer_activation=args.hidden_layer_activation,
|