SURE-tools 2.2.7__py3-none-any.whl → 2.2.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +82 -64
- SURE/perturb/perturb.py +27 -1
- {sure_tools-2.2.7.dist-info → sure_tools-2.2.18.dist-info}/METADATA +1 -1
- {sure_tools-2.2.7.dist-info → sure_tools-2.2.18.dist-info}/RECORD +8 -8
- {sure_tools-2.2.7.dist-info → sure_tools-2.2.18.dist-info}/WHEEL +0 -0
- {sure_tools-2.2.7.dist-info → sure_tools-2.2.18.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.2.7.dist-info → sure_tools-2.2.18.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.2.7.dist-info → sure_tools-2.2.18.dist-info}/top_level.txt +0 -0
SURE/DensityFlow.py
CHANGED
|
@@ -59,12 +59,13 @@ class DensityFlow(nn.Module):
|
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
|
+
turn_off_cell_specific: bool = False,
|
|
62
63
|
supervised_mode: bool = False,
|
|
63
64
|
z_dim: int = 10,
|
|
64
65
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
66
67
|
inverse_dispersion: float = 10.0,
|
|
67
|
-
use_zeroinflate: bool =
|
|
68
|
+
use_zeroinflate: bool = False,
|
|
68
69
|
hidden_layers: list = [500],
|
|
69
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
71
|
nn_dropout: float = 0.1,
|
|
@@ -102,6 +103,7 @@ class DensityFlow(nn.Module):
|
|
|
102
103
|
else:
|
|
103
104
|
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
104
105
|
#self.use_bias = not zero_bias
|
|
106
|
+
self.turn_off_cell_specific = turn_off_cell_specific
|
|
105
107
|
|
|
106
108
|
self.codebook_weights = None
|
|
107
109
|
|
|
@@ -203,27 +205,51 @@ class DensityFlow(nn.Module):
|
|
|
203
205
|
self.cell_factor_effect = nn.ModuleList()
|
|
204
206
|
for i in np.arange(self.cell_factor_size):
|
|
205
207
|
if self.use_bias[i]:
|
|
206
|
-
self.
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
208
|
+
if self.turn_off_cell_specific:
|
|
209
|
+
self.cell_factor_effect.append(MLP(
|
|
210
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
211
|
+
activation=activate_fct,
|
|
212
|
+
output_activation=None,
|
|
213
|
+
post_layer_fct=post_layer_fct,
|
|
214
|
+
post_act_fct=post_act_fct,
|
|
215
|
+
allow_broadcast=self.allow_broadcast,
|
|
216
|
+
use_cuda=self.use_cuda,
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
self.cell_factor_effect.append(MLP(
|
|
221
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
222
|
+
activation=activate_fct,
|
|
223
|
+
output_activation=None,
|
|
224
|
+
post_layer_fct=post_layer_fct,
|
|
225
|
+
post_act_fct=post_act_fct,
|
|
226
|
+
allow_broadcast=self.allow_broadcast,
|
|
227
|
+
use_cuda=self.use_cuda,
|
|
228
|
+
)
|
|
214
229
|
)
|
|
215
|
-
)
|
|
216
230
|
else:
|
|
217
|
-
self.
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
231
|
+
if self.turn_off_cell_specific:
|
|
232
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
233
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
234
|
+
activation=activate_fct,
|
|
235
|
+
output_activation=None,
|
|
236
|
+
post_layer_fct=post_layer_fct,
|
|
237
|
+
post_act_fct=post_act_fct,
|
|
238
|
+
allow_broadcast=self.allow_broadcast,
|
|
239
|
+
use_cuda=self.use_cuda,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
244
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
245
|
+
activation=activate_fct,
|
|
246
|
+
output_activation=None,
|
|
247
|
+
post_layer_fct=post_layer_fct,
|
|
248
|
+
post_act_fct=post_act_fct,
|
|
249
|
+
allow_broadcast=self.allow_broadcast,
|
|
250
|
+
use_cuda=self.use_cuda,
|
|
251
|
+
)
|
|
225
252
|
)
|
|
226
|
-
)
|
|
227
253
|
|
|
228
254
|
self.decoder_concentrate = MLP(
|
|
229
255
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
@@ -676,9 +702,17 @@ class DensityFlow(nn.Module):
|
|
|
676
702
|
zus = None
|
|
677
703
|
for i in np.arange(self.cell_factor_size):
|
|
678
704
|
if i==0:
|
|
679
|
-
|
|
705
|
+
#if self.turn_off_cell_specific:
|
|
706
|
+
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
707
|
+
#else:
|
|
708
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
709
|
+
zus = self._cell_response(zns, i, us)
|
|
680
710
|
else:
|
|
681
|
-
|
|
711
|
+
#if self.turn_off_cell_specific:
|
|
712
|
+
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
713
|
+
#else:
|
|
714
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
715
|
+
zus = zus + self._cell_response(zns, i, us)
|
|
682
716
|
return zus
|
|
683
717
|
|
|
684
718
|
def _get_codebook_identity(self):
|
|
@@ -696,7 +730,7 @@ class DensityFlow(nn.Module):
|
|
|
696
730
|
"""
|
|
697
731
|
Return the mean part of metacell codebook
|
|
698
732
|
"""
|
|
699
|
-
cb = self.
|
|
733
|
+
cb = self._get_codebook()
|
|
700
734
|
cb = tensor_to_numpy(cb)
|
|
701
735
|
return cb
|
|
702
736
|
|
|
@@ -820,12 +854,12 @@ class DensityFlow(nn.Module):
|
|
|
820
854
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
821
855
|
|
|
822
856
|
# factor effect of xs
|
|
823
|
-
dzs0 = self.get_cell_response(
|
|
857
|
+
dzs0 = self.get_cell_response(zs, factor_idx=pert_idx, perturb=us_i)
|
|
824
858
|
|
|
825
859
|
# perturbation effect
|
|
826
860
|
ps = np.ones_like(us_i)
|
|
827
861
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
828
|
-
dzs = self.get_cell_response(
|
|
862
|
+
dzs = self.get_cell_response(zs, factor_idx=pert_idx, perturb=ps)
|
|
829
863
|
zs = zs + dzs0 + dzs
|
|
830
864
|
else:
|
|
831
865
|
zs = zs + dzs0
|
|
@@ -842,47 +876,48 @@ class DensityFlow(nn.Module):
|
|
|
842
876
|
|
|
843
877
|
return counts, zs
|
|
844
878
|
|
|
845
|
-
def _cell_response(self,
|
|
879
|
+
def _cell_response(self, zs, perturb_idx, perturb):
|
|
846
880
|
#zns,_ = self.encoder_zn(xs)
|
|
847
|
-
zns,_ = self._get_basal_embedding(xs)
|
|
881
|
+
#zns,_ = self._get_basal_embedding(xs)
|
|
882
|
+
zns = zs
|
|
848
883
|
if perturb.ndim==2:
|
|
849
|
-
|
|
884
|
+
if self.turn_off_cell_specific:
|
|
885
|
+
ms = self.cell_factor_effect[perturb_idx](perturb)
|
|
886
|
+
else:
|
|
887
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
|
|
850
888
|
else:
|
|
851
|
-
|
|
889
|
+
if self.turn_off_cell_specific:
|
|
890
|
+
ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
|
|
891
|
+
else:
|
|
892
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
|
|
852
893
|
|
|
853
894
|
return ms
|
|
854
895
|
|
|
855
896
|
def get_cell_response(self,
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
897
|
+
zs,
|
|
898
|
+
perturb_idx,
|
|
899
|
+
perturb_us,
|
|
859
900
|
batch_size: int = 1024):
|
|
860
901
|
"""
|
|
861
902
|
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
862
903
|
|
|
863
904
|
"""
|
|
864
|
-
xs = self.preprocess(xs)
|
|
865
|
-
|
|
866
|
-
ps = convert_to_tensor(
|
|
867
|
-
dataset = CustomDataset2(
|
|
905
|
+
#xs = self.preprocess(xs)
|
|
906
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
907
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
908
|
+
dataset = CustomDataset2(zs,ps)
|
|
868
909
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
869
910
|
|
|
870
911
|
Z = []
|
|
871
912
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
872
|
-
for
|
|
873
|
-
zns = self._cell_response(
|
|
913
|
+
for Z_batch, P_batch, _ in dataloader:
|
|
914
|
+
zns = self._cell_response(Z_batch, perturb_idx, P_batch)
|
|
874
915
|
Z.append(tensor_to_numpy(zns))
|
|
875
916
|
pbar.update(1)
|
|
876
917
|
|
|
877
918
|
Z = np.concatenate(Z)
|
|
878
919
|
return Z
|
|
879
920
|
|
|
880
|
-
def get_metacell_response(self, factor_idx, perturb):
|
|
881
|
-
zs = self._get_codebook()
|
|
882
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
883
|
-
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
884
|
-
return tensor_to_numpy(ms)
|
|
885
|
-
|
|
886
921
|
def _get_expression_response(self, delta_zs):
|
|
887
922
|
return self.decoder_concentrate(delta_zs)
|
|
888
923
|
|
|
@@ -907,7 +942,7 @@ class DensityFlow(nn.Module):
|
|
|
907
942
|
R = np.concatenate(R)
|
|
908
943
|
return R
|
|
909
944
|
|
|
910
|
-
def _count(self,concentrate, library_size=None):
|
|
945
|
+
def _count(self, concentrate, library_size=None):
|
|
911
946
|
if self.loss_func == 'bernoulli':
|
|
912
947
|
#counts = self.sigmoid(concentrate)
|
|
913
948
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
@@ -917,18 +952,8 @@ class DensityFlow(nn.Module):
|
|
|
917
952
|
counts = theta * library_size
|
|
918
953
|
return counts
|
|
919
954
|
|
|
920
|
-
def _count_sample(self,concentrate):
|
|
921
|
-
if self.loss_func == 'bernoulli':
|
|
922
|
-
logits = concentrate
|
|
923
|
-
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
924
|
-
else:
|
|
925
|
-
counts = self._count(concentrate=concentrate)
|
|
926
|
-
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
927
|
-
return counts
|
|
928
|
-
|
|
929
955
|
def get_counts(self, zs, library_sizes,
|
|
930
|
-
batch_size: int = 1024
|
|
931
|
-
use_sampler: bool = False):
|
|
956
|
+
batch_size: int = 1024):
|
|
932
957
|
|
|
933
958
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
934
959
|
|
|
@@ -945,10 +970,7 @@ class DensityFlow(nn.Module):
|
|
|
945
970
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
946
971
|
for Z_batch, L_batch, _ in dataloader:
|
|
947
972
|
concentrate = self._get_expression_response(Z_batch)
|
|
948
|
-
|
|
949
|
-
counts = self._count_sample(concentrate)
|
|
950
|
-
else:
|
|
951
|
-
counts = self._count(concentrate, L_batch)
|
|
973
|
+
counts = self._count(concentrate, L_batch)
|
|
952
974
|
E.append(tensor_to_numpy(counts))
|
|
953
975
|
pbar.update(1)
|
|
954
976
|
|
|
@@ -1093,9 +1115,6 @@ class DensityFlow(nn.Module):
|
|
|
1093
1115
|
# Update progress bar
|
|
1094
1116
|
pbar.set_postfix({'loss': str_loss})
|
|
1095
1117
|
pbar.update(1)
|
|
1096
|
-
|
|
1097
|
-
if self.loss_func == 'negbinomial':
|
|
1098
|
-
self.inverse_dispersion = pyro.param("inverse_dispersion")
|
|
1099
1118
|
|
|
1100
1119
|
@classmethod
|
|
1101
1120
|
def save_model(cls, model, file_path, compression=False):
|
|
@@ -1359,5 +1378,4 @@ def main():
|
|
|
1359
1378
|
|
|
1360
1379
|
|
|
1361
1380
|
if __name__ == "__main__":
|
|
1362
|
-
|
|
1363
1381
|
main()
|
SURE/perturb/perturb.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
3
4
|
from numba import njit
|
|
4
5
|
from itertools import chain
|
|
5
6
|
from joblib import Parallel, delayed
|
|
@@ -8,6 +9,8 @@ from typing import Literal
|
|
|
8
9
|
class LabelMatrix:
|
|
9
10
|
def __init__(self):
|
|
10
11
|
self.labels_ = None
|
|
12
|
+
self.control_label = None
|
|
13
|
+
self.sep_pattern = None
|
|
11
14
|
|
|
12
15
|
def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
13
16
|
if speedup=='none':
|
|
@@ -24,8 +27,31 @@ class LabelMatrix:
|
|
|
24
27
|
mat = np.delete(mat, idx, axis=1)
|
|
25
28
|
self.labels_ = np.delete(self.labels_, idx)
|
|
26
29
|
|
|
30
|
+
self.control_label = control_label
|
|
31
|
+
self.sep_pattern=sep_pattern
|
|
32
|
+
|
|
27
33
|
return mat
|
|
28
|
-
|
|
34
|
+
|
|
35
|
+
def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
|
|
36
|
+
sep_pattern = self.sep_pattern
|
|
37
|
+
if speedup=='none':
|
|
38
|
+
mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
39
|
+
elif speedup=='vectorize':
|
|
40
|
+
mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
41
|
+
elif speedup=='parallel':
|
|
42
|
+
mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
43
|
+
|
|
44
|
+
mat_df = pd.DataFrame(mat, columns=labels_)
|
|
45
|
+
|
|
46
|
+
labels_valid = [x for x in labels_ if x in self.labels_]
|
|
47
|
+
mat_df = mat_df[labels_valid]
|
|
48
|
+
|
|
49
|
+
mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
|
|
50
|
+
mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
|
|
51
|
+
mat_valid_df[labels_valid] = mat_df
|
|
52
|
+
|
|
53
|
+
return mat_valid_df.values
|
|
54
|
+
|
|
29
55
|
def inverse_transform(self, matrix):
|
|
30
56
|
return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
|
|
31
57
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/DensityFlow.py,sha256=
|
|
1
|
+
SURE/DensityFlow.py,sha256=FQ5LT-5xo_O3Qa5I0hxYBMq8f7HvapxfwTQj-oc3iyI,56086
|
|
2
2
|
SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
|
|
3
3
|
SURE/__init__.py,sha256=NVp22RCHrhSwHNMomABC-eftoCYvt7vV1XOzim-UZHE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
@@ -12,14 +12,14 @@ SURE/flow/__init__.py,sha256=rsAjYsh1xVIrxBCuwOE0Q_6N5th1wBgjJceV0ABPG3c,183
|
|
|
12
12
|
SURE/flow/flow_stats.py,sha256=6SzNMT59WRFRP1nC6bvpBPF7BugWnkIS_DSlr4S-Ez0,11338
|
|
13
13
|
SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
|
|
14
14
|
SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
|
|
15
|
-
SURE/perturb/perturb.py,sha256=
|
|
15
|
+
SURE/perturb/perturb.py,sha256=ey7cxsM1tO1MW4UaE_MLpLHK87CjvXzn2CBPtvv1VZ0,6116
|
|
16
16
|
SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
17
17
|
SURE/utils/custom_mlp.py,sha256=HuNb7f8-6RFjsvfEu1XOuNpLrHZkGYHgf8TpJfPSNO0,9382
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.2.
|
|
21
|
-
sure_tools-2.2.
|
|
22
|
-
sure_tools-2.2.
|
|
23
|
-
sure_tools-2.2.
|
|
24
|
-
sure_tools-2.2.
|
|
25
|
-
sure_tools-2.2.
|
|
20
|
+
sure_tools-2.2.18.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.2.18.dist-info/METADATA,sha256=QIr1OsgtniZmJyGGBnA-Pfmm385gGJl6W6vMF1KxgRY,2678
|
|
22
|
+
sure_tools-2.2.18.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.2.18.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.2.18.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.2.18.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|