SURE-tools 2.0.9__tar.gz → 2.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.0.9 → sure_tools-2.1.0}/PKG-INFO +1 -1
- sure_tools-2.0.9/SURE/SURE.py → sure_tools-2.1.0/SURE/PerturbFlow.py +41 -17
- sure_tools-2.1.0/SURE/SURE.py +1253 -0
- sure_tools-2.1.0/SURE/__init__.py +11 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE_tools.egg-info/SOURCES.txt +1 -0
- sure_tools-2.1.0/SURE_tools.egg-info/entry_points.txt +3 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/setup.py +3 -2
- sure_tools-2.0.9/SURE/__init__.py +0 -10
- sure_tools-2.0.9/SURE_tools.egg-info/entry_points.txt +0 -2
- {sure_tools-2.0.9 → sure_tools-2.1.0}/LICENSE +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/README.md +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/atac/utils.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/utils/queue.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/utils/utils.py +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.0.9 → sure_tools-2.1.0}/setup.cfg +0 -0
|
@@ -54,7 +54,7 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class PerturbFlow(nn.Module):
|
|
58
58
|
"""SUccinct REpresentation of single-omics cells
|
|
59
59
|
|
|
60
60
|
Parameters
|
|
@@ -330,7 +330,7 @@ class SURE(nn.Module):
|
|
|
330
330
|
return xs
|
|
331
331
|
|
|
332
332
|
def model1(self, xs):
|
|
333
|
-
pyro.module('
|
|
333
|
+
pyro.module('PerturbFlow', self)
|
|
334
334
|
|
|
335
335
|
eps = torch.finfo(xs.dtype).eps
|
|
336
336
|
batch_size = xs.size(0)
|
|
@@ -407,7 +407,7 @@ class SURE(nn.Module):
|
|
|
407
407
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
408
408
|
|
|
409
409
|
def model2(self, xs, us=None):
|
|
410
|
-
pyro.module('
|
|
410
|
+
pyro.module('PerturbFlow', self)
|
|
411
411
|
|
|
412
412
|
eps = torch.finfo(xs.dtype).eps
|
|
413
413
|
batch_size = xs.size(0)
|
|
@@ -495,7 +495,7 @@ class SURE(nn.Module):
|
|
|
495
495
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
496
496
|
|
|
497
497
|
def model3(self, xs, ys, embeds=None):
|
|
498
|
-
pyro.module('
|
|
498
|
+
pyro.module('PerturbFlow', self)
|
|
499
499
|
|
|
500
500
|
eps = torch.finfo(xs.dtype).eps
|
|
501
501
|
batch_size = xs.size(0)
|
|
@@ -587,7 +587,7 @@ class SURE(nn.Module):
|
|
|
587
587
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
588
588
|
|
|
589
589
|
def model4(self, xs, us, ys, embeds=None):
|
|
590
|
-
pyro.module('
|
|
590
|
+
pyro.module('PerturbFlow', self)
|
|
591
591
|
|
|
592
592
|
eps = torch.finfo(xs.dtype).eps
|
|
593
593
|
batch_size = xs.size(0)
|
|
@@ -813,7 +813,7 @@ class SURE(nn.Module):
|
|
|
813
813
|
A = np.concatenate(A)
|
|
814
814
|
return A
|
|
815
815
|
|
|
816
|
-
def
|
|
816
|
+
def _cell_state_response(self, xs, factor_idx, perturb):
|
|
817
817
|
zns,_ = self.encoder_zn(xs)
|
|
818
818
|
if type(factor_idx) == str:
|
|
819
819
|
factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
|
|
@@ -825,7 +825,7 @@ class SURE(nn.Module):
|
|
|
825
825
|
|
|
826
826
|
return ms
|
|
827
827
|
|
|
828
|
-
def
|
|
828
|
+
def get_cell_state_response(self,
|
|
829
829
|
xs,
|
|
830
830
|
factor_idx,
|
|
831
831
|
perturb,
|
|
@@ -843,14 +843,14 @@ class SURE(nn.Module):
|
|
|
843
843
|
Z = []
|
|
844
844
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
845
845
|
for X_batch, P_batch, _ in dataloader:
|
|
846
|
-
zns = self.
|
|
846
|
+
zns = self._cell_state_response(X_batch, factor_idx, P_batch)
|
|
847
847
|
Z.append(tensor_to_numpy(zns))
|
|
848
848
|
pbar.update(1)
|
|
849
849
|
|
|
850
850
|
Z = np.concatenate(Z)
|
|
851
851
|
return Z
|
|
852
852
|
|
|
853
|
-
def
|
|
853
|
+
def get_metacell_response(self, factor_idx, perturb):
|
|
854
854
|
zs = self._get_codebook()
|
|
855
855
|
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
856
856
|
|
|
@@ -860,6 +860,30 @@ class SURE(nn.Module):
|
|
|
860
860
|
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
861
861
|
return tensor_to_numpy(ms)
|
|
862
862
|
|
|
863
|
+
def _get_expression_response(self, delta_zs):
|
|
864
|
+
return self.decoder_concentrate(delta_zs)
|
|
865
|
+
|
|
866
|
+
def get_expression_response(self,
|
|
867
|
+
delta_zs,
|
|
868
|
+
batch_size: int = 1024):
|
|
869
|
+
"""
|
|
870
|
+
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
871
|
+
|
|
872
|
+
"""
|
|
873
|
+
delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
|
|
874
|
+
dataset = CustomDataset(delta_zs)
|
|
875
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
876
|
+
|
|
877
|
+
R = []
|
|
878
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
879
|
+
for delta_Z_batch, _ in dataloader:
|
|
880
|
+
r = self._cell_move(delta_Z_batch)
|
|
881
|
+
R.append(tensor_to_numpy(r))
|
|
882
|
+
pbar.update(1)
|
|
883
|
+
|
|
884
|
+
R = np.concatenate(R)
|
|
885
|
+
return R
|
|
886
|
+
|
|
863
887
|
def preprocess(self, xs, threshold=0):
|
|
864
888
|
if self.loss_func == 'bernoulli':
|
|
865
889
|
ad = sc.AnnData(xs)
|
|
@@ -887,7 +911,7 @@ class SURE(nn.Module):
|
|
|
887
911
|
threshold: int = 0,
|
|
888
912
|
use_jax: bool = False):
|
|
889
913
|
"""
|
|
890
|
-
Train the
|
|
914
|
+
Train the PerturbFlow model.
|
|
891
915
|
|
|
892
916
|
Parameters
|
|
893
917
|
----------
|
|
@@ -913,7 +937,7 @@ class SURE(nn.Module):
|
|
|
913
937
|
Parameter for optimization.
|
|
914
938
|
use_jax
|
|
915
939
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
916
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
940
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
|
|
917
941
|
"""
|
|
918
942
|
xs = self.preprocess(xs, threshold=threshold)
|
|
919
943
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1031,12 +1055,12 @@ class SURE(nn.Module):
|
|
|
1031
1055
|
|
|
1032
1056
|
|
|
1033
1057
|
EXAMPLE_RUN = (
|
|
1034
|
-
"example run:
|
|
1058
|
+
"example run: PerturbFlow --help"
|
|
1035
1059
|
)
|
|
1036
1060
|
|
|
1037
1061
|
def parse_args():
|
|
1038
1062
|
parser = argparse.ArgumentParser(
|
|
1039
|
-
description="
|
|
1063
|
+
description="PerturbFlow\n{}".format(EXAMPLE_RUN))
|
|
1040
1064
|
|
|
1041
1065
|
parser.add_argument(
|
|
1042
1066
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1254,7 +1278,7 @@ def main():
|
|
|
1254
1278
|
latent_dist = args.z_dist
|
|
1255
1279
|
|
|
1256
1280
|
###########################################
|
|
1257
|
-
|
|
1281
|
+
perturbflow = PerturbFlow(
|
|
1258
1282
|
input_size=input_size,
|
|
1259
1283
|
cell_factor_size=cell_factor_size,
|
|
1260
1284
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1277,7 +1301,7 @@ def main():
|
|
|
1277
1301
|
dtype=dtype,
|
|
1278
1302
|
)
|
|
1279
1303
|
|
|
1280
|
-
|
|
1304
|
+
perturbflow.fit(xs, us=us,
|
|
1281
1305
|
num_epochs=args.num_epochs,
|
|
1282
1306
|
learning_rate=args.learning_rate,
|
|
1283
1307
|
batch_size=args.batch_size,
|
|
@@ -1289,9 +1313,9 @@ def main():
|
|
|
1289
1313
|
|
|
1290
1314
|
if args.save_model is not None:
|
|
1291
1315
|
if args.save_model.endswith('gz'):
|
|
1292
|
-
|
|
1316
|
+
PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
|
|
1293
1317
|
else:
|
|
1294
|
-
|
|
1318
|
+
PerturbFlow.save_model(perturbflow, args.save_model)
|
|
1295
1319
|
|
|
1296
1320
|
|
|
1297
1321
|
|