SURE-tools 2.4.22__py3-none-any.whl → 2.4.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +78 -68
- SURE/DensityFlow2.py +1422 -0
- SURE/DensityFlowLinear.py +1413 -0
- SURE/PerturbationAwareDecoder.py +162 -148
- SURE/__init__.py +3 -1
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.32.dist-info}/METADATA +1 -1
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.32.dist-info}/RECORD +11 -9
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.32.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.32.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.32.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.22.dist-info → sure_tools-2.4.32.dist-info}/top_level.txt +0 -0
SURE/DensityFlow.py
CHANGED
|
@@ -57,16 +57,16 @@ def set_random_seed(seed):
|
|
|
57
57
|
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
|
-
codebook_size: int =
|
|
60
|
+
codebook_size: int = 100,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
62
|
turn_off_cell_specific: bool = False,
|
|
63
63
|
supervised_mode: bool = False,
|
|
64
|
-
z_dim: int =
|
|
64
|
+
z_dim: int = 50,
|
|
65
65
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
67
|
-
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
|
+
dispersion: float = 8.0,
|
|
68
68
|
use_zeroinflate: bool = False,
|
|
69
|
-
hidden_layers: list = [
|
|
69
|
+
hidden_layers: list = [1024],
|
|
70
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
71
71
|
nn_dropout: float = 0.1,
|
|
72
72
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -81,7 +81,7 @@ class DensityFlow(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
self.input_size = input_size
|
|
83
83
|
self.cell_factor_size = cell_factor_size
|
|
84
|
-
self.
|
|
84
|
+
self.dispersion = dispersion
|
|
85
85
|
self.latent_dim = z_dim
|
|
86
86
|
self.hidden_layers = hidden_layers
|
|
87
87
|
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
@@ -258,7 +258,7 @@ class DensityFlow(nn.Module):
|
|
|
258
258
|
)
|
|
259
259
|
)
|
|
260
260
|
|
|
261
|
-
self.
|
|
261
|
+
self.decoder_log_mu = MLP(
|
|
262
262
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
263
263
|
activation=activate_fct,
|
|
264
264
|
output_activation=None,
|
|
@@ -348,8 +348,8 @@ class DensityFlow(nn.Module):
|
|
|
348
348
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
349
349
|
|
|
350
350
|
if self.loss_func=='negbinomial':
|
|
351
|
-
|
|
352
|
-
|
|
351
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
352
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
353
353
|
|
|
354
354
|
if self.use_zeroinflate:
|
|
355
355
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -383,28 +383,32 @@ class DensityFlow(nn.Module):
|
|
|
383
383
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
384
384
|
|
|
385
385
|
zs = zns
|
|
386
|
-
|
|
386
|
+
log_mu = self.decoder_log_mu(zs)
|
|
387
387
|
if self.loss_func in ['bernoulli']:
|
|
388
|
-
log_theta =
|
|
388
|
+
log_theta = log_mu
|
|
389
|
+
elif self.loss_func == 'negbinomial':
|
|
390
|
+
mu = log_mu.exp()
|
|
389
391
|
else:
|
|
390
|
-
rate =
|
|
392
|
+
rate = log_mu.exp()
|
|
391
393
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
392
394
|
if self.loss_func == 'poisson':
|
|
393
395
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
394
396
|
|
|
395
397
|
if self.loss_func == 'negbinomial':
|
|
398
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
396
399
|
if self.use_zeroinflate:
|
|
397
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
400
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
401
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
398
402
|
else:
|
|
399
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
403
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
|
|
404
|
+
logits=logits).to_event(1), obs=xs)
|
|
400
405
|
elif self.loss_func == 'poisson':
|
|
401
406
|
if self.use_zeroinflate:
|
|
402
407
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
403
408
|
else:
|
|
404
409
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
405
410
|
elif self.loss_func == 'multinomial':
|
|
406
|
-
|
|
407
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
411
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
408
412
|
elif self.loss_func == 'bernoulli':
|
|
409
413
|
if self.use_zeroinflate:
|
|
410
414
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -428,8 +432,8 @@ class DensityFlow(nn.Module):
|
|
|
428
432
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
429
433
|
|
|
430
434
|
if self.loss_func=='negbinomial':
|
|
431
|
-
|
|
432
|
-
|
|
435
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
436
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
433
437
|
|
|
434
438
|
if self.use_zeroinflate:
|
|
435
439
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -468,28 +472,30 @@ class DensityFlow(nn.Module):
|
|
|
468
472
|
else:
|
|
469
473
|
zs = zns
|
|
470
474
|
|
|
471
|
-
|
|
475
|
+
log_mu = self.decoder_log_mu(zs)
|
|
472
476
|
if self.loss_func in ['bernoulli']:
|
|
473
|
-
log_theta =
|
|
477
|
+
log_theta = log_mu
|
|
478
|
+
elif self.loss_func == 'negbinomial':
|
|
479
|
+
mu = log_mu.exp()
|
|
474
480
|
else:
|
|
475
|
-
rate =
|
|
481
|
+
rate = log_mu.exp()
|
|
476
482
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
477
483
|
if self.loss_func == 'poisson':
|
|
478
484
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
479
485
|
|
|
480
486
|
if self.loss_func == 'negbinomial':
|
|
487
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
481
488
|
if self.use_zeroinflate:
|
|
482
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
489
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
483
490
|
else:
|
|
484
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
491
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
485
492
|
elif self.loss_func == 'poisson':
|
|
486
493
|
if self.use_zeroinflate:
|
|
487
494
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
488
495
|
else:
|
|
489
496
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
490
497
|
elif self.loss_func == 'multinomial':
|
|
491
|
-
|
|
492
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
498
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
493
499
|
elif self.loss_func == 'bernoulli':
|
|
494
500
|
if self.use_zeroinflate:
|
|
495
501
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -513,8 +519,8 @@ class DensityFlow(nn.Module):
|
|
|
513
519
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
514
520
|
|
|
515
521
|
if self.loss_func=='negbinomial':
|
|
516
|
-
|
|
517
|
-
|
|
522
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
523
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
518
524
|
|
|
519
525
|
if self.use_zeroinflate:
|
|
520
526
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -565,28 +571,31 @@ class DensityFlow(nn.Module):
|
|
|
565
571
|
|
|
566
572
|
zs = zns
|
|
567
573
|
|
|
568
|
-
|
|
574
|
+
log_mu = self.decoder_log_mu(zs)
|
|
569
575
|
if self.loss_func in ['bernoulli']:
|
|
570
|
-
log_theta =
|
|
576
|
+
log_theta = log_mu
|
|
577
|
+
elif self.loss_func in ['negbinomial']:
|
|
578
|
+
mu = log_mu.exp()
|
|
571
579
|
else:
|
|
572
|
-
rate =
|
|
580
|
+
rate = log_mu.exp()
|
|
573
581
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
574
582
|
if self.loss_func == 'poisson':
|
|
575
583
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
576
584
|
|
|
577
585
|
if self.loss_func == 'negbinomial':
|
|
586
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
578
587
|
if self.use_zeroinflate:
|
|
579
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
588
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
589
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
580
590
|
else:
|
|
581
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
591
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
582
592
|
elif self.loss_func == 'poisson':
|
|
583
593
|
if self.use_zeroinflate:
|
|
584
594
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
585
595
|
else:
|
|
586
596
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
587
597
|
elif self.loss_func == 'multinomial':
|
|
588
|
-
|
|
589
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
598
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
590
599
|
elif self.loss_func == 'bernoulli':
|
|
591
600
|
if self.use_zeroinflate:
|
|
592
601
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -610,8 +619,8 @@ class DensityFlow(nn.Module):
|
|
|
610
619
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
611
620
|
|
|
612
621
|
if self.loss_func=='negbinomial':
|
|
613
|
-
|
|
614
|
-
|
|
622
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
623
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
615
624
|
|
|
616
625
|
if self.use_zeroinflate:
|
|
617
626
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -672,28 +681,31 @@ class DensityFlow(nn.Module):
|
|
|
672
681
|
else:
|
|
673
682
|
zs = zns
|
|
674
683
|
|
|
675
|
-
|
|
684
|
+
log_mu = self.decoder_log_mu(zs)
|
|
676
685
|
if self.loss_func in ['bernoulli']:
|
|
677
|
-
log_theta =
|
|
686
|
+
log_theta = log_mu
|
|
687
|
+
elif self.loss_func in ['negbinomial']:
|
|
688
|
+
mu = log_mu.exp()
|
|
678
689
|
else:
|
|
679
|
-
rate =
|
|
690
|
+
rate = log_mu.exp()
|
|
680
691
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
681
692
|
if self.loss_func == 'poisson':
|
|
682
693
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
683
694
|
|
|
684
695
|
if self.loss_func == 'negbinomial':
|
|
696
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
685
697
|
if self.use_zeroinflate:
|
|
686
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
698
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
699
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
687
700
|
else:
|
|
688
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
701
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
689
702
|
elif self.loss_func == 'poisson':
|
|
690
703
|
if self.use_zeroinflate:
|
|
691
704
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
692
705
|
else:
|
|
693
706
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
694
707
|
elif self.loss_func == 'multinomial':
|
|
695
|
-
|
|
696
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
708
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
697
709
|
elif self.loss_func == 'bernoulli':
|
|
698
710
|
if self.use_zeroinflate:
|
|
699
711
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -717,13 +729,13 @@ class DensityFlow(nn.Module):
|
|
|
717
729
|
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
718
730
|
#else:
|
|
719
731
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
720
|
-
zus = self.
|
|
732
|
+
zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
721
733
|
else:
|
|
722
734
|
#if self.turn_off_cell_specific:
|
|
723
735
|
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
724
736
|
#else:
|
|
725
737
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
726
|
-
zus = zus + self.
|
|
738
|
+
zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
727
739
|
return zus
|
|
728
740
|
|
|
729
741
|
def _get_codebook_identity(self):
|
|
@@ -865,12 +877,12 @@ class DensityFlow(nn.Module):
|
|
|
865
877
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
866
878
|
|
|
867
879
|
# factor effect of xs
|
|
868
|
-
dzs0 = self.
|
|
880
|
+
dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
869
881
|
|
|
870
882
|
# perturbation effect
|
|
871
883
|
ps = np.ones_like(us_i)
|
|
872
884
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
873
|
-
dzs = self.
|
|
885
|
+
dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
874
886
|
zs = zs + dzs0 + dzs
|
|
875
887
|
else:
|
|
876
888
|
zs = zs + dzs0
|
|
@@ -884,10 +896,11 @@ class DensityFlow(nn.Module):
|
|
|
884
896
|
library_sizes = library_sizes.reshape(-1,1)
|
|
885
897
|
|
|
886
898
|
counts = self.get_counts(zs, library_sizes=library_sizes)
|
|
899
|
+
log_mu = self.get_log_mu(zs)
|
|
887
900
|
|
|
888
|
-
return counts,
|
|
901
|
+
return counts, log_mu
|
|
889
902
|
|
|
890
|
-
def
|
|
903
|
+
def _cell_shift(self, zs, perturb_idx, perturb):
|
|
891
904
|
#zns,_ = self.encoder_zn(xs)
|
|
892
905
|
#zns,_ = self._get_basal_embedding(xs)
|
|
893
906
|
zns = zs
|
|
@@ -904,7 +917,7 @@ class DensityFlow(nn.Module):
|
|
|
904
917
|
|
|
905
918
|
return ms
|
|
906
919
|
|
|
907
|
-
def
|
|
920
|
+
def get_cell_shift(self,
|
|
908
921
|
zs,
|
|
909
922
|
perturb_idx,
|
|
910
923
|
perturb_us,
|
|
@@ -922,46 +935,43 @@ class DensityFlow(nn.Module):
|
|
|
922
935
|
Z = []
|
|
923
936
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
924
937
|
for Z_batch, P_batch, _ in dataloader:
|
|
925
|
-
zns = self.
|
|
938
|
+
zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
|
|
926
939
|
Z.append(tensor_to_numpy(zns))
|
|
927
940
|
pbar.update(1)
|
|
928
941
|
|
|
929
942
|
Z = np.concatenate(Z)
|
|
930
943
|
return Z
|
|
931
944
|
|
|
932
|
-
def
|
|
933
|
-
return self.
|
|
945
|
+
def _log_mu(self, zs):
|
|
946
|
+
return self.decoder_log_mu(zs)
|
|
934
947
|
|
|
935
|
-
def
|
|
936
|
-
delta_zs,
|
|
937
|
-
batch_size: int = 1024):
|
|
948
|
+
def get_log_mu(self, zs, batch_size: int = 1024):
|
|
938
949
|
"""
|
|
939
950
|
Return cells' changes in the feature space induced by specific perturbation of a factor
|
|
940
951
|
|
|
941
952
|
"""
|
|
942
|
-
|
|
943
|
-
dataset = CustomDataset(
|
|
953
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
954
|
+
dataset = CustomDataset(zs)
|
|
944
955
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
945
956
|
|
|
946
957
|
R = []
|
|
947
958
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
948
|
-
for
|
|
949
|
-
r = self.
|
|
959
|
+
for Z_batch, _ in dataloader:
|
|
960
|
+
r = self._log_mu(Z_batch)
|
|
950
961
|
R.append(tensor_to_numpy(r))
|
|
951
962
|
pbar.update(1)
|
|
952
963
|
|
|
953
964
|
R = np.concatenate(R)
|
|
954
965
|
return R
|
|
955
966
|
|
|
956
|
-
def _count(self,
|
|
967
|
+
def _count(self, log_mu, library_size=None):
|
|
957
968
|
if self.loss_func == 'bernoulli':
|
|
958
|
-
|
|
959
|
-
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
969
|
+
counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
|
|
960
970
|
elif self.loss_func == 'multinomial':
|
|
961
|
-
theta = dist.Multinomial(total_count=int(1e8), logits=
|
|
971
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
|
|
962
972
|
counts = theta * library_size
|
|
963
973
|
else:
|
|
964
|
-
rate =
|
|
974
|
+
rate = log_mu.exp()
|
|
965
975
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
966
976
|
counts = theta * library_size
|
|
967
977
|
return counts
|
|
@@ -983,8 +993,8 @@ class DensityFlow(nn.Module):
|
|
|
983
993
|
E = []
|
|
984
994
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
985
995
|
for Z_batch, L_batch, _ in dataloader:
|
|
986
|
-
|
|
987
|
-
counts = self._count(
|
|
996
|
+
log_mu = self._log_mu(Z_batch)
|
|
997
|
+
counts = self._count(log_mu, L_batch)
|
|
988
998
|
E.append(tensor_to_numpy(counts))
|
|
989
999
|
pbar.update(1)
|
|
990
1000
|
|
|
@@ -1357,7 +1367,7 @@ def main():
|
|
|
1357
1367
|
df = DensityFlow(
|
|
1358
1368
|
input_size=input_size,
|
|
1359
1369
|
cell_factor_size=cell_factor_size,
|
|
1360
|
-
|
|
1370
|
+
dispersion=args.dispersion,
|
|
1361
1371
|
z_dim=args.z_dim,
|
|
1362
1372
|
hidden_layers=args.hidden_layers,
|
|
1363
1373
|
hidden_layer_activation=args.hidden_layer_activation,
|