SURE-tools 2.1.91__tar.gz → 2.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.91 → sure_tools-2.2.0}/PKG-INFO +1 -1
- sure_tools-2.1.91/SURE/PerturbFlow.py → sure_tools-2.2.0/SURE/DensityFlow.py +15 -23
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/__init__.py +3 -3
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE_tools.egg-info/SOURCES.txt +1 -1
- {sure_tools-2.1.91 → sure_tools-2.2.0}/setup.py +1 -1
- {sure_tools-2.1.91 → sure_tools-2.2.0}/LICENSE +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/README.md +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/SURE.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.0}/setup.cfg +0 -0
|
@@ -54,12 +54,11 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
|
-
cell_factor_effect_discrete: bool = False,
|
|
63
62
|
supervised_mode: bool = False,
|
|
64
63
|
z_dim: int = 10,
|
|
65
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
@@ -103,7 +102,6 @@ class PerturbFlow(nn.Module):
|
|
|
103
102
|
else:
|
|
104
103
|
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
105
104
|
#self.use_bias = not zero_bias
|
|
106
|
-
self.enumrate = cell_factor_effect_discrete
|
|
107
105
|
|
|
108
106
|
self.codebook_weights = None
|
|
109
107
|
|
|
@@ -310,7 +308,7 @@ class PerturbFlow(nn.Module):
|
|
|
310
308
|
return xs
|
|
311
309
|
|
|
312
310
|
def model1(self, xs):
|
|
313
|
-
pyro.module('
|
|
311
|
+
pyro.module('DensityFlow', self)
|
|
314
312
|
|
|
315
313
|
eps = torch.finfo(xs.dtype).eps
|
|
316
314
|
batch_size = xs.size(0)
|
|
@@ -389,7 +387,7 @@ class PerturbFlow(nn.Module):
|
|
|
389
387
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
390
388
|
|
|
391
389
|
def model2(self, xs, us=None):
|
|
392
|
-
pyro.module('
|
|
390
|
+
pyro.module('DensityFlow', self)
|
|
393
391
|
|
|
394
392
|
eps = torch.finfo(xs.dtype).eps
|
|
395
393
|
batch_size = xs.size(0)
|
|
@@ -431,10 +429,7 @@ class PerturbFlow(nn.Module):
|
|
|
431
429
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
432
430
|
|
|
433
431
|
if self.cell_factor_size>0:
|
|
434
|
-
|
|
435
|
-
zus = self._total_effects(zn_loc, us)
|
|
436
|
-
else:
|
|
437
|
-
zus = self._total_effects(zns, us)
|
|
432
|
+
zus = self._total_effects(zns, us)
|
|
438
433
|
zs = zns+zus
|
|
439
434
|
else:
|
|
440
435
|
zs = zns
|
|
@@ -476,7 +471,7 @@ class PerturbFlow(nn.Module):
|
|
|
476
471
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
477
472
|
|
|
478
473
|
def model3(self, xs, ys, embeds=None):
|
|
479
|
-
pyro.module('
|
|
474
|
+
pyro.module('DensityFlow', self)
|
|
480
475
|
|
|
481
476
|
eps = torch.finfo(xs.dtype).eps
|
|
482
477
|
batch_size = xs.size(0)
|
|
@@ -572,7 +567,7 @@ class PerturbFlow(nn.Module):
|
|
|
572
567
|
zns = embeds
|
|
573
568
|
|
|
574
569
|
def model4(self, xs, us, ys, embeds=None):
|
|
575
|
-
pyro.module('
|
|
570
|
+
pyro.module('DensityFlow', self)
|
|
576
571
|
|
|
577
572
|
eps = torch.finfo(xs.dtype).eps
|
|
578
573
|
batch_size = xs.size(0)
|
|
@@ -636,10 +631,7 @@ class PerturbFlow(nn.Module):
|
|
|
636
631
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
637
632
|
# else:
|
|
638
633
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
639
|
-
|
|
640
|
-
zus = self._total_effects(zn_loc, us)
|
|
641
|
-
else:
|
|
642
|
-
zus = self._total_effects(zns, us)
|
|
634
|
+
zus = self._total_effects(zns, us)
|
|
643
635
|
zs = zns+zus
|
|
644
636
|
else:
|
|
645
637
|
zs = zns
|
|
@@ -989,7 +981,7 @@ class PerturbFlow(nn.Module):
|
|
|
989
981
|
threshold: int = 0,
|
|
990
982
|
use_jax: bool = True):
|
|
991
983
|
"""
|
|
992
|
-
Train the
|
|
984
|
+
Train the DensityFlow model.
|
|
993
985
|
|
|
994
986
|
Parameters
|
|
995
987
|
----------
|
|
@@ -1015,7 +1007,7 @@ class PerturbFlow(nn.Module):
|
|
|
1015
1007
|
Parameter for optimization.
|
|
1016
1008
|
use_jax
|
|
1017
1009
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
1018
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
1010
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
|
|
1019
1011
|
"""
|
|
1020
1012
|
xs = self.preprocess(xs, threshold=threshold)
|
|
1021
1013
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1133,12 +1125,12 @@ class PerturbFlow(nn.Module):
|
|
|
1133
1125
|
|
|
1134
1126
|
|
|
1135
1127
|
EXAMPLE_RUN = (
|
|
1136
|
-
"example run:
|
|
1128
|
+
"example run: DensityFlow --help"
|
|
1137
1129
|
)
|
|
1138
1130
|
|
|
1139
1131
|
def parse_args():
|
|
1140
1132
|
parser = argparse.ArgumentParser(
|
|
1141
|
-
description="
|
|
1133
|
+
description="DensityFlow\n{}".format(EXAMPLE_RUN))
|
|
1142
1134
|
|
|
1143
1135
|
parser.add_argument(
|
|
1144
1136
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1325,7 +1317,7 @@ def main():
|
|
|
1325
1317
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1326
1318
|
|
|
1327
1319
|
###########################################
|
|
1328
|
-
|
|
1320
|
+
DensityFlow = DensityFlow(
|
|
1329
1321
|
input_size=input_size,
|
|
1330
1322
|
cell_factor_size=cell_factor_size,
|
|
1331
1323
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1344,7 +1336,7 @@ def main():
|
|
|
1344
1336
|
dtype=dtype,
|
|
1345
1337
|
)
|
|
1346
1338
|
|
|
1347
|
-
|
|
1339
|
+
DensityFlow.fit(xs, us=us,
|
|
1348
1340
|
num_epochs=args.num_epochs,
|
|
1349
1341
|
learning_rate=args.learning_rate,
|
|
1350
1342
|
batch_size=args.batch_size,
|
|
@@ -1356,9 +1348,9 @@ def main():
|
|
|
1356
1348
|
|
|
1357
1349
|
if args.save_model is not None:
|
|
1358
1350
|
if args.save_model.endswith('gz'):
|
|
1359
|
-
|
|
1351
|
+
DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
|
|
1360
1352
|
else:
|
|
1361
|
-
|
|
1353
|
+
DensityFlow.save_model(DensityFlow, args.save_model)
|
|
1362
1354
|
|
|
1363
1355
|
|
|
1364
1356
|
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
|
-
from .
|
|
2
|
+
from .DensityFlow import DensityFlow
|
|
3
3
|
|
|
4
4
|
from . import utils
|
|
5
5
|
from . import codebook
|
|
6
6
|
from . import SURE
|
|
7
|
-
from . import
|
|
7
|
+
from . import DensityFlow
|
|
8
8
|
from . import atac
|
|
9
9
|
from . import flow
|
|
10
10
|
from . import perturb
|
|
11
11
|
|
|
12
|
-
__all__ = ['SURE', '
|
|
12
|
+
__all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|