SURE-tools 2.0.5__tar.gz → 2.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.0.5 → sure_tools-2.0.7}/PKG-INFO +1 -1
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/SURE.py +18 -5
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/flow/plot_quiver.py +1 -1
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.0.5 → sure_tools-2.0.7}/setup.py +1 -1
- {sure_tools-2.0.5 → sure_tools-2.0.7}/LICENSE +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/README.md +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/__init__.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/atac/utils.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/utils/queue.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE/utils/utils.py +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.0.5 → sure_tools-2.0.7}/setup.cfg +0 -0
|
@@ -11,7 +11,7 @@ from torch.distributions import constraints
|
|
|
11
11
|
from torch.distributions.transforms import SoftmaxTransform
|
|
12
12
|
|
|
13
13
|
from .utils.custom_mlp import MLP, Exp
|
|
14
|
-
from .utils.utils import CustomDataset,
|
|
14
|
+
from .utils.utils import CustomDataset, CustomDataset2, CustomDataset4, tensor_to_numpy, convert_to_tensor
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
import os
|
|
@@ -814,11 +814,13 @@ class SURE(nn.Module):
|
|
|
814
814
|
return A
|
|
815
815
|
|
|
816
816
|
def _cell_move(self, xs, factor_idx, perturb):
|
|
817
|
-
zns = self.encoder_zn(xs)
|
|
817
|
+
zns,_ = self.encoder_zn(xs)
|
|
818
818
|
if type(factor_idx) == str:
|
|
819
819
|
factor_idx = np.where(self.cell_factor_names==factor_idx)
|
|
820
820
|
|
|
821
821
|
if perturb.ndim==2:
|
|
822
|
+
print(factor_idx)
|
|
823
|
+
print(type(perturb))
|
|
822
824
|
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
823
825
|
else:
|
|
824
826
|
ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
|
|
@@ -836,19 +838,30 @@ class SURE(nn.Module):
|
|
|
836
838
|
"""
|
|
837
839
|
xs = self.preprocess(xs)
|
|
838
840
|
xs = convert_to_tensor(xs, device=self.get_device())
|
|
839
|
-
|
|
841
|
+
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
842
|
+
dataset = CustomDataset2(xs,ps)
|
|
840
843
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
841
844
|
|
|
842
845
|
Z = []
|
|
843
846
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
844
|
-
for X_batch, _ in dataloader:
|
|
845
|
-
zns = self._cell_move(X_batch, factor_idx,
|
|
847
|
+
for X_batch, P_batch, _ in dataloader:
|
|
848
|
+
zns = self._cell_move(X_batch, factor_idx, P_batch)
|
|
846
849
|
Z.append(tensor_to_numpy(zns))
|
|
847
850
|
pbar.update(1)
|
|
848
851
|
|
|
849
852
|
Z = np.concatenate(Z)
|
|
850
853
|
return Z
|
|
851
854
|
|
|
855
|
+
def get_metacell_move(self, factor_idx, perturb):
|
|
856
|
+
zs = self._get_codebook()
|
|
857
|
+
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
858
|
+
|
|
859
|
+
if type(factor_idx) == str:
|
|
860
|
+
factor_idx = np.where(self.cell_factor_names==factor_idx)
|
|
861
|
+
|
|
862
|
+
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
863
|
+
return tensor_to_numpy(ms)
|
|
864
|
+
|
|
852
865
|
def preprocess(self, xs, threshold=0):
|
|
853
866
|
if self.loss_func == 'bernoulli':
|
|
854
867
|
ad = sc.AnnData(xs)
|
|
@@ -32,7 +32,7 @@ def plot_quiver(z_points, delta_z, method='umap', figsize=(6,4), dpi=200):
|
|
|
32
32
|
sc.pp.neighbors(ad)
|
|
33
33
|
sc.tl.umap(ad)
|
|
34
34
|
z_2d = ad[:z_points.shape[0]].obsm['X_umap']
|
|
35
|
-
delta_z_2d = ad[z_points.shape[0]:] - z_2d
|
|
35
|
+
delta_z_2d = ad[z_points.shape[0]:].obsm['X_umap'] - z_2d
|
|
36
36
|
dim_names = ['UMAP1', 'UMAP2']
|
|
37
37
|
|
|
38
38
|
# 绘制quiver图
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|