SURE-tools 2.2.0__py3-none-any.whl → 2.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +3 -3
- SURE/SURE.py +6 -6
- SURE/flow/flow_stats.py +12 -0
- {sure_tools-2.2.0.dist-info → sure_tools-2.2.2.dist-info}/METADATA +1 -1
- {sure_tools-2.2.0.dist-info → sure_tools-2.2.2.dist-info}/RECORD +9 -9
- {sure_tools-2.2.0.dist-info → sure_tools-2.2.2.dist-info}/WHEEL +0 -0
- {sure_tools-2.2.0.dist-info → sure_tools-2.2.2.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.2.0.dist-info → sure_tools-2.2.2.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.2.0.dist-info → sure_tools-2.2.2.dist-info}/top_level.txt +0 -0
SURE/DensityFlow.py
CHANGED
|
@@ -64,8 +64,8 @@ class DensityFlow(nn.Module):
|
|
|
64
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
65
65
|
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
|
|
66
66
|
inverse_dispersion: float = 10.0,
|
|
67
|
-
use_zeroinflate: bool =
|
|
68
|
-
hidden_layers: list = [
|
|
67
|
+
use_zeroinflate: bool = True,
|
|
68
|
+
hidden_layers: list = [500],
|
|
69
69
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
70
|
nn_dropout: float = 0.1,
|
|
71
71
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -970,7 +970,7 @@ class DensityFlow(nn.Module):
|
|
|
970
970
|
us = None,
|
|
971
971
|
ys = None,
|
|
972
972
|
zs = None,
|
|
973
|
-
num_epochs: int =
|
|
973
|
+
num_epochs: int = 500,
|
|
974
974
|
learning_rate: float = 0.0001,
|
|
975
975
|
batch_size: int = 256,
|
|
976
976
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
SURE/SURE.py
CHANGED
|
@@ -99,17 +99,17 @@ class SURE(nn.Module):
|
|
|
99
99
|
cell_factor_size: int = 0,
|
|
100
100
|
supervised_mode: bool = False,
|
|
101
101
|
z_dim: int = 10,
|
|
102
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
103
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
102
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
103
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
|
|
104
104
|
inverse_dispersion: float = 10.0,
|
|
105
105
|
use_zeroinflate: bool = True,
|
|
106
|
-
hidden_layers: list = [
|
|
106
|
+
hidden_layers: list = [500],
|
|
107
107
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
108
108
|
nn_dropout: float = 0.1,
|
|
109
109
|
post_layer_fct: list = ['layernorm'],
|
|
110
110
|
post_act_fct: list = None,
|
|
111
111
|
config_enum: str = 'parallel',
|
|
112
|
-
use_cuda: bool =
|
|
112
|
+
use_cuda: bool = True,
|
|
113
113
|
seed: int = 42,
|
|
114
114
|
dtype = torch.float32, # type: ignore
|
|
115
115
|
):
|
|
@@ -817,7 +817,7 @@ class SURE(nn.Module):
|
|
|
817
817
|
us = None,
|
|
818
818
|
ys = None,
|
|
819
819
|
zs = None,
|
|
820
|
-
num_epochs: int =
|
|
820
|
+
num_epochs: int = 500,
|
|
821
821
|
learning_rate: float = 0.0001,
|
|
822
822
|
batch_size: int = 256,
|
|
823
823
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -826,7 +826,7 @@ class SURE(nn.Module):
|
|
|
826
826
|
decay_rate: float = 0.9,
|
|
827
827
|
config_enum: str = 'parallel',
|
|
828
828
|
threshold: int = 0,
|
|
829
|
-
use_jax: bool =
|
|
829
|
+
use_jax: bool = True):
|
|
830
830
|
"""
|
|
831
831
|
Train the SURE model.
|
|
832
832
|
|
SURE/flow/flow_stats.py
CHANGED
|
@@ -41,6 +41,18 @@ class VectorFieldEval:
|
|
|
41
41
|
divergence[np.isnan(divergence)] = 0
|
|
42
42
|
|
|
43
43
|
return divergence
|
|
44
|
+
|
|
45
|
+
def movement_stats(self,vectors):
|
|
46
|
+
return calculate_movement_stats(vectors)
|
|
47
|
+
|
|
48
|
+
def direction_stats(self, vectors):
|
|
49
|
+
return calculate_direction_stats(vectors)
|
|
50
|
+
|
|
51
|
+
def movement_energy(self, vectors, masses=None):
|
|
52
|
+
return calculate_movement_energy(vectors, masses)
|
|
53
|
+
|
|
54
|
+
def movement_divergence(self, positions, vectors):
|
|
55
|
+
return calculate_movement_divergence(positions, vectors)
|
|
44
56
|
|
|
45
57
|
|
|
46
58
|
def calculate_movement_stats(vectors):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
SURE/DensityFlow.py,sha256=
|
|
1
|
+
SURE/DensityFlow.py,sha256=ACSFkwzlyBqLH39SSOKKepaVFXldpX0Hbu50jd9jjXQ,54676
|
|
2
2
|
SURE/PerturbFlow.py,sha256=kVpEdVgW_AGvvv7d1KnnBydx_hGfpXUXmFg4t4dPlwY,54677
|
|
3
|
-
SURE/SURE.py,sha256=
|
|
3
|
+
SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
|
|
4
4
|
SURE/__init__.py,sha256=NVp22RCHrhSwHNMomABC-eftoCYvt7vV1XOzim-UZHE,293
|
|
5
5
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
6
6
|
SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
|
|
@@ -10,7 +10,7 @@ SURE/atac/utils.py,sha256=m4NYwpy9O5T1pXTzgCOCcmlwrC6GTi-cQ5sm2wZu2O8,4354
|
|
|
10
10
|
SURE/codebook/__init__.py,sha256=2T5gjp8JIaBayrXAnOJYSebQHsWprOs87difpR1OPNw,243
|
|
11
11
|
SURE/codebook/codebook.py,sha256=ZlN6gRX9Gj2D2u3P5KeOsbZri0MoMAiJo9lNeL-MK-I,17117
|
|
12
12
|
SURE/flow/__init__.py,sha256=rsAjYsh1xVIrxBCuwOE0Q_6N5th1wBgjJceV0ABPG3c,183
|
|
13
|
-
SURE/flow/flow_stats.py,sha256=
|
|
13
|
+
SURE/flow/flow_stats.py,sha256=6SzNMT59WRFRP1nC6bvpBPF7BugWnkIS_DSlr4S-Ez0,11338
|
|
14
14
|
SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
|
|
15
15
|
SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
|
|
16
16
|
SURE/perturb/perturb.py,sha256=1iSsCePcwkA2CyM1nCdq_G8gogUNjhMH0BfhhvhpJQk,5037
|
|
@@ -18,9 +18,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
|
18
18
|
SURE/utils/custom_mlp.py,sha256=HuNb7f8-6RFjsvfEu1XOuNpLrHZkGYHgf8TpJfPSNO0,9382
|
|
19
19
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
20
20
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
21
|
-
sure_tools-2.2.
|
|
22
|
-
sure_tools-2.2.
|
|
23
|
-
sure_tools-2.2.
|
|
24
|
-
sure_tools-2.2.
|
|
25
|
-
sure_tools-2.2.
|
|
26
|
-
sure_tools-2.2.
|
|
21
|
+
sure_tools-2.2.2.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
22
|
+
sure_tools-2.2.2.dist-info/METADATA,sha256=ptLGAXYWNil97QNNpcCCasW3pWTNe6kMRYm-vTm6TBE,2677
|
|
23
|
+
sure_tools-2.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
24
|
+
sure_tools-2.2.2.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
25
|
+
sure_tools-2.2.2.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
26
|
+
sure_tools-2.2.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|