SURE-tools 2.2.2__tar.gz → 2.2.25__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.2.2 → sure_tools-2.2.25}/PKG-INFO +1 -1
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/DensityFlow.py +88 -66
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/perturb/perturb.py +27 -1
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/utils/custom_mlp.py +8 -2
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.2.2 → sure_tools-2.2.25}/setup.py +1 -1
- {sure_tools-2.2.2 → sure_tools-2.2.25}/LICENSE +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/README.md +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/SURE.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/__init__.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/atac/utils.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/utils/queue.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE/utils/utils.py +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.2.2 → sure_tools-2.2.25}/setup.cfg +0 -0
|
@@ -59,12 +59,13 @@ class DensityFlow(nn.Module):
|
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
|
+
turn_off_cell_specific: bool = False,
|
|
62
63
|
supervised_mode: bool = False,
|
|
63
64
|
z_dim: int = 10,
|
|
64
65
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
66
67
|
inverse_dispersion: float = 10.0,
|
|
67
|
-
use_zeroinflate: bool =
|
|
68
|
+
use_zeroinflate: bool = False,
|
|
68
69
|
hidden_layers: list = [500],
|
|
69
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
71
|
nn_dropout: float = 0.1,
|
|
@@ -102,6 +103,7 @@ class DensityFlow(nn.Module):
|
|
|
102
103
|
else:
|
|
103
104
|
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
104
105
|
#self.use_bias = not zero_bias
|
|
106
|
+
self.turn_off_cell_specific = turn_off_cell_specific
|
|
105
107
|
|
|
106
108
|
self.codebook_weights = None
|
|
107
109
|
|
|
@@ -203,27 +205,51 @@ class DensityFlow(nn.Module):
|
|
|
203
205
|
self.cell_factor_effect = nn.ModuleList()
|
|
204
206
|
for i in np.arange(self.cell_factor_size):
|
|
205
207
|
if self.use_bias[i]:
|
|
206
|
-
self.
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
208
|
+
if self.turn_off_cell_specific:
|
|
209
|
+
self.cell_factor_effect.append(MLP(
|
|
210
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
211
|
+
activation=activate_fct,
|
|
212
|
+
output_activation=None,
|
|
213
|
+
post_layer_fct=post_layer_fct,
|
|
214
|
+
post_act_fct=post_act_fct,
|
|
215
|
+
allow_broadcast=self.allow_broadcast,
|
|
216
|
+
use_cuda=self.use_cuda,
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
self.cell_factor_effect.append(MLP(
|
|
221
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
222
|
+
activation=activate_fct,
|
|
223
|
+
output_activation=None,
|
|
224
|
+
post_layer_fct=post_layer_fct,
|
|
225
|
+
post_act_fct=post_act_fct,
|
|
226
|
+
allow_broadcast=self.allow_broadcast,
|
|
227
|
+
use_cuda=self.use_cuda,
|
|
228
|
+
)
|
|
214
229
|
)
|
|
215
|
-
)
|
|
216
230
|
else:
|
|
217
|
-
self.
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
231
|
+
if self.turn_off_cell_specific:
|
|
232
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
233
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
234
|
+
activation=activate_fct,
|
|
235
|
+
output_activation=None,
|
|
236
|
+
post_layer_fct=post_layer_fct,
|
|
237
|
+
post_act_fct=post_act_fct,
|
|
238
|
+
allow_broadcast=self.allow_broadcast,
|
|
239
|
+
use_cuda=self.use_cuda,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
244
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
245
|
+
activation=activate_fct,
|
|
246
|
+
output_activation=None,
|
|
247
|
+
post_layer_fct=post_layer_fct,
|
|
248
|
+
post_act_fct=post_act_fct,
|
|
249
|
+
allow_broadcast=self.allow_broadcast,
|
|
250
|
+
use_cuda=self.use_cuda,
|
|
251
|
+
)
|
|
225
252
|
)
|
|
226
|
-
)
|
|
227
253
|
|
|
228
254
|
self.decoder_concentrate = MLP(
|
|
229
255
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
@@ -676,9 +702,17 @@ class DensityFlow(nn.Module):
|
|
|
676
702
|
zus = None
|
|
677
703
|
for i in np.arange(self.cell_factor_size):
|
|
678
704
|
if i==0:
|
|
679
|
-
|
|
705
|
+
#if self.turn_off_cell_specific:
|
|
706
|
+
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
707
|
+
#else:
|
|
708
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
709
|
+
zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
680
710
|
else:
|
|
681
|
-
|
|
711
|
+
#if self.turn_off_cell_specific:
|
|
712
|
+
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
713
|
+
#else:
|
|
714
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
715
|
+
zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
682
716
|
return zus
|
|
683
717
|
|
|
684
718
|
def _get_codebook_identity(self):
|
|
@@ -696,7 +730,7 @@ class DensityFlow(nn.Module):
|
|
|
696
730
|
"""
|
|
697
731
|
Return the mean part of metacell codebook
|
|
698
732
|
"""
|
|
699
|
-
cb = self.
|
|
733
|
+
cb = self._get_codebook()
|
|
700
734
|
cb = tensor_to_numpy(cb)
|
|
701
735
|
return cb
|
|
702
736
|
|
|
@@ -820,13 +854,15 @@ class DensityFlow(nn.Module):
|
|
|
820
854
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
821
855
|
|
|
822
856
|
# factor effect of xs
|
|
823
|
-
dzs0 = self.get_cell_response(
|
|
857
|
+
dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
824
858
|
|
|
825
859
|
# perturbation effect
|
|
826
860
|
ps = np.ones_like(us_i)
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
861
|
+
if np.sum(np.abs(ps-us_i))>=1:
|
|
862
|
+
dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
863
|
+
zs = zs + dzs0 + dzs
|
|
864
|
+
else:
|
|
865
|
+
zs = zs + dzs0
|
|
830
866
|
|
|
831
867
|
if library_sizes is None:
|
|
832
868
|
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
@@ -840,47 +876,48 @@ class DensityFlow(nn.Module):
|
|
|
840
876
|
|
|
841
877
|
return counts, zs
|
|
842
878
|
|
|
843
|
-
def _cell_response(self,
|
|
879
|
+
def _cell_response(self, zs, perturb_idx, perturb):
|
|
844
880
|
#zns,_ = self.encoder_zn(xs)
|
|
845
|
-
zns,_ = self._get_basal_embedding(xs)
|
|
881
|
+
#zns,_ = self._get_basal_embedding(xs)
|
|
882
|
+
zns = zs
|
|
846
883
|
if perturb.ndim==2:
|
|
847
|
-
|
|
884
|
+
if self.turn_off_cell_specific:
|
|
885
|
+
ms = self.cell_factor_effect[perturb_idx](perturb)
|
|
886
|
+
else:
|
|
887
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
|
|
848
888
|
else:
|
|
849
|
-
|
|
889
|
+
if self.turn_off_cell_specific:
|
|
890
|
+
ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
|
|
891
|
+
else:
|
|
892
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
|
|
850
893
|
|
|
851
894
|
return ms
|
|
852
895
|
|
|
853
896
|
def get_cell_response(self,
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
897
|
+
zs,
|
|
898
|
+
perturb_idx,
|
|
899
|
+
perturb_us,
|
|
857
900
|
batch_size: int = 1024):
|
|
858
901
|
"""
|
|
859
902
|
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
860
903
|
|
|
861
904
|
"""
|
|
862
|
-
xs = self.preprocess(xs)
|
|
863
|
-
|
|
864
|
-
ps = convert_to_tensor(
|
|
865
|
-
dataset = CustomDataset2(
|
|
905
|
+
#xs = self.preprocess(xs)
|
|
906
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
907
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
908
|
+
dataset = CustomDataset2(zs,ps)
|
|
866
909
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
867
910
|
|
|
868
911
|
Z = []
|
|
869
912
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
870
|
-
for
|
|
871
|
-
zns = self._cell_response(
|
|
913
|
+
for Z_batch, P_batch, _ in dataloader:
|
|
914
|
+
zns = self._cell_response(Z_batch, perturb_idx, P_batch)
|
|
872
915
|
Z.append(tensor_to_numpy(zns))
|
|
873
916
|
pbar.update(1)
|
|
874
917
|
|
|
875
918
|
Z = np.concatenate(Z)
|
|
876
919
|
return Z
|
|
877
920
|
|
|
878
|
-
def get_metacell_response(self, factor_idx, perturb):
|
|
879
|
-
zs = self._get_codebook()
|
|
880
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
881
|
-
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
882
|
-
return tensor_to_numpy(ms)
|
|
883
|
-
|
|
884
921
|
def _get_expression_response(self, delta_zs):
|
|
885
922
|
return self.decoder_concentrate(delta_zs)
|
|
886
923
|
|
|
@@ -905,7 +942,7 @@ class DensityFlow(nn.Module):
|
|
|
905
942
|
R = np.concatenate(R)
|
|
906
943
|
return R
|
|
907
944
|
|
|
908
|
-
def _count(self,concentrate, library_size=None):
|
|
945
|
+
def _count(self, concentrate, library_size=None):
|
|
909
946
|
if self.loss_func == 'bernoulli':
|
|
910
947
|
#counts = self.sigmoid(concentrate)
|
|
911
948
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
@@ -913,28 +950,17 @@ class DensityFlow(nn.Module):
|
|
|
913
950
|
rate = concentrate.exp()
|
|
914
951
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
915
952
|
counts = theta * library_size
|
|
916
|
-
#counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
917
|
-
return counts
|
|
918
|
-
|
|
919
|
-
def _count_sample(self,concentrate):
|
|
920
|
-
if self.loss_func == 'bernoulli':
|
|
921
|
-
logits = concentrate
|
|
922
|
-
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
923
|
-
else:
|
|
924
|
-
counts = self._count(concentrate=concentrate)
|
|
925
|
-
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
926
953
|
return counts
|
|
927
954
|
|
|
928
955
|
def get_counts(self, zs, library_sizes,
|
|
929
|
-
batch_size: int = 1024
|
|
930
|
-
use_sampler: bool = False):
|
|
956
|
+
batch_size: int = 1024):
|
|
931
957
|
|
|
932
958
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
933
959
|
|
|
934
960
|
if type(library_sizes) == list:
|
|
935
|
-
library_sizes = np.array(library_sizes).
|
|
961
|
+
library_sizes = np.array(library_sizes).reshape(-1,1)
|
|
936
962
|
elif len(library_sizes.shape)==1:
|
|
937
|
-
library_sizes = library_sizes.
|
|
963
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
938
964
|
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
939
965
|
|
|
940
966
|
dataset = CustomDataset2(zs,ls)
|
|
@@ -944,10 +970,7 @@ class DensityFlow(nn.Module):
|
|
|
944
970
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
945
971
|
for Z_batch, L_batch, _ in dataloader:
|
|
946
972
|
concentrate = self._get_expression_response(Z_batch)
|
|
947
|
-
|
|
948
|
-
counts = self._count_sample(concentrate)
|
|
949
|
-
else:
|
|
950
|
-
counts = self._count(concentrate, L_batch)
|
|
973
|
+
counts = self._count(concentrate, L_batch)
|
|
951
974
|
E.append(tensor_to_numpy(counts))
|
|
952
975
|
pbar.update(1)
|
|
953
976
|
|
|
@@ -1355,5 +1378,4 @@ def main():
|
|
|
1355
1378
|
|
|
1356
1379
|
|
|
1357
1380
|
if __name__ == "__main__":
|
|
1358
|
-
|
|
1359
1381
|
main()
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
3
4
|
from numba import njit
|
|
4
5
|
from itertools import chain
|
|
5
6
|
from joblib import Parallel, delayed
|
|
@@ -8,6 +9,8 @@ from typing import Literal
|
|
|
8
9
|
class LabelMatrix:
|
|
9
10
|
def __init__(self):
|
|
10
11
|
self.labels_ = None
|
|
12
|
+
self.control_label = None
|
|
13
|
+
self.sep_pattern = None
|
|
11
14
|
|
|
12
15
|
def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
13
16
|
if speedup=='none':
|
|
@@ -24,8 +27,31 @@ class LabelMatrix:
|
|
|
24
27
|
mat = np.delete(mat, idx, axis=1)
|
|
25
28
|
self.labels_ = np.delete(self.labels_, idx)
|
|
26
29
|
|
|
30
|
+
self.control_label = control_label
|
|
31
|
+
self.sep_pattern=sep_pattern
|
|
32
|
+
|
|
27
33
|
return mat
|
|
28
|
-
|
|
34
|
+
|
|
35
|
+
def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
|
|
36
|
+
sep_pattern = self.sep_pattern
|
|
37
|
+
if speedup=='none':
|
|
38
|
+
mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
39
|
+
elif speedup=='vectorize':
|
|
40
|
+
mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
41
|
+
elif speedup=='parallel':
|
|
42
|
+
mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
43
|
+
|
|
44
|
+
mat_df = pd.DataFrame(mat, columns=labels_)
|
|
45
|
+
|
|
46
|
+
labels_valid = [x for x in labels_ if x in self.labels_]
|
|
47
|
+
mat_df = mat_df[labels_valid]
|
|
48
|
+
|
|
49
|
+
mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
|
|
50
|
+
mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
|
|
51
|
+
mat_valid_df[labels_valid] = mat_df
|
|
52
|
+
|
|
53
|
+
return mat_valid_df.values
|
|
54
|
+
|
|
29
55
|
def inverse_transform(self, matrix):
|
|
30
56
|
return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
|
|
31
57
|
|
|
@@ -240,9 +240,15 @@ class ZeroBiasMLP(nn.Module):
|
|
|
240
240
|
y = self.mlp(x)
|
|
241
241
|
mask = torch.zeros_like(y)
|
|
242
242
|
if len(y.shape)==2:
|
|
243
|
-
|
|
243
|
+
if len(x)>2:
|
|
244
|
+
mask[x[1][:,0]>0,:] = 1
|
|
245
|
+
else:
|
|
246
|
+
mask[x[:,0]>0,:] = 1
|
|
244
247
|
elif len(y.shape)==3:
|
|
245
|
-
|
|
248
|
+
if len(x)>1:
|
|
249
|
+
mask[:,x[1][:,0]>0,:] = 1
|
|
250
|
+
else:
|
|
251
|
+
mask[:,x[:,0]>0,:] = 1
|
|
246
252
|
return y*mask
|
|
247
253
|
|
|
248
254
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|