SURE-tools 2.1.55__tar.gz → 2.1.57__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.55 → sure_tools-2.1.57}/PKG-INFO +1 -1
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/PerturbFlow.py +33 -25
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.55 → sure_tools-2.1.57}/setup.py +1 -1
- {sure_tools-2.1.55 → sure_tools-2.1.57}/LICENSE +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/README.md +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/SURE.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/__init__.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.55 → sure_tools-2.1.57}/setup.cfg +0 -0
|
@@ -244,7 +244,7 @@ class PerturbFlow(nn.Module):
|
|
|
244
244
|
# allow_broadcast=self.allow_broadcast,
|
|
245
245
|
# use_cuda=self.use_cuda,
|
|
246
246
|
# )
|
|
247
|
-
self.encoder_concentrate = self.decoder_concentrate
|
|
247
|
+
#self.encoder_concentrate = self.decoder_concentrate
|
|
248
248
|
else:
|
|
249
249
|
self.decoder_concentrate = MLP(
|
|
250
250
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
@@ -380,6 +380,8 @@ class PerturbFlow(nn.Module):
|
|
|
380
380
|
concentrate = self.decoder_concentrate(zs)
|
|
381
381
|
if self.loss_func == 'bernoulli':
|
|
382
382
|
log_theta = concentrate
|
|
383
|
+
elif self.loss_func == 'negbinomial':
|
|
384
|
+
log_theta = concentrate
|
|
383
385
|
else:
|
|
384
386
|
rate = concentrate.exp()
|
|
385
387
|
if self.loss_func != 'poisson':
|
|
@@ -387,9 +389,9 @@ class PerturbFlow(nn.Module):
|
|
|
387
389
|
|
|
388
390
|
if self.loss_func == 'negbinomial':
|
|
389
391
|
if self.use_zeroinflate:
|
|
390
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
392
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
391
393
|
else:
|
|
392
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
394
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
|
|
393
395
|
elif self.loss_func == 'poisson':
|
|
394
396
|
if self.use_zeroinflate:
|
|
395
397
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -417,9 +419,9 @@ class PerturbFlow(nn.Module):
|
|
|
417
419
|
alpha = self.encoder_n(zns)
|
|
418
420
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
419
421
|
|
|
420
|
-
if self.loss_func == 'gamma-poisson':
|
|
421
|
-
|
|
422
|
-
|
|
422
|
+
#if self.loss_func == 'gamma-poisson':
|
|
423
|
+
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
424
|
+
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
423
425
|
|
|
424
426
|
def model2(self, xs, us=None):
|
|
425
427
|
pyro.module('PerturbFlow', self)
|
|
@@ -484,6 +486,8 @@ class PerturbFlow(nn.Module):
|
|
|
484
486
|
concentrate = self.decoder_concentrate(zs)
|
|
485
487
|
if self.loss_func == 'bernoulli':
|
|
486
488
|
log_theta = concentrate
|
|
489
|
+
elif self.loss_func == 'negbinomial':
|
|
490
|
+
log_theta = concentrate
|
|
487
491
|
else:
|
|
488
492
|
rate = concentrate.exp()
|
|
489
493
|
if self.loss_func != 'poisson':
|
|
@@ -491,9 +495,9 @@ class PerturbFlow(nn.Module):
|
|
|
491
495
|
|
|
492
496
|
if self.loss_func == 'negbinomial':
|
|
493
497
|
if self.use_zeroinflate:
|
|
494
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
498
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
495
499
|
else:
|
|
496
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
500
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
|
|
497
501
|
elif self.loss_func == 'poisson':
|
|
498
502
|
if self.use_zeroinflate:
|
|
499
503
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -521,9 +525,9 @@ class PerturbFlow(nn.Module):
|
|
|
521
525
|
alpha = self.encoder_n(zns)
|
|
522
526
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
523
527
|
|
|
524
|
-
if self.loss_func == 'gamma-poisson':
|
|
525
|
-
|
|
526
|
-
|
|
528
|
+
#if self.loss_func == 'gamma-poisson':
|
|
529
|
+
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
530
|
+
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
527
531
|
|
|
528
532
|
def model3(self, xs, ys, embeds=None):
|
|
529
533
|
pyro.module('PerturbFlow', self)
|
|
@@ -594,6 +598,8 @@ class PerturbFlow(nn.Module):
|
|
|
594
598
|
concentrate = self.decoder_concentrate(zs)
|
|
595
599
|
if self.loss_func == 'bernoulli':
|
|
596
600
|
log_theta = concentrate
|
|
601
|
+
elif self.loss_func == 'negbinomial':
|
|
602
|
+
log_theta = concentrate
|
|
597
603
|
else:
|
|
598
604
|
rate = concentrate.exp()
|
|
599
605
|
if self.loss_func != 'poisson':
|
|
@@ -601,9 +607,9 @@ class PerturbFlow(nn.Module):
|
|
|
601
607
|
|
|
602
608
|
if self.loss_func == 'negbinomial':
|
|
603
609
|
if self.use_zeroinflate:
|
|
604
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
610
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
605
611
|
else:
|
|
606
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
612
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
|
|
607
613
|
elif self.loss_func == 'poisson':
|
|
608
614
|
if self.use_zeroinflate:
|
|
609
615
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -631,9 +637,9 @@ class PerturbFlow(nn.Module):
|
|
|
631
637
|
else:
|
|
632
638
|
zns = embeds
|
|
633
639
|
|
|
634
|
-
if self.loss_func == 'gamma-poisson':
|
|
635
|
-
|
|
636
|
-
|
|
640
|
+
#if self.loss_func == 'gamma-poisson':
|
|
641
|
+
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
642
|
+
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
637
643
|
|
|
638
644
|
def model4(self, xs, us, ys, embeds=None):
|
|
639
645
|
pyro.module('PerturbFlow', self)
|
|
@@ -714,6 +720,8 @@ class PerturbFlow(nn.Module):
|
|
|
714
720
|
concentrate = self.decoder_concentrate(zs)
|
|
715
721
|
if self.loss_func == 'bernoulli':
|
|
716
722
|
log_theta = concentrate
|
|
723
|
+
elif self.loss_func == 'negbinomial':
|
|
724
|
+
log_theta = concentrate
|
|
717
725
|
else:
|
|
718
726
|
rate = concentrate.exp()
|
|
719
727
|
if self.loss_func != 'poisson':
|
|
@@ -721,9 +729,9 @@ class PerturbFlow(nn.Module):
|
|
|
721
729
|
|
|
722
730
|
if self.loss_func == 'negbinomial':
|
|
723
731
|
if self.use_zeroinflate:
|
|
724
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
732
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
725
733
|
else:
|
|
726
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
734
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
|
|
727
735
|
elif self.loss_func == 'poisson':
|
|
728
736
|
if self.use_zeroinflate:
|
|
729
737
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -751,9 +759,9 @@ class PerturbFlow(nn.Module):
|
|
|
751
759
|
else:
|
|
752
760
|
zns = embeds
|
|
753
761
|
|
|
754
|
-
if self.loss_func == 'gamma-poisson':
|
|
755
|
-
|
|
756
|
-
|
|
762
|
+
#if self.loss_func == 'gamma-poisson':
|
|
763
|
+
# con_alpha,con_beta = self.encoder_concentrate(zns)
|
|
764
|
+
# rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
|
|
757
765
|
|
|
758
766
|
def _total_effects(self, zns, us):
|
|
759
767
|
zus = None
|
|
@@ -933,7 +941,7 @@ class PerturbFlow(nn.Module):
|
|
|
933
941
|
|
|
934
942
|
def _get_expression_response(self, delta_zs):
|
|
935
943
|
if self.loss_func == 'gamma-poisson':
|
|
936
|
-
alpha,beta = self.
|
|
944
|
+
alpha,beta = self.decoder_concentrate(delta_zs)
|
|
937
945
|
xs = dist.Gamma(alpha,beta).to_event(1).mean
|
|
938
946
|
else:
|
|
939
947
|
xs = self.decoder_concentrate(delta_zs)
|
|
@@ -966,11 +974,11 @@ class PerturbFlow(nn.Module):
|
|
|
966
974
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
967
975
|
elif self.loss_func == 'negbinomial':
|
|
968
976
|
#counts = concentrate.exp()
|
|
969
|
-
rate = concentrate.exp()
|
|
970
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
977
|
+
#rate = concentrate.exp()
|
|
978
|
+
#theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
971
979
|
|
|
972
980
|
total_count = pyro.param("inverse_dispersion")
|
|
973
|
-
counts = dist.NegativeBinomial(total_count=total_count,
|
|
981
|
+
counts = dist.NegativeBinomial(total_count=total_count, logits=concentrate).to_event(1).mean
|
|
974
982
|
elif self.loss_func == 'poisson':
|
|
975
983
|
rate = concentrate.exp()
|
|
976
984
|
counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|