SURE-tools 2.0.10__tar.gz → 2.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.0.10 → sure_tools-2.1.0}/PKG-INFO +1 -1
- sure_tools-2.0.10/SURE/SURE.py → sure_tools-2.1.0/SURE/PerturbFlow.py +19 -19
- sure_tools-2.1.0/SURE/SURE.py +1253 -0
- sure_tools-2.1.0/SURE/__init__.py +11 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE_tools.egg-info/SOURCES.txt +1 -0
- sure_tools-2.1.0/SURE_tools.egg-info/entry_points.txt +3 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/setup.py +3 -2
- sure_tools-2.0.10/SURE/__init__.py +0 -10
- sure_tools-2.0.10/SURE_tools.egg-info/entry_points.txt +0 -2
- {sure_tools-2.0.10 → sure_tools-2.1.0}/LICENSE +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/README.md +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/atac/utils.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/utils/queue.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/utils/utils.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.0}/setup.cfg +0 -0
|
@@ -54,7 +54,7 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class PerturbFlow(nn.Module):
|
|
58
58
|
"""SUccinct REpresentation of single-omics cells
|
|
59
59
|
|
|
60
60
|
Parameters
|
|
@@ -330,7 +330,7 @@ class SURE(nn.Module):
|
|
|
330
330
|
return xs
|
|
331
331
|
|
|
332
332
|
def model1(self, xs):
|
|
333
|
-
pyro.module('
|
|
333
|
+
pyro.module('PerturbFlow', self)
|
|
334
334
|
|
|
335
335
|
eps = torch.finfo(xs.dtype).eps
|
|
336
336
|
batch_size = xs.size(0)
|
|
@@ -407,7 +407,7 @@ class SURE(nn.Module):
|
|
|
407
407
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
408
408
|
|
|
409
409
|
def model2(self, xs, us=None):
|
|
410
|
-
pyro.module('
|
|
410
|
+
pyro.module('PerturbFlow', self)
|
|
411
411
|
|
|
412
412
|
eps = torch.finfo(xs.dtype).eps
|
|
413
413
|
batch_size = xs.size(0)
|
|
@@ -495,7 +495,7 @@ class SURE(nn.Module):
|
|
|
495
495
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
496
496
|
|
|
497
497
|
def model3(self, xs, ys, embeds=None):
|
|
498
|
-
pyro.module('
|
|
498
|
+
pyro.module('PerturbFlow', self)
|
|
499
499
|
|
|
500
500
|
eps = torch.finfo(xs.dtype).eps
|
|
501
501
|
batch_size = xs.size(0)
|
|
@@ -587,7 +587,7 @@ class SURE(nn.Module):
|
|
|
587
587
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
588
588
|
|
|
589
589
|
def model4(self, xs, us, ys, embeds=None):
|
|
590
|
-
pyro.module('
|
|
590
|
+
pyro.module('PerturbFlow', self)
|
|
591
591
|
|
|
592
592
|
eps = torch.finfo(xs.dtype).eps
|
|
593
593
|
batch_size = xs.size(0)
|
|
@@ -813,7 +813,7 @@ class SURE(nn.Module):
|
|
|
813
813
|
A = np.concatenate(A)
|
|
814
814
|
return A
|
|
815
815
|
|
|
816
|
-
def
|
|
816
|
+
def _cell_state_response(self, xs, factor_idx, perturb):
|
|
817
817
|
zns,_ = self.encoder_zn(xs)
|
|
818
818
|
if type(factor_idx) == str:
|
|
819
819
|
factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
|
|
@@ -825,7 +825,7 @@ class SURE(nn.Module):
|
|
|
825
825
|
|
|
826
826
|
return ms
|
|
827
827
|
|
|
828
|
-
def
|
|
828
|
+
def get_cell_state_response(self,
|
|
829
829
|
xs,
|
|
830
830
|
factor_idx,
|
|
831
831
|
perturb,
|
|
@@ -843,14 +843,14 @@ class SURE(nn.Module):
|
|
|
843
843
|
Z = []
|
|
844
844
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
845
845
|
for X_batch, P_batch, _ in dataloader:
|
|
846
|
-
zns = self.
|
|
846
|
+
zns = self._cell_state_response(X_batch, factor_idx, P_batch)
|
|
847
847
|
Z.append(tensor_to_numpy(zns))
|
|
848
848
|
pbar.update(1)
|
|
849
849
|
|
|
850
850
|
Z = np.concatenate(Z)
|
|
851
851
|
return Z
|
|
852
852
|
|
|
853
|
-
def
|
|
853
|
+
def get_metacell_response(self, factor_idx, perturb):
|
|
854
854
|
zs = self._get_codebook()
|
|
855
855
|
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
856
856
|
|
|
@@ -860,10 +860,10 @@ class SURE(nn.Module):
|
|
|
860
860
|
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
861
861
|
return tensor_to_numpy(ms)
|
|
862
862
|
|
|
863
|
-
def
|
|
863
|
+
def _get_expression_response(self, delta_zs):
|
|
864
864
|
return self.decoder_concentrate(delta_zs)
|
|
865
865
|
|
|
866
|
-
def
|
|
866
|
+
def get_expression_response(self,
|
|
867
867
|
delta_zs,
|
|
868
868
|
batch_size: int = 1024):
|
|
869
869
|
"""
|
|
@@ -911,7 +911,7 @@ class SURE(nn.Module):
|
|
|
911
911
|
threshold: int = 0,
|
|
912
912
|
use_jax: bool = False):
|
|
913
913
|
"""
|
|
914
|
-
Train the
|
|
914
|
+
Train the PerturbFlow model.
|
|
915
915
|
|
|
916
916
|
Parameters
|
|
917
917
|
----------
|
|
@@ -937,7 +937,7 @@ class SURE(nn.Module):
|
|
|
937
937
|
Parameter for optimization.
|
|
938
938
|
use_jax
|
|
939
939
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
940
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
940
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
|
|
941
941
|
"""
|
|
942
942
|
xs = self.preprocess(xs, threshold=threshold)
|
|
943
943
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1055,12 +1055,12 @@ class SURE(nn.Module):
|
|
|
1055
1055
|
|
|
1056
1056
|
|
|
1057
1057
|
EXAMPLE_RUN = (
|
|
1058
|
-
"example run:
|
|
1058
|
+
"example run: PerturbFlow --help"
|
|
1059
1059
|
)
|
|
1060
1060
|
|
|
1061
1061
|
def parse_args():
|
|
1062
1062
|
parser = argparse.ArgumentParser(
|
|
1063
|
-
description="
|
|
1063
|
+
description="PerturbFlow\n{}".format(EXAMPLE_RUN))
|
|
1064
1064
|
|
|
1065
1065
|
parser.add_argument(
|
|
1066
1066
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1278,7 +1278,7 @@ def main():
|
|
|
1278
1278
|
latent_dist = args.z_dist
|
|
1279
1279
|
|
|
1280
1280
|
###########################################
|
|
1281
|
-
|
|
1281
|
+
perturbflow = PerturbFlow(
|
|
1282
1282
|
input_size=input_size,
|
|
1283
1283
|
cell_factor_size=cell_factor_size,
|
|
1284
1284
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1301,7 +1301,7 @@ def main():
|
|
|
1301
1301
|
dtype=dtype,
|
|
1302
1302
|
)
|
|
1303
1303
|
|
|
1304
|
-
|
|
1304
|
+
perturbflow.fit(xs, us=us,
|
|
1305
1305
|
num_epochs=args.num_epochs,
|
|
1306
1306
|
learning_rate=args.learning_rate,
|
|
1307
1307
|
batch_size=args.batch_size,
|
|
@@ -1313,9 +1313,9 @@ def main():
|
|
|
1313
1313
|
|
|
1314
1314
|
if args.save_model is not None:
|
|
1315
1315
|
if args.save_model.endswith('gz'):
|
|
1316
|
-
|
|
1316
|
+
PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
|
|
1317
1317
|
else:
|
|
1318
|
-
|
|
1318
|
+
PerturbFlow.save_model(perturbflow, args.save_model)
|
|
1319
1319
|
|
|
1320
1320
|
|
|
1321
1321
|
|