SURE-tools 2.1.51__py3-none-any.whl → 2.1.53__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/PerturbFlow.py +25 -8
- {sure_tools-2.1.51.dist-info → sure_tools-2.1.53.dist-info}/METADATA +1 -1
- {sure_tools-2.1.51.dist-info → sure_tools-2.1.53.dist-info}/RECORD +7 -7
- {sure_tools-2.1.51.dist-info → sure_tools-2.1.53.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.51.dist-info → sure_tools-2.1.53.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.51.dist-info → sure_tools-2.1.53.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.51.dist-info → sure_tools-2.1.53.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -317,12 +317,14 @@ class PerturbFlow(nn.Module):
|
|
|
317
317
|
if self.loss_func=='negbinomial':
|
|
318
318
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
319
319
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
320
|
+
elif self.loss_func == 'multinomial':
|
|
321
|
+
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
320
322
|
|
|
321
323
|
if self.use_zeroinflate:
|
|
322
324
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
323
325
|
|
|
324
326
|
acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
|
|
325
|
-
|
|
327
|
+
|
|
326
328
|
I = torch.eye(self.code_size)
|
|
327
329
|
if self.latent_dist=='studentt':
|
|
328
330
|
acs_dof,acs_loc = self.codebook(I)
|
|
@@ -369,7 +371,7 @@ class PerturbFlow(nn.Module):
|
|
|
369
371
|
else:
|
|
370
372
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
371
373
|
elif self.loss_func == 'multinomial':
|
|
372
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
374
|
+
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
373
375
|
elif self.loss_func == 'bernoulli':
|
|
374
376
|
if self.use_zeroinflate:
|
|
375
377
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -395,6 +397,8 @@ class PerturbFlow(nn.Module):
|
|
|
395
397
|
if self.loss_func=='negbinomial':
|
|
396
398
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
397
399
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
400
|
+
elif self.loss_func == 'multinomial':
|
|
401
|
+
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
398
402
|
|
|
399
403
|
if self.use_zeroinflate:
|
|
400
404
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -458,7 +462,7 @@ class PerturbFlow(nn.Module):
|
|
|
458
462
|
else:
|
|
459
463
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
460
464
|
elif self.loss_func == 'multinomial':
|
|
461
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
465
|
+
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
462
466
|
elif self.loss_func == 'bernoulli':
|
|
463
467
|
if self.use_zeroinflate:
|
|
464
468
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -484,6 +488,8 @@ class PerturbFlow(nn.Module):
|
|
|
484
488
|
if self.loss_func=='negbinomial':
|
|
485
489
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
486
490
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
491
|
+
elif self.loss_func == 'multinomial':
|
|
492
|
+
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
487
493
|
|
|
488
494
|
if self.use_zeroinflate:
|
|
489
495
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -553,7 +559,7 @@ class PerturbFlow(nn.Module):
|
|
|
553
559
|
else:
|
|
554
560
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
555
561
|
elif self.loss_func == 'multinomial':
|
|
556
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
562
|
+
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
557
563
|
elif self.loss_func == 'bernoulli':
|
|
558
564
|
if self.use_zeroinflate:
|
|
559
565
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -577,6 +583,8 @@ class PerturbFlow(nn.Module):
|
|
|
577
583
|
if self.loss_func=='negbinomial':
|
|
578
584
|
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
579
585
|
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
586
|
+
elif self.loss_func == 'multinomial':
|
|
587
|
+
total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
|
|
580
588
|
|
|
581
589
|
if self.use_zeroinflate:
|
|
582
590
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -656,7 +664,7 @@ class PerturbFlow(nn.Module):
|
|
|
656
664
|
else:
|
|
657
665
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
658
666
|
elif self.loss_func == 'multinomial':
|
|
659
|
-
pyro.sample('x', dist.Multinomial(total_count=
|
|
667
|
+
pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
|
|
660
668
|
elif self.loss_func == 'bernoulli':
|
|
661
669
|
if self.use_zeroinflate:
|
|
662
670
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -872,9 +880,18 @@ class PerturbFlow(nn.Module):
|
|
|
872
880
|
|
|
873
881
|
def _count(self,concentrate):
|
|
874
882
|
if self.loss_func == 'bernoulli':
|
|
875
|
-
counts = self.sigmoid(concentrate)
|
|
876
|
-
|
|
877
|
-
|
|
883
|
+
#counts = self.sigmoid(concentrate)
|
|
884
|
+
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
885
|
+
elif self.loss_func == 'negbinomial':
|
|
886
|
+
#counts = concentrate.exp()
|
|
887
|
+
rate = concentrate.exp()
|
|
888
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
889
|
+
|
|
890
|
+
total_count = pyro.param("inverse_dispersion")
|
|
891
|
+
counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1)
|
|
892
|
+
elif self.loss_func == 'poisson':
|
|
893
|
+
rate = concentrate.exp()
|
|
894
|
+
counts = dist.Poisson(rate=rate).to_event(1)
|
|
878
895
|
return counts
|
|
879
896
|
|
|
880
897
|
def _count_sample(self,concentrate):
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=JS0TguFFewNU6lwFLI0rtJsPUkDcHWFpN2USuBB1dL8,53827
|
|
2
2
|
SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
|
|
3
3
|
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
|
17
17
|
SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
25
|
-
sure_tools-2.1.
|
|
20
|
+
sure_tools-2.1.53.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.1.53.dist-info/METADATA,sha256=wNhmVGxxzIeL38Nb2VXIAJC6zX_jK3SgFqnqCd56ajA,2678
|
|
22
|
+
sure_tools-2.1.53.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.1.53.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.1.53.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.1.53.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|