SURE-tools 2.1.92__tar.gz → 2.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.92 → sure_tools-2.2.0}/PKG-INFO +1 -1
- sure_tools-2.1.92/SURE/PerturbFlow.py → sure_tools-2.2.0/SURE/DensityFlow.py +13 -13
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/__init__.py +3 -3
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE_tools.egg-info/SOURCES.txt +1 -1
- {sure_tools-2.1.92 → sure_tools-2.2.0}/setup.py +1 -1
- {sure_tools-2.1.92 → sure_tools-2.2.0}/LICENSE +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/README.md +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/SURE.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.92 → sure_tools-2.2.0}/setup.cfg +0 -0
|
@@ -54,7 +54,7 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
@@ -308,7 +308,7 @@ class PerturbFlow(nn.Module):
|
|
|
308
308
|
return xs
|
|
309
309
|
|
|
310
310
|
def model1(self, xs):
|
|
311
|
-
pyro.module('
|
|
311
|
+
pyro.module('DensityFlow', self)
|
|
312
312
|
|
|
313
313
|
eps = torch.finfo(xs.dtype).eps
|
|
314
314
|
batch_size = xs.size(0)
|
|
@@ -387,7 +387,7 @@ class PerturbFlow(nn.Module):
|
|
|
387
387
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
388
388
|
|
|
389
389
|
def model2(self, xs, us=None):
|
|
390
|
-
pyro.module('
|
|
390
|
+
pyro.module('DensityFlow', self)
|
|
391
391
|
|
|
392
392
|
eps = torch.finfo(xs.dtype).eps
|
|
393
393
|
batch_size = xs.size(0)
|
|
@@ -471,7 +471,7 @@ class PerturbFlow(nn.Module):
|
|
|
471
471
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
472
472
|
|
|
473
473
|
def model3(self, xs, ys, embeds=None):
|
|
474
|
-
pyro.module('
|
|
474
|
+
pyro.module('DensityFlow', self)
|
|
475
475
|
|
|
476
476
|
eps = torch.finfo(xs.dtype).eps
|
|
477
477
|
batch_size = xs.size(0)
|
|
@@ -567,7 +567,7 @@ class PerturbFlow(nn.Module):
|
|
|
567
567
|
zns = embeds
|
|
568
568
|
|
|
569
569
|
def model4(self, xs, us, ys, embeds=None):
|
|
570
|
-
pyro.module('
|
|
570
|
+
pyro.module('DensityFlow', self)
|
|
571
571
|
|
|
572
572
|
eps = torch.finfo(xs.dtype).eps
|
|
573
573
|
batch_size = xs.size(0)
|
|
@@ -981,7 +981,7 @@ class PerturbFlow(nn.Module):
|
|
|
981
981
|
threshold: int = 0,
|
|
982
982
|
use_jax: bool = True):
|
|
983
983
|
"""
|
|
984
|
-
Train the
|
|
984
|
+
Train the DensityFlow model.
|
|
985
985
|
|
|
986
986
|
Parameters
|
|
987
987
|
----------
|
|
@@ -1007,7 +1007,7 @@ class PerturbFlow(nn.Module):
|
|
|
1007
1007
|
Parameter for optimization.
|
|
1008
1008
|
use_jax
|
|
1009
1009
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
1010
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
1010
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
|
|
1011
1011
|
"""
|
|
1012
1012
|
xs = self.preprocess(xs, threshold=threshold)
|
|
1013
1013
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1125,12 +1125,12 @@ class PerturbFlow(nn.Module):
|
|
|
1125
1125
|
|
|
1126
1126
|
|
|
1127
1127
|
EXAMPLE_RUN = (
|
|
1128
|
-
"example run:
|
|
1128
|
+
"example run: DensityFlow --help"
|
|
1129
1129
|
)
|
|
1130
1130
|
|
|
1131
1131
|
def parse_args():
|
|
1132
1132
|
parser = argparse.ArgumentParser(
|
|
1133
|
-
description="
|
|
1133
|
+
description="DensityFlow\n{}".format(EXAMPLE_RUN))
|
|
1134
1134
|
|
|
1135
1135
|
parser.add_argument(
|
|
1136
1136
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1317,7 +1317,7 @@ def main():
|
|
|
1317
1317
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1318
1318
|
|
|
1319
1319
|
###########################################
|
|
1320
|
-
|
|
1320
|
+
DensityFlow = DensityFlow(
|
|
1321
1321
|
input_size=input_size,
|
|
1322
1322
|
cell_factor_size=cell_factor_size,
|
|
1323
1323
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1336,7 +1336,7 @@ def main():
|
|
|
1336
1336
|
dtype=dtype,
|
|
1337
1337
|
)
|
|
1338
1338
|
|
|
1339
|
-
|
|
1339
|
+
DensityFlow.fit(xs, us=us,
|
|
1340
1340
|
num_epochs=args.num_epochs,
|
|
1341
1341
|
learning_rate=args.learning_rate,
|
|
1342
1342
|
batch_size=args.batch_size,
|
|
@@ -1348,9 +1348,9 @@ def main():
|
|
|
1348
1348
|
|
|
1349
1349
|
if args.save_model is not None:
|
|
1350
1350
|
if args.save_model.endswith('gz'):
|
|
1351
|
-
|
|
1351
|
+
DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
|
|
1352
1352
|
else:
|
|
1353
|
-
|
|
1353
|
+
DensityFlow.save_model(DensityFlow, args.save_model)
|
|
1354
1354
|
|
|
1355
1355
|
|
|
1356
1356
|
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
|
-
from .
|
|
2
|
+
from .DensityFlow import DensityFlow
|
|
3
3
|
|
|
4
4
|
from . import utils
|
|
5
5
|
from . import codebook
|
|
6
6
|
from . import SURE
|
|
7
|
-
from . import
|
|
7
|
+
from . import DensityFlow
|
|
8
8
|
from . import atac
|
|
9
9
|
from . import flow
|
|
10
10
|
from . import perturb
|
|
11
11
|
|
|
12
|
-
__all__ = ['SURE', '
|
|
12
|
+
__all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|