SURE-tools 2.0.10__tar.gz → 2.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-2.0.10 → sure_tools-2.1.1}/PKG-INFO +1 -1
- sure_tools-2.0.10/SURE/SURE.py → sure_tools-2.1.1/SURE/PerturbFlow.py +21 -30
- sure_tools-2.1.1/SURE/SURE.py +1236 -0
- sure_tools-2.1.1/SURE/__init__.py +11 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE_tools.egg-info/SOURCES.txt +1 -0
- sure_tools-2.1.1/SURE_tools.egg-info/entry_points.txt +3 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/setup.py +3 -2
- sure_tools-2.0.10/SURE/__init__.py +0 -10
- sure_tools-2.0.10/SURE_tools.egg-info/entry_points.txt +0 -2
- {sure_tools-2.0.10 → sure_tools-2.1.1}/LICENSE +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/README.md +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/atac/utils.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/utils/queue.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/utils/utils.py +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.0.10 → sure_tools-2.1.1}/setup.cfg +0 -0
|
@@ -54,7 +54,7 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class PerturbFlow(nn.Module):
|
|
58
58
|
"""SUccinct REpresentation of single-omics cells
|
|
59
59
|
|
|
60
60
|
Parameters
|
|
@@ -97,7 +97,6 @@ class SURE(nn.Module):
|
|
|
97
97
|
input_size: int,
|
|
98
98
|
codebook_size: int = 200,
|
|
99
99
|
cell_factor_size: int = 0,
|
|
100
|
-
cell_factor_names: list = None,
|
|
101
100
|
supervised_mode: bool = False,
|
|
102
101
|
z_dim: int = 10,
|
|
103
102
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
@@ -135,7 +134,6 @@ class SURE(nn.Module):
|
|
|
135
134
|
self.post_layer_fct = post_layer_fct
|
|
136
135
|
self.post_act_fct = post_act_fct
|
|
137
136
|
self.hidden_layer_activation = hidden_layer_activation
|
|
138
|
-
self.cell_factor_names = cell_factor_names
|
|
139
137
|
|
|
140
138
|
self.codebook_weights = None
|
|
141
139
|
|
|
@@ -330,7 +328,7 @@ class SURE(nn.Module):
|
|
|
330
328
|
return xs
|
|
331
329
|
|
|
332
330
|
def model1(self, xs):
|
|
333
|
-
pyro.module('
|
|
331
|
+
pyro.module('PerturbFlow', self)
|
|
334
332
|
|
|
335
333
|
eps = torch.finfo(xs.dtype).eps
|
|
336
334
|
batch_size = xs.size(0)
|
|
@@ -407,7 +405,7 @@ class SURE(nn.Module):
|
|
|
407
405
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
408
406
|
|
|
409
407
|
def model2(self, xs, us=None):
|
|
410
|
-
pyro.module('
|
|
408
|
+
pyro.module('PerturbFlow', self)
|
|
411
409
|
|
|
412
410
|
eps = torch.finfo(xs.dtype).eps
|
|
413
411
|
batch_size = xs.size(0)
|
|
@@ -495,7 +493,7 @@ class SURE(nn.Module):
|
|
|
495
493
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
496
494
|
|
|
497
495
|
def model3(self, xs, ys, embeds=None):
|
|
498
|
-
pyro.module('
|
|
496
|
+
pyro.module('PerturbFlow', self)
|
|
499
497
|
|
|
500
498
|
eps = torch.finfo(xs.dtype).eps
|
|
501
499
|
batch_size = xs.size(0)
|
|
@@ -587,7 +585,7 @@ class SURE(nn.Module):
|
|
|
587
585
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
588
586
|
|
|
589
587
|
def model4(self, xs, us, ys, embeds=None):
|
|
590
|
-
pyro.module('
|
|
588
|
+
pyro.module('PerturbFlow', self)
|
|
591
589
|
|
|
592
590
|
eps = torch.finfo(xs.dtype).eps
|
|
593
591
|
batch_size = xs.size(0)
|
|
@@ -813,11 +811,8 @@ class SURE(nn.Module):
|
|
|
813
811
|
A = np.concatenate(A)
|
|
814
812
|
return A
|
|
815
813
|
|
|
816
|
-
def
|
|
817
|
-
zns,_ = self.encoder_zn(xs)
|
|
818
|
-
if type(factor_idx) == str:
|
|
819
|
-
factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
|
|
820
|
-
|
|
814
|
+
def _cell_response(self, xs, factor_idx, perturb):
|
|
815
|
+
zns,_ = self.encoder_zn(xs)
|
|
821
816
|
if perturb.ndim==2:
|
|
822
817
|
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
823
818
|
else:
|
|
@@ -825,7 +820,7 @@ class SURE(nn.Module):
|
|
|
825
820
|
|
|
826
821
|
return ms
|
|
827
822
|
|
|
828
|
-
def
|
|
823
|
+
def get_cell_response(self,
|
|
829
824
|
xs,
|
|
830
825
|
factor_idx,
|
|
831
826
|
perturb,
|
|
@@ -843,27 +838,23 @@ class SURE(nn.Module):
|
|
|
843
838
|
Z = []
|
|
844
839
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
845
840
|
for X_batch, P_batch, _ in dataloader:
|
|
846
|
-
zns = self.
|
|
841
|
+
zns = self._cell_response(X_batch, factor_idx, P_batch)
|
|
847
842
|
Z.append(tensor_to_numpy(zns))
|
|
848
843
|
pbar.update(1)
|
|
849
844
|
|
|
850
845
|
Z = np.concatenate(Z)
|
|
851
846
|
return Z
|
|
852
847
|
|
|
853
|
-
def
|
|
848
|
+
def get_metacell_response(self, factor_idx, perturb):
|
|
854
849
|
zs = self._get_codebook()
|
|
855
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
856
|
-
|
|
857
|
-
if type(factor_idx) == str:
|
|
858
|
-
factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
|
|
859
|
-
|
|
850
|
+
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
860
851
|
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
861
852
|
return tensor_to_numpy(ms)
|
|
862
853
|
|
|
863
|
-
def
|
|
854
|
+
def _get_expression_response(self, delta_zs):
|
|
864
855
|
return self.decoder_concentrate(delta_zs)
|
|
865
856
|
|
|
866
|
-
def
|
|
857
|
+
def get_expression_response(self,
|
|
867
858
|
delta_zs,
|
|
868
859
|
batch_size: int = 1024):
|
|
869
860
|
"""
|
|
@@ -911,7 +902,7 @@ class SURE(nn.Module):
|
|
|
911
902
|
threshold: int = 0,
|
|
912
903
|
use_jax: bool = False):
|
|
913
904
|
"""
|
|
914
|
-
Train the
|
|
905
|
+
Train the PerturbFlow model.
|
|
915
906
|
|
|
916
907
|
Parameters
|
|
917
908
|
----------
|
|
@@ -937,7 +928,7 @@ class SURE(nn.Module):
|
|
|
937
928
|
Parameter for optimization.
|
|
938
929
|
use_jax
|
|
939
930
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
940
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
931
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
|
|
941
932
|
"""
|
|
942
933
|
xs = self.preprocess(xs, threshold=threshold)
|
|
943
934
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1055,12 +1046,12 @@ class SURE(nn.Module):
|
|
|
1055
1046
|
|
|
1056
1047
|
|
|
1057
1048
|
EXAMPLE_RUN = (
|
|
1058
|
-
"example run:
|
|
1049
|
+
"example run: PerturbFlow --help"
|
|
1059
1050
|
)
|
|
1060
1051
|
|
|
1061
1052
|
def parse_args():
|
|
1062
1053
|
parser = argparse.ArgumentParser(
|
|
1063
|
-
description="
|
|
1054
|
+
description="PerturbFlow\n{}".format(EXAMPLE_RUN))
|
|
1064
1055
|
|
|
1065
1056
|
parser.add_argument(
|
|
1066
1057
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1278,7 +1269,7 @@ def main():
|
|
|
1278
1269
|
latent_dist = args.z_dist
|
|
1279
1270
|
|
|
1280
1271
|
###########################################
|
|
1281
|
-
|
|
1272
|
+
perturbflow = PerturbFlow(
|
|
1282
1273
|
input_size=input_size,
|
|
1283
1274
|
cell_factor_size=cell_factor_size,
|
|
1284
1275
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1301,7 +1292,7 @@ def main():
|
|
|
1301
1292
|
dtype=dtype,
|
|
1302
1293
|
)
|
|
1303
1294
|
|
|
1304
|
-
|
|
1295
|
+
perturbflow.fit(xs, us=us,
|
|
1305
1296
|
num_epochs=args.num_epochs,
|
|
1306
1297
|
learning_rate=args.learning_rate,
|
|
1307
1298
|
batch_size=args.batch_size,
|
|
@@ -1313,9 +1304,9 @@ def main():
|
|
|
1313
1304
|
|
|
1314
1305
|
if args.save_model is not None:
|
|
1315
1306
|
if args.save_model.endswith('gz'):
|
|
1316
|
-
|
|
1307
|
+
PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
|
|
1317
1308
|
else:
|
|
1318
|
-
|
|
1309
|
+
PerturbFlow.save_model(perturbflow, args.save_model)
|
|
1319
1310
|
|
|
1320
1311
|
|
|
1321
1312
|
|