SURE-tools 2.2.7__tar.gz → 2.2.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-2.2.7 → sure_tools-2.2.15}/PKG-INFO +1 -1
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/DensityFlow.py +18 -39
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/perturb/perturb.py +27 -1
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.2.7 → sure_tools-2.2.15}/setup.py +1 -1
- {sure_tools-2.2.7 → sure_tools-2.2.15}/LICENSE +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/README.md +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/SURE.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/__init__.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/atac/utils.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/utils/queue.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/utils/utils.py +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.2.7 → sure_tools-2.2.15}/setup.cfg +0 -0
|
@@ -696,7 +696,7 @@ class DensityFlow(nn.Module):
|
|
|
696
696
|
"""
|
|
697
697
|
Return the mean part of metacell codebook
|
|
698
698
|
"""
|
|
699
|
-
cb = self.
|
|
699
|
+
cb = self._get_codebook()
|
|
700
700
|
cb = tensor_to_numpy(cb)
|
|
701
701
|
return cb
|
|
702
702
|
|
|
@@ -842,47 +842,42 @@ class DensityFlow(nn.Module):
|
|
|
842
842
|
|
|
843
843
|
return counts, zs
|
|
844
844
|
|
|
845
|
-
def _cell_response(self,
|
|
845
|
+
def _cell_response(self, zs, perturb_idx, perturb):
|
|
846
846
|
#zns,_ = self.encoder_zn(xs)
|
|
847
|
-
zns,_ = self._get_basal_embedding(xs)
|
|
847
|
+
#zns,_ = self._get_basal_embedding(xs)
|
|
848
|
+
zns = zs
|
|
848
849
|
if perturb.ndim==2:
|
|
849
|
-
ms = self.cell_factor_effect[
|
|
850
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
|
|
850
851
|
else:
|
|
851
|
-
ms = self.cell_factor_effect[
|
|
852
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
|
|
852
853
|
|
|
853
854
|
return ms
|
|
854
855
|
|
|
855
856
|
def get_cell_response(self,
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
857
|
+
zs,
|
|
858
|
+
perturb_idx,
|
|
859
|
+
perturb_us,
|
|
859
860
|
batch_size: int = 1024):
|
|
860
861
|
"""
|
|
861
862
|
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
862
863
|
|
|
863
864
|
"""
|
|
864
|
-
xs = self.preprocess(xs)
|
|
865
|
-
|
|
866
|
-
ps = convert_to_tensor(
|
|
867
|
-
dataset = CustomDataset2(
|
|
865
|
+
#xs = self.preprocess(xs)
|
|
866
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
867
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
868
|
+
dataset = CustomDataset2(zs,ps)
|
|
868
869
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
869
870
|
|
|
870
871
|
Z = []
|
|
871
872
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
872
|
-
for
|
|
873
|
-
zns = self._cell_response(
|
|
873
|
+
for Z_batch, P_batch, _ in dataloader:
|
|
874
|
+
zns = self._cell_response(Z_batch, perturb_idx, P_batch)
|
|
874
875
|
Z.append(tensor_to_numpy(zns))
|
|
875
876
|
pbar.update(1)
|
|
876
877
|
|
|
877
878
|
Z = np.concatenate(Z)
|
|
878
879
|
return Z
|
|
879
880
|
|
|
880
|
-
def get_metacell_response(self, factor_idx, perturb):
|
|
881
|
-
zs = self._get_codebook()
|
|
882
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
883
|
-
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
884
|
-
return tensor_to_numpy(ms)
|
|
885
|
-
|
|
886
881
|
def _get_expression_response(self, delta_zs):
|
|
887
882
|
return self.decoder_concentrate(delta_zs)
|
|
888
883
|
|
|
@@ -907,7 +902,7 @@ class DensityFlow(nn.Module):
|
|
|
907
902
|
R = np.concatenate(R)
|
|
908
903
|
return R
|
|
909
904
|
|
|
910
|
-
def _count(self,concentrate, library_size=None):
|
|
905
|
+
def _count(self, concentrate, library_size=None):
|
|
911
906
|
if self.loss_func == 'bernoulli':
|
|
912
907
|
#counts = self.sigmoid(concentrate)
|
|
913
908
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
@@ -917,18 +912,8 @@ class DensityFlow(nn.Module):
|
|
|
917
912
|
counts = theta * library_size
|
|
918
913
|
return counts
|
|
919
914
|
|
|
920
|
-
def _count_sample(self,concentrate):
|
|
921
|
-
if self.loss_func == 'bernoulli':
|
|
922
|
-
logits = concentrate
|
|
923
|
-
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
924
|
-
else:
|
|
925
|
-
counts = self._count(concentrate=concentrate)
|
|
926
|
-
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
927
|
-
return counts
|
|
928
|
-
|
|
929
915
|
def get_counts(self, zs, library_sizes,
|
|
930
|
-
batch_size: int = 1024
|
|
931
|
-
use_sampler: bool = False):
|
|
916
|
+
batch_size: int = 1024):
|
|
932
917
|
|
|
933
918
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
934
919
|
|
|
@@ -945,10 +930,7 @@ class DensityFlow(nn.Module):
|
|
|
945
930
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
946
931
|
for Z_batch, L_batch, _ in dataloader:
|
|
947
932
|
concentrate = self._get_expression_response(Z_batch)
|
|
948
|
-
|
|
949
|
-
counts = self._count_sample(concentrate)
|
|
950
|
-
else:
|
|
951
|
-
counts = self._count(concentrate, L_batch)
|
|
933
|
+
counts = self._count(concentrate, L_batch)
|
|
952
934
|
E.append(tensor_to_numpy(counts))
|
|
953
935
|
pbar.update(1)
|
|
954
936
|
|
|
@@ -1093,9 +1075,6 @@ class DensityFlow(nn.Module):
|
|
|
1093
1075
|
# Update progress bar
|
|
1094
1076
|
pbar.set_postfix({'loss': str_loss})
|
|
1095
1077
|
pbar.update(1)
|
|
1096
|
-
|
|
1097
|
-
if self.loss_func == 'negbinomial':
|
|
1098
|
-
self.inverse_dispersion = pyro.param("inverse_dispersion")
|
|
1099
1078
|
|
|
1100
1079
|
@classmethod
|
|
1101
1080
|
def save_model(cls, model, file_path, compression=False):
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
3
4
|
from numba import njit
|
|
4
5
|
from itertools import chain
|
|
5
6
|
from joblib import Parallel, delayed
|
|
@@ -8,6 +9,8 @@ from typing import Literal
|
|
|
8
9
|
class LabelMatrix:
|
|
9
10
|
def __init__(self):
|
|
10
11
|
self.labels_ = None
|
|
12
|
+
self.control_label = None
|
|
13
|
+
self.sep_pattern = None
|
|
11
14
|
|
|
12
15
|
def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
13
16
|
if speedup=='none':
|
|
@@ -24,8 +27,31 @@ class LabelMatrix:
|
|
|
24
27
|
mat = np.delete(mat, idx, axis=1)
|
|
25
28
|
self.labels_ = np.delete(self.labels_, idx)
|
|
26
29
|
|
|
30
|
+
self.control_label = control_label
|
|
31
|
+
self.sep_pattern=sep_pattern
|
|
32
|
+
|
|
27
33
|
return mat
|
|
28
|
-
|
|
34
|
+
|
|
35
|
+
def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
|
|
36
|
+
sep_pattern = self.sep_pattern
|
|
37
|
+
if speedup=='none':
|
|
38
|
+
mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
39
|
+
elif speedup=='vectorize':
|
|
40
|
+
mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
41
|
+
elif speedup=='parallel':
|
|
42
|
+
mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
43
|
+
|
|
44
|
+
mat_df = pd.DataFrame(mat, columns=labels_)
|
|
45
|
+
|
|
46
|
+
labels_valid = [x for x in labels_ if x in self.labels_]
|
|
47
|
+
mat_df = mat_df[labels_valid]
|
|
48
|
+
|
|
49
|
+
mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
|
|
50
|
+
mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
|
|
51
|
+
mat_valid_df[labels_valid] = mat_df
|
|
52
|
+
|
|
53
|
+
return mat_valid_df.values
|
|
54
|
+
|
|
29
55
|
def inverse_transform(self, matrix):
|
|
30
56
|
return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
|
|
31
57
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|