SURE-tools 2.1.56__py3-none-any.whl → 2.1.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/PerturbFlow.py +19 -11
- {sure_tools-2.1.56.dist-info → sure_tools-2.1.57.dist-info}/METADATA +1 -1
- {sure_tools-2.1.56.dist-info → sure_tools-2.1.57.dist-info}/RECORD +7 -7
- {sure_tools-2.1.56.dist-info → sure_tools-2.1.57.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.56.dist-info → sure_tools-2.1.57.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.56.dist-info → sure_tools-2.1.57.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.56.dist-info → sure_tools-2.1.57.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -380,6 +380,8 @@ class PerturbFlow(nn.Module):
|
|
|
380
380
|
concentrate = self.decoder_concentrate(zs)
|
|
381
381
|
if self.loss_func == 'bernoulli':
|
|
382
382
|
log_theta = concentrate
|
|
383
|
+
elif self.loss_func == 'negbinomial':
|
|
384
|
+
log_theta = concentrate
|
|
383
385
|
else:
|
|
384
386
|
rate = concentrate.exp()
|
|
385
387
|
if self.loss_func != 'poisson':
|
|
@@ -387,9 +389,9 @@ class PerturbFlow(nn.Module):
|
|
|
387
389
|
|
|
388
390
|
if self.loss_func == 'negbinomial':
|
|
389
391
|
if self.use_zeroinflate:
|
|
390
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
392
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
391
393
|
else:
|
|
392
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
394
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
|
|
393
395
|
elif self.loss_func == 'poisson':
|
|
394
396
|
if self.use_zeroinflate:
|
|
395
397
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -484,6 +486,8 @@ class PerturbFlow(nn.Module):
|
|
|
484
486
|
concentrate = self.decoder_concentrate(zs)
|
|
485
487
|
if self.loss_func == 'bernoulli':
|
|
486
488
|
log_theta = concentrate
|
|
489
|
+
elif self.loss_func == 'negbinomial':
|
|
490
|
+
log_theta = concentrate
|
|
487
491
|
else:
|
|
488
492
|
rate = concentrate.exp()
|
|
489
493
|
if self.loss_func != 'poisson':
|
|
@@ -491,9 +495,9 @@ class PerturbFlow(nn.Module):
|
|
|
491
495
|
|
|
492
496
|
if self.loss_func == 'negbinomial':
|
|
493
497
|
if self.use_zeroinflate:
|
|
494
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
498
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
495
499
|
else:
|
|
496
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
500
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
|
|
497
501
|
elif self.loss_func == 'poisson':
|
|
498
502
|
if self.use_zeroinflate:
|
|
499
503
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -594,6 +598,8 @@ class PerturbFlow(nn.Module):
|
|
|
594
598
|
concentrate = self.decoder_concentrate(zs)
|
|
595
599
|
if self.loss_func == 'bernoulli':
|
|
596
600
|
log_theta = concentrate
|
|
601
|
+
elif self.loss_func == 'negbinomial':
|
|
602
|
+
log_theta = concentrate
|
|
597
603
|
else:
|
|
598
604
|
rate = concentrate.exp()
|
|
599
605
|
if self.loss_func != 'poisson':
|
|
@@ -601,9 +607,9 @@ class PerturbFlow(nn.Module):
|
|
|
601
607
|
|
|
602
608
|
if self.loss_func == 'negbinomial':
|
|
603
609
|
if self.use_zeroinflate:
|
|
604
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
610
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
605
611
|
else:
|
|
606
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
612
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
|
|
607
613
|
elif self.loss_func == 'poisson':
|
|
608
614
|
if self.use_zeroinflate:
|
|
609
615
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -714,6 +720,8 @@ class PerturbFlow(nn.Module):
|
|
|
714
720
|
concentrate = self.decoder_concentrate(zs)
|
|
715
721
|
if self.loss_func == 'bernoulli':
|
|
716
722
|
log_theta = concentrate
|
|
723
|
+
elif self.loss_func == 'negbinomial':
|
|
724
|
+
log_theta = concentrate
|
|
717
725
|
else:
|
|
718
726
|
rate = concentrate.exp()
|
|
719
727
|
if self.loss_func != 'poisson':
|
|
@@ -721,9 +729,9 @@ class PerturbFlow(nn.Module):
|
|
|
721
729
|
|
|
722
730
|
if self.loss_func == 'negbinomial':
|
|
723
731
|
if self.use_zeroinflate:
|
|
724
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
732
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
725
733
|
else:
|
|
726
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
734
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
|
|
727
735
|
elif self.loss_func == 'poisson':
|
|
728
736
|
if self.use_zeroinflate:
|
|
729
737
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -966,11 +974,11 @@ class PerturbFlow(nn.Module):
|
|
|
966
974
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
967
975
|
elif self.loss_func == 'negbinomial':
|
|
968
976
|
#counts = concentrate.exp()
|
|
969
|
-
rate = concentrate.exp()
|
|
970
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
977
|
+
#rate = concentrate.exp()
|
|
978
|
+
#theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
971
979
|
|
|
972
980
|
total_count = pyro.param("inverse_dispersion")
|
|
973
|
-
counts = dist.NegativeBinomial(total_count=total_count,
|
|
981
|
+
counts = dist.NegativeBinomial(total_count=total_count, logits=concentrate).to_event(1).mean
|
|
974
982
|
elif self.loss_func == 'poisson':
|
|
975
983
|
rate = concentrate.exp()
|
|
976
984
|
counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=emkyhDc99eTJQNkMdsHCp6VPg6468CRkc8lRHyA4P4o,59977
|
|
2
2
|
SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
|
|
3
3
|
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
|
17
17
|
SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
25
|
-
sure_tools-2.1.
|
|
20
|
+
sure_tools-2.1.57.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.1.57.dist-info/METADATA,sha256=Y1npoz3fb9597vOLCGFK4__9N85QgzBnX_zUra5E1Fg,2678
|
|
22
|
+
sure_tools-2.1.57.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.1.57.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.1.57.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.1.57.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|