SURE-tools 2.4.17__py3-none-any.whl → 2.4.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +85 -68
- SURE/DensityFlow2.py +1422 -0
- SURE/DensityFlowLinear.py +1413 -0
- SURE/PerturbationAwareDecoder.py +737 -0
- SURE/VirtualCellDecoder.py +0 -1
- SURE/__init__.py +7 -2
- {sure_tools-2.4.17.dist-info → sure_tools-2.4.32.dist-info}/METADATA +1 -1
- {sure_tools-2.4.17.dist-info → sure_tools-2.4.32.dist-info}/RECORD +12 -9
- {sure_tools-2.4.17.dist-info → sure_tools-2.4.32.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.17.dist-info → sure_tools-2.4.32.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.17.dist-info → sure_tools-2.4.32.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.17.dist-info → sure_tools-2.4.32.dist-info}/top_level.txt +0 -0
SURE/DensityFlow.py
CHANGED
|
@@ -57,16 +57,16 @@ def set_random_seed(seed):
|
|
|
57
57
|
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
|
-
codebook_size: int =
|
|
60
|
+
codebook_size: int = 100,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
62
|
turn_off_cell_specific: bool = False,
|
|
63
63
|
supervised_mode: bool = False,
|
|
64
|
-
z_dim: int =
|
|
64
|
+
z_dim: int = 50,
|
|
65
65
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
67
|
-
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
|
+
dispersion: float = 8.0,
|
|
68
68
|
use_zeroinflate: bool = False,
|
|
69
|
-
hidden_layers: list = [
|
|
69
|
+
hidden_layers: list = [1024],
|
|
70
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
71
71
|
nn_dropout: float = 0.1,
|
|
72
72
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -81,7 +81,7 @@ class DensityFlow(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
self.input_size = input_size
|
|
83
83
|
self.cell_factor_size = cell_factor_size
|
|
84
|
-
self.
|
|
84
|
+
self.dispersion = dispersion
|
|
85
85
|
self.latent_dim = z_dim
|
|
86
86
|
self.hidden_layers = hidden_layers
|
|
87
87
|
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
@@ -109,6 +109,13 @@ class DensityFlow(nn.Module):
|
|
|
109
109
|
|
|
110
110
|
set_random_seed(seed)
|
|
111
111
|
self.setup_networks()
|
|
112
|
+
|
|
113
|
+
print(f"🧬 DensityFlow Initialized:")
|
|
114
|
+
print(f" - Latent Dimension: {self.latent_dim}")
|
|
115
|
+
print(f" - Gene Dimension: {self.input_size}")
|
|
116
|
+
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
117
|
+
print(f" - Device: {self.get_device()}")
|
|
118
|
+
print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
|
|
112
119
|
|
|
113
120
|
def setup_networks(self):
|
|
114
121
|
latent_dim = self.latent_dim
|
|
@@ -251,7 +258,7 @@ class DensityFlow(nn.Module):
|
|
|
251
258
|
)
|
|
252
259
|
)
|
|
253
260
|
|
|
254
|
-
self.
|
|
261
|
+
self.decoder_log_mu = MLP(
|
|
255
262
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
256
263
|
activation=activate_fct,
|
|
257
264
|
output_activation=None,
|
|
@@ -341,8 +348,8 @@ class DensityFlow(nn.Module):
|
|
|
341
348
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
342
349
|
|
|
343
350
|
if self.loss_func=='negbinomial':
|
|
344
|
-
|
|
345
|
-
|
|
351
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
352
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
346
353
|
|
|
347
354
|
if self.use_zeroinflate:
|
|
348
355
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -376,28 +383,32 @@ class DensityFlow(nn.Module):
|
|
|
376
383
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
377
384
|
|
|
378
385
|
zs = zns
|
|
379
|
-
|
|
386
|
+
log_mu = self.decoder_log_mu(zs)
|
|
380
387
|
if self.loss_func in ['bernoulli']:
|
|
381
|
-
log_theta =
|
|
388
|
+
log_theta = log_mu
|
|
389
|
+
elif self.loss_func == 'negbinomial':
|
|
390
|
+
mu = log_mu.exp()
|
|
382
391
|
else:
|
|
383
|
-
rate =
|
|
392
|
+
rate = log_mu.exp()
|
|
384
393
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
385
394
|
if self.loss_func == 'poisson':
|
|
386
395
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
387
396
|
|
|
388
397
|
if self.loss_func == 'negbinomial':
|
|
398
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
389
399
|
if self.use_zeroinflate:
|
|
390
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
400
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
401
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
391
402
|
else:
|
|
392
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
403
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
|
|
404
|
+
logits=logits).to_event(1), obs=xs)
|
|
393
405
|
elif self.loss_func == 'poisson':
|
|
394
406
|
if self.use_zeroinflate:
|
|
395
407
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
396
408
|
else:
|
|
397
409
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
398
410
|
elif self.loss_func == 'multinomial':
|
|
399
|
-
|
|
400
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
411
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
401
412
|
elif self.loss_func == 'bernoulli':
|
|
402
413
|
if self.use_zeroinflate:
|
|
403
414
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -421,8 +432,8 @@ class DensityFlow(nn.Module):
|
|
|
421
432
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
422
433
|
|
|
423
434
|
if self.loss_func=='negbinomial':
|
|
424
|
-
|
|
425
|
-
|
|
435
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
436
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
426
437
|
|
|
427
438
|
if self.use_zeroinflate:
|
|
428
439
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -461,28 +472,30 @@ class DensityFlow(nn.Module):
|
|
|
461
472
|
else:
|
|
462
473
|
zs = zns
|
|
463
474
|
|
|
464
|
-
|
|
475
|
+
log_mu = self.decoder_log_mu(zs)
|
|
465
476
|
if self.loss_func in ['bernoulli']:
|
|
466
|
-
log_theta =
|
|
477
|
+
log_theta = log_mu
|
|
478
|
+
elif self.loss_func == 'negbinomial':
|
|
479
|
+
mu = log_mu.exp()
|
|
467
480
|
else:
|
|
468
|
-
rate =
|
|
481
|
+
rate = log_mu.exp()
|
|
469
482
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
470
483
|
if self.loss_func == 'poisson':
|
|
471
484
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
472
485
|
|
|
473
486
|
if self.loss_func == 'negbinomial':
|
|
487
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
474
488
|
if self.use_zeroinflate:
|
|
475
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
489
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
476
490
|
else:
|
|
477
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
491
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
478
492
|
elif self.loss_func == 'poisson':
|
|
479
493
|
if self.use_zeroinflate:
|
|
480
494
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
481
495
|
else:
|
|
482
496
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
483
497
|
elif self.loss_func == 'multinomial':
|
|
484
|
-
|
|
485
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
498
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
486
499
|
elif self.loss_func == 'bernoulli':
|
|
487
500
|
if self.use_zeroinflate:
|
|
488
501
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -506,8 +519,8 @@ class DensityFlow(nn.Module):
|
|
|
506
519
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
507
520
|
|
|
508
521
|
if self.loss_func=='negbinomial':
|
|
509
|
-
|
|
510
|
-
|
|
522
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
523
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
511
524
|
|
|
512
525
|
if self.use_zeroinflate:
|
|
513
526
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -558,28 +571,31 @@ class DensityFlow(nn.Module):
|
|
|
558
571
|
|
|
559
572
|
zs = zns
|
|
560
573
|
|
|
561
|
-
|
|
574
|
+
log_mu = self.decoder_log_mu(zs)
|
|
562
575
|
if self.loss_func in ['bernoulli']:
|
|
563
|
-
log_theta =
|
|
576
|
+
log_theta = log_mu
|
|
577
|
+
elif self.loss_func in ['negbinomial']:
|
|
578
|
+
mu = log_mu.exp()
|
|
564
579
|
else:
|
|
565
|
-
rate =
|
|
580
|
+
rate = log_mu.exp()
|
|
566
581
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
567
582
|
if self.loss_func == 'poisson':
|
|
568
583
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
569
584
|
|
|
570
585
|
if self.loss_func == 'negbinomial':
|
|
586
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
571
587
|
if self.use_zeroinflate:
|
|
572
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
588
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
589
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
573
590
|
else:
|
|
574
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
591
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
575
592
|
elif self.loss_func == 'poisson':
|
|
576
593
|
if self.use_zeroinflate:
|
|
577
594
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
578
595
|
else:
|
|
579
596
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
580
597
|
elif self.loss_func == 'multinomial':
|
|
581
|
-
|
|
582
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
598
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
583
599
|
elif self.loss_func == 'bernoulli':
|
|
584
600
|
if self.use_zeroinflate:
|
|
585
601
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -603,8 +619,8 @@ class DensityFlow(nn.Module):
|
|
|
603
619
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
604
620
|
|
|
605
621
|
if self.loss_func=='negbinomial':
|
|
606
|
-
|
|
607
|
-
|
|
622
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
623
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
608
624
|
|
|
609
625
|
if self.use_zeroinflate:
|
|
610
626
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -665,28 +681,31 @@ class DensityFlow(nn.Module):
|
|
|
665
681
|
else:
|
|
666
682
|
zs = zns
|
|
667
683
|
|
|
668
|
-
|
|
684
|
+
log_mu = self.decoder_log_mu(zs)
|
|
669
685
|
if self.loss_func in ['bernoulli']:
|
|
670
|
-
log_theta =
|
|
686
|
+
log_theta = log_mu
|
|
687
|
+
elif self.loss_func in ['negbinomial']:
|
|
688
|
+
mu = log_mu.exp()
|
|
671
689
|
else:
|
|
672
|
-
rate =
|
|
690
|
+
rate = log_mu.exp()
|
|
673
691
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
674
692
|
if self.loss_func == 'poisson':
|
|
675
693
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
676
694
|
|
|
677
695
|
if self.loss_func == 'negbinomial':
|
|
696
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
678
697
|
if self.use_zeroinflate:
|
|
679
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
698
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
699
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
680
700
|
else:
|
|
681
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
701
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
682
702
|
elif self.loss_func == 'poisson':
|
|
683
703
|
if self.use_zeroinflate:
|
|
684
704
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
685
705
|
else:
|
|
686
706
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
687
707
|
elif self.loss_func == 'multinomial':
|
|
688
|
-
|
|
689
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
708
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
690
709
|
elif self.loss_func == 'bernoulli':
|
|
691
710
|
if self.use_zeroinflate:
|
|
692
711
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -710,13 +729,13 @@ class DensityFlow(nn.Module):
|
|
|
710
729
|
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
711
730
|
#else:
|
|
712
731
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
713
|
-
zus = self.
|
|
732
|
+
zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
714
733
|
else:
|
|
715
734
|
#if self.turn_off_cell_specific:
|
|
716
735
|
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
717
736
|
#else:
|
|
718
737
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
719
|
-
zus = zus + self.
|
|
738
|
+
zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
720
739
|
return zus
|
|
721
740
|
|
|
722
741
|
def _get_codebook_identity(self):
|
|
@@ -858,12 +877,12 @@ class DensityFlow(nn.Module):
|
|
|
858
877
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
859
878
|
|
|
860
879
|
# factor effect of xs
|
|
861
|
-
dzs0 = self.
|
|
880
|
+
dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
862
881
|
|
|
863
882
|
# perturbation effect
|
|
864
883
|
ps = np.ones_like(us_i)
|
|
865
884
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
866
|
-
dzs = self.
|
|
885
|
+
dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
867
886
|
zs = zs + dzs0 + dzs
|
|
868
887
|
else:
|
|
869
888
|
zs = zs + dzs0
|
|
@@ -877,10 +896,11 @@ class DensityFlow(nn.Module):
|
|
|
877
896
|
library_sizes = library_sizes.reshape(-1,1)
|
|
878
897
|
|
|
879
898
|
counts = self.get_counts(zs, library_sizes=library_sizes)
|
|
899
|
+
log_mu = self.get_log_mu(zs)
|
|
880
900
|
|
|
881
|
-
return counts,
|
|
901
|
+
return counts, log_mu
|
|
882
902
|
|
|
883
|
-
def
|
|
903
|
+
def _cell_shift(self, zs, perturb_idx, perturb):
|
|
884
904
|
#zns,_ = self.encoder_zn(xs)
|
|
885
905
|
#zns,_ = self._get_basal_embedding(xs)
|
|
886
906
|
zns = zs
|
|
@@ -897,7 +917,7 @@ class DensityFlow(nn.Module):
|
|
|
897
917
|
|
|
898
918
|
return ms
|
|
899
919
|
|
|
900
|
-
def
|
|
920
|
+
def get_cell_shift(self,
|
|
901
921
|
zs,
|
|
902
922
|
perturb_idx,
|
|
903
923
|
perturb_us,
|
|
@@ -915,46 +935,43 @@ class DensityFlow(nn.Module):
|
|
|
915
935
|
Z = []
|
|
916
936
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
917
937
|
for Z_batch, P_batch, _ in dataloader:
|
|
918
|
-
zns = self.
|
|
938
|
+
zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
|
|
919
939
|
Z.append(tensor_to_numpy(zns))
|
|
920
940
|
pbar.update(1)
|
|
921
941
|
|
|
922
942
|
Z = np.concatenate(Z)
|
|
923
943
|
return Z
|
|
924
944
|
|
|
925
|
-
def
|
|
926
|
-
return self.
|
|
945
|
+
def _log_mu(self, zs):
|
|
946
|
+
return self.decoder_log_mu(zs)
|
|
927
947
|
|
|
928
|
-
def
|
|
929
|
-
delta_zs,
|
|
930
|
-
batch_size: int = 1024):
|
|
948
|
+
def get_log_mu(self, zs, batch_size: int = 1024):
|
|
931
949
|
"""
|
|
932
950
|
Return cells' changes in the feature space induced by specific perturbation of a factor
|
|
933
951
|
|
|
934
952
|
"""
|
|
935
|
-
|
|
936
|
-
dataset = CustomDataset(
|
|
953
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
954
|
+
dataset = CustomDataset(zs)
|
|
937
955
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
938
956
|
|
|
939
957
|
R = []
|
|
940
958
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
941
|
-
for
|
|
942
|
-
r = self.
|
|
959
|
+
for Z_batch, _ in dataloader:
|
|
960
|
+
r = self._log_mu(Z_batch)
|
|
943
961
|
R.append(tensor_to_numpy(r))
|
|
944
962
|
pbar.update(1)
|
|
945
963
|
|
|
946
964
|
R = np.concatenate(R)
|
|
947
965
|
return R
|
|
948
966
|
|
|
949
|
-
def _count(self,
|
|
967
|
+
def _count(self, log_mu, library_size=None):
|
|
950
968
|
if self.loss_func == 'bernoulli':
|
|
951
|
-
|
|
952
|
-
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
969
|
+
counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
|
|
953
970
|
elif self.loss_func == 'multinomial':
|
|
954
|
-
theta = dist.Multinomial(total_count=int(1e8), logits=
|
|
971
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
|
|
955
972
|
counts = theta * library_size
|
|
956
973
|
else:
|
|
957
|
-
rate =
|
|
974
|
+
rate = log_mu.exp()
|
|
958
975
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
959
976
|
counts = theta * library_size
|
|
960
977
|
return counts
|
|
@@ -976,8 +993,8 @@ class DensityFlow(nn.Module):
|
|
|
976
993
|
E = []
|
|
977
994
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
978
995
|
for Z_batch, L_batch, _ in dataloader:
|
|
979
|
-
|
|
980
|
-
counts = self._count(
|
|
996
|
+
log_mu = self._log_mu(Z_batch)
|
|
997
|
+
counts = self._count(log_mu, L_batch)
|
|
981
998
|
E.append(tensor_to_numpy(counts))
|
|
982
999
|
pbar.update(1)
|
|
983
1000
|
|
|
@@ -1350,7 +1367,7 @@ def main():
|
|
|
1350
1367
|
df = DensityFlow(
|
|
1351
1368
|
input_size=input_size,
|
|
1352
1369
|
cell_factor_size=cell_factor_size,
|
|
1353
|
-
|
|
1370
|
+
dispersion=args.dispersion,
|
|
1354
1371
|
z_dim=args.z_dim,
|
|
1355
1372
|
hidden_layers=args.hidden_layers,
|
|
1356
1373
|
hidden_layer_activation=args.hidden_layer_activation,
|