SURE-tools 2.1.91__tar.gz → 2.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (30) hide show
  1. {sure_tools-2.1.91 → sure_tools-2.2.0}/PKG-INFO +1 -1
  2. sure_tools-2.1.91/SURE/PerturbFlow.py → sure_tools-2.2.0/SURE/DensityFlow.py +15 -23
  3. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/__init__.py +3 -3
  4. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE_tools.egg-info/PKG-INFO +1 -1
  5. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE_tools.egg-info/SOURCES.txt +1 -1
  6. {sure_tools-2.1.91 → sure_tools-2.2.0}/setup.py +1 -1
  7. {sure_tools-2.1.91 → sure_tools-2.2.0}/LICENSE +0 -0
  8. {sure_tools-2.1.91 → sure_tools-2.2.0}/README.md +0 -0
  9. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/SURE.py +0 -0
  10. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/assembly/__init__.py +0 -0
  11. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/assembly/assembly.py +0 -0
  12. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/assembly/atlas.py +0 -0
  13. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/atac/__init__.py +0 -0
  14. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/atac/utils.py +0 -0
  15. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/codebook/__init__.py +0 -0
  16. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/codebook/codebook.py +0 -0
  17. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/flow/__init__.py +0 -0
  18. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/flow/flow_stats.py +0 -0
  19. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/flow/plot_quiver.py +0 -0
  20. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/perturb/__init__.py +0 -0
  21. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/perturb/perturb.py +0 -0
  22. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/utils/__init__.py +0 -0
  23. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/utils/custom_mlp.py +0 -0
  24. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/utils/queue.py +0 -0
  25. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE/utils/utils.py +0 -0
  26. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE_tools.egg-info/entry_points.txt +0 -0
  28. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE_tools.egg-info/requires.txt +0 -0
  29. {sure_tools-2.1.91 → sure_tools-2.2.0}/SURE_tools.egg-info/top_level.txt +0 -0
  30. {sure_tools-2.1.91 → sure_tools-2.2.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.91
3
+ Version: 2.2.0
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -54,12 +54,11 @@ def set_random_seed(seed):
54
54
  # Set seed for Pyro
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
- class PerturbFlow(nn.Module):
57
+ class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
60
  codebook_size: int = 200,
61
61
  cell_factor_size: int = 0,
62
- cell_factor_effect_discrete: bool = False,
63
62
  supervised_mode: bool = False,
64
63
  z_dim: int = 10,
65
64
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
@@ -103,7 +102,6 @@ class PerturbFlow(nn.Module):
103
102
  else:
104
103
  self.use_bias = [not zero_bias] * self.cell_factor_size
105
104
  #self.use_bias = not zero_bias
106
- self.enumrate = cell_factor_effect_discrete
107
105
 
108
106
  self.codebook_weights = None
109
107
 
@@ -310,7 +308,7 @@ class PerturbFlow(nn.Module):
310
308
  return xs
311
309
 
312
310
  def model1(self, xs):
313
- pyro.module('PerturbFlow', self)
311
+ pyro.module('DensityFlow', self)
314
312
 
315
313
  eps = torch.finfo(xs.dtype).eps
316
314
  batch_size = xs.size(0)
@@ -389,7 +387,7 @@ class PerturbFlow(nn.Module):
389
387
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
390
388
 
391
389
  def model2(self, xs, us=None):
392
- pyro.module('PerturbFlow', self)
390
+ pyro.module('DensityFlow', self)
393
391
 
394
392
  eps = torch.finfo(xs.dtype).eps
395
393
  batch_size = xs.size(0)
@@ -431,10 +429,7 @@ class PerturbFlow(nn.Module):
431
429
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
432
430
 
433
431
  if self.cell_factor_size>0:
434
- if self.enumrate:
435
- zus = self._total_effects(zn_loc, us)
436
- else:
437
- zus = self._total_effects(zns, us)
432
+ zus = self._total_effects(zns, us)
438
433
  zs = zns+zus
439
434
  else:
440
435
  zs = zns
@@ -476,7 +471,7 @@ class PerturbFlow(nn.Module):
476
471
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
477
472
 
478
473
  def model3(self, xs, ys, embeds=None):
479
- pyro.module('PerturbFlow', self)
474
+ pyro.module('DensityFlow', self)
480
475
 
481
476
  eps = torch.finfo(xs.dtype).eps
482
477
  batch_size = xs.size(0)
@@ -572,7 +567,7 @@ class PerturbFlow(nn.Module):
572
567
  zns = embeds
573
568
 
574
569
  def model4(self, xs, us, ys, embeds=None):
575
- pyro.module('PerturbFlow', self)
570
+ pyro.module('DensityFlow', self)
576
571
 
577
572
  eps = torch.finfo(xs.dtype).eps
578
573
  batch_size = xs.size(0)
@@ -636,10 +631,7 @@ class PerturbFlow(nn.Module):
636
631
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
637
632
  # else:
638
633
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
639
- if self.enumrate:
640
- zus = self._total_effects(zn_loc, us)
641
- else:
642
- zus = self._total_effects(zns, us)
634
+ zus = self._total_effects(zns, us)
643
635
  zs = zns+zus
644
636
  else:
645
637
  zs = zns
@@ -989,7 +981,7 @@ class PerturbFlow(nn.Module):
989
981
  threshold: int = 0,
990
982
  use_jax: bool = True):
