SURE-tools 2.1.50__py3-none-any.whl → 2.1.52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/PerturbFlow.py +14 -5
- {sure_tools-2.1.50.dist-info → sure_tools-2.1.52.dist-info}/METADATA +1 -1
- {sure_tools-2.1.50.dist-info → sure_tools-2.1.52.dist-info}/RECORD +7 -7
- {sure_tools-2.1.50.dist-info → sure_tools-2.1.52.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.50.dist-info → sure_tools-2.1.52.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.50.dist-info → sure_tools-2.1.52.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.50.dist-info → sure_tools-2.1.52.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -64,7 +64,7 @@ class PerturbFlow(nn.Module):
|
|
|
64
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
65
65
|
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
66
66
|
inverse_dispersion: float = 10.0,
|
|
67
|
-
use_zeroinflate: bool =
|
|
67
|
+
use_zeroinflate: bool = False,
|
|
68
68
|
hidden_layers: list = [300],
|
|
69
69
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
70
|
nn_dropout: float = 0.1,
|
|
@@ -872,9 +872,18 @@ class PerturbFlow(nn.Module):
|
|
|
872
872
|
|
|
873
873
|
def _count(self,concentrate):
|
|
874
874
|
if self.loss_func == 'bernoulli':
|
|
875
|
-
counts = self.sigmoid(concentrate)
|
|
876
|
-
|
|
877
|
-
|
|
875
|
+
#counts = self.sigmoid(concentrate)
|
|
876
|
+
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
877
|
+
elif self.loss_func == 'negbinomial':
|
|
878
|
+
#counts = concentrate.exp()
|
|
879
|
+
rate = concentrate.exp()
|
|
880
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
881
|
+
|
|
882
|
+
total_count = pyro.param("inverse_dispersion")
|
|
883
|
+
counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1)
|
|
884
|
+
elif self.loss_func == 'poisson':
|
|
885
|
+
rate = concentrate.exp()
|
|
886
|
+
counts = dist.Poisson(rate=rate).to_event(1)
|
|
878
887
|
return counts
|
|
879
888
|
|
|
880
889
|
def _count_sample(self,concentrate):
|
|
@@ -933,7 +942,7 @@ class PerturbFlow(nn.Module):
|
|
|
933
942
|
decay_rate: float = 0.9,
|
|
934
943
|
config_enum: str = 'parallel',
|
|
935
944
|
threshold: int = 0,
|
|
936
|
-
use_jax: bool =
|
|
945
|
+
use_jax: bool = True):
|
|
937
946
|
"""
|
|
938
947
|
Train the PerturbFlow model.
|
|
939
948
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=yt0kOW4buZKpJQ3Jn_8Zd2uEKUq29DZYkSzcgEP55EA,53211
|
|
2
2
|
SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
|
|
3
3
|
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
|
17
17
|
SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
25
|
-
sure_tools-2.1.
|
|
20
|
+
sure_tools-2.1.52.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.1.52.dist-info/METADATA,sha256=ARO36IQ9aKV9Sp4F9AkLR6zzXSsPnMGguljP7XU95Mk,2678
|
|
22
|
+
sure_tools-2.1.52.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.1.52.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.1.52.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.1.52.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|