991
983
  """
992
- Train the PerturbFlow model.
984
+ Train the DensityFlow model.
993
985
 
994
986
  Parameters
995
987
  ----------
@@ -1015,7 +1007,7 @@ class PerturbFlow(nn.Module):
1015
1007
  Parameter for optimization.
1016
1008
  use_jax
1017
1009
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
1018
- the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
1010
+ the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
1019
1011
  """
1020
1012
  xs = self.preprocess(xs, threshold=threshold)
1021
1013
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
@@ -1133,12 +1125,12 @@ class PerturbFlow(nn.Module):
1133
1125
 
1134
1126
 
1135
1127
  EXAMPLE_RUN = (
1136
- "example run: PerturbFlow --help"
1128
+ "example run: DensityFlow --help"
1137
1129
  )
1138
1130
 
1139
1131
  def parse_args():
1140
1132
  parser = argparse.ArgumentParser(
1141
- description="PerturbFlow\n{}".format(EXAMPLE_RUN))
1133
+ description="DensityFlow\n{}".format(EXAMPLE_RUN))
1142
1134
 
1143
1135
  parser.add_argument(
1144
1136
  "--cuda", action="store_true", help="use GPU(s) to speed up training"
@@ -1325,7 +1317,7 @@ def main():
1325
1317
  cell_factor_size = 0 if us is None else us.shape[1]
1326
1318
 
1327
1319
  ###########################################
1328
- perturbflow = PerturbFlow(
1320
+ DensityFlow = DensityFlow(
1329
1321
  input_size=input_size,
1330
1322
  cell_factor_size=cell_factor_size,
1331
1323
  inverse_dispersion=args.inverse_dispersion,
@@ -1344,7 +1336,7 @@ def main():
1344
1336
  dtype=dtype,
1345
1337
  )
1346
1338
 
1347
- perturbflow.fit(xs, us=us,
1339
+ DensityFlow.fit(xs, us=us,
1348
1340
  num_epochs=args.num_epochs,
1349
1341
  learning_rate=args.learning_rate,
1350
1342
  batch_size=args.batch_size,
@@ -1356,9 +1348,9 @@ def main():
1356
1348
 
1357
1349
  if args.save_model is not None:
1358
1350
  if args.save_model.endswith('gz'):
1359
- PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
1351
+ DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
1360
1352
  else:
1361
- PerturbFlow.save_model(perturbflow, args.save_model)
1353
+ DensityFlow.save_model(DensityFlow, args.save_model)
1362
1354
 
1363
1355
 
1364
1356
 
@@ -1,12 +1,12 @@
1
1
  from .SURE import SURE
2
- from .PerturbFlow import PerturbFlow
2
+ from .DensityFlow import DensityFlow
3
3
 
4
4
  from . import utils
5
5
  from . import codebook
6
6
  from . import SURE
7
- from . import PerturbFlow
7
+ from . import DensityFlow
8
8
  from . import atac
9
9
  from . import flow
10
10
  from . import perturb
11
11
 
12
- __all__ = ['SURE', 'PerturbFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
12
+ __all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.91
3
+ Version: 2.2.0
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,7 +1,7 @@
1
1
  LICENSE
2
2
  README.md
3
3
  setup.py
4
- SURE/PerturbFlow.py
4
+ SURE/DensityFlow.py
5
5
  SURE/SURE.py
6
6
  SURE/__init__.py
7
7
  SURE/assembly/__init__.py
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.1.91',
8
+ version='2.2.0',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes