SURE-tools 2.0.10__tar.gz → 2.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (29) hide show
  1. {sure_tools-2.0.10 → sure_tools-2.1.0}/PKG-INFO +1 -1
  2. sure_tools-2.0.10/SURE/SURE.py → sure_tools-2.1.0/SURE/PerturbFlow.py +19 -19
  3. sure_tools-2.1.0/SURE/SURE.py +1253 -0
  4. sure_tools-2.1.0/SURE/__init__.py +11 -0
  5. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE_tools.egg-info/PKG-INFO +1 -1
  6. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE_tools.egg-info/SOURCES.txt +1 -0
  7. sure_tools-2.1.0/SURE_tools.egg-info/entry_points.txt +3 -0
  8. {sure_tools-2.0.10 → sure_tools-2.1.0}/setup.py +3 -2
  9. sure_tools-2.0.10/SURE/__init__.py +0 -10
  10. sure_tools-2.0.10/SURE_tools.egg-info/entry_points.txt +0 -2
  11. {sure_tools-2.0.10 → sure_tools-2.1.0}/LICENSE +0 -0
  12. {sure_tools-2.0.10 → sure_tools-2.1.0}/README.md +0 -0
  13. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/assembly/__init__.py +0 -0
  14. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/assembly/assembly.py +0 -0
  15. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/assembly/atlas.py +0 -0
  16. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/atac/__init__.py +0 -0
  17. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/atac/utils.py +0 -0
  18. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/codebook/__init__.py +0 -0
  19. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/codebook/codebook.py +0 -0
  20. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/flow/__init__.py +0 -0
  21. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/flow/plot_quiver.py +0 -0
  22. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/utils/__init__.py +0 -0
  23. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/utils/custom_mlp.py +0 -0
  24. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/utils/queue.py +0 -0
  25. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE/utils/utils.py +0 -0
  26. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE_tools.egg-info/requires.txt +0 -0
  28. {sure_tools-2.0.10 → sure_tools-2.1.0}/SURE_tools.egg-info/top_level.txt +0 -0
  29. {sure_tools-2.0.10 → sure_tools-2.1.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.0.10
3
+ Version: 2.1.0
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -54,7 +54,7 @@ def set_random_seed(seed):
54
54
  # Set seed for Pyro
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
- class SURE(nn.Module):
57
+ class PerturbFlow(nn.Module):
58
58
  """SUccinct REpresentation of single-omics cells
59
59
 
60
60
  Parameters
@@ -330,7 +330,7 @@ class SURE(nn.Module):
330
330
  return xs
331
331
 
332
332
  def model1(self, xs):
333
- pyro.module('sure', self)
333
+ pyro.module('PerturbFlow', self)
334
334
 
335
335
  eps = torch.finfo(xs.dtype).eps
336
336
  batch_size = xs.size(0)
@@ -407,7 +407,7 @@ class SURE(nn.Module):
407
407
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
408
408
 
409
409
  def model2(self, xs, us=None):
410
- pyro.module('sure', self)
410
+ pyro.module('PerturbFlow', self)
411
411
 
412
412
  eps = torch.finfo(xs.dtype).eps
413
413
  batch_size = xs.size(0)
@@ -495,7 +495,7 @@ class SURE(nn.Module):
495
495
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
496
496
 
497
497
  def model3(self, xs, ys, embeds=None):
498
- pyro.module('sure', self)
498
+ pyro.module('PerturbFlow', self)
499
499
 
500
500
  eps = torch.finfo(xs.dtype).eps
501
501
  batch_size = xs.size(0)
@@ -587,7 +587,7 @@ class SURE(nn.Module):
587
587
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
588
588
 
589
589
  def model4(self, xs, us, ys, embeds=None):
590
- pyro.module('sure', self)
590
+ pyro.module('PerturbFlow', self)
591
591
 
592
592
  eps = torch.finfo(xs.dtype).eps
593
593
  batch_size = xs.size(0)
@@ -813,7 +813,7 @@ class SURE(nn.Module):
813
813
  A = np.concatenate(A)
814
814
  return A
815
815
 
816
- def _cell_move(self, xs, factor_idx, perturb):
816
+ def _cell_state_response(self, xs, factor_idx, perturb):
817
817
  zns,_ = self.encoder_zn(xs)
818
818
  if type(factor_idx) == str:
819
819
  factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
@@ -825,7 +825,7 @@ class SURE(nn.Module):
825
825
 
826
826
  return ms
827
827
 
828
- def get_cell_move(self,
828
+ def get_cell_state_response(self,
829
829
  xs,
830
830
  factor_idx,
831
831
  perturb,
@@ -843,14 +843,14 @@ class SURE(nn.Module):
843
843
  Z = []
844
844
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
845
845
  for X_batch, P_batch, _ in dataloader:
846
- zns = self._cell_move(X_batch, factor_idx, P_batch)
846
+ zns = self._cell_state_response(X_batch, factor_idx, P_batch)
847
847
  Z.append(tensor_to_numpy(zns))
848
848
  pbar.update(1)
849
849
 
850
850
  Z = np.concatenate(Z)
851
851
  return Z
852
852
 
853
- def get_metacell_move(self, factor_idx, perturb):
853
+ def get_metacell_response(self, factor_idx, perturb):
854
854
  zs = self._get_codebook()
855
855
  ps = convert_to_tensor(perturb, device=self.get_device())
856
856
 
@@ -860,10 +860,10 @@ class SURE(nn.Module):
860
860
  ms = self.cell_factor_effect[factor_idx]([zs,ps])
861
861
  return tensor_to_numpy(ms)
862
862
 
863
- def _get_expression_responses(self, delta_zs):
863
+ def _get_expression_response(self, delta_zs):
864
864
  return self.decoder_concentrate(delta_zs)
865
865
 
866
- def get_expression_responses(self,
866
+ def get_expression_response(self,
867
867
  delta_zs,
868
868
  batch_size: int = 1024):
869
869
  """
@@ -911,7 +911,7 @@ class SURE(nn.Module):
911
911
  threshold: int = 0,
912
912
  use_jax: bool = False):
913
913
  """
914
- Train the SURE model.
914
+ Train the PerturbFlow model.
915
915
 
916
916
  Parameters
917
917
  ----------
@@ -937,7 +937,7 @@ class SURE(nn.Module):
937
937
  Parameter for optimization.
938
938
  use_jax
939
939
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
940
- the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
940
+ the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
941
941
  """
942
942
  xs = self.preprocess(xs, threshold=threshold)
943
943
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
@@ -1055,12 +1055,12 @@ class SURE(nn.Module):
1055
1055
 
1056
1056
 
1057
1057
  EXAMPLE_RUN = (
1058
- "example run: SURE --help"
1058
+ "example run: PerturbFlow --help"
1059
1059
  )
1060
1060
 
1061
1061
  def parse_args():
1062
1062
  parser = argparse.ArgumentParser(
1063
- description="SURE\n{}".format(EXAMPLE_RUN))
1063
+ description="PerturbFlow\n{}".format(EXAMPLE_RUN))
1064
1064
 
1065
1065
  parser.add_argument(
1066
1066
  "--cuda", action="store_true", help="use GPU(s) to speed up training"
@@ -1278,7 +1278,7 @@ def main():
1278
1278
  latent_dist = args.z_dist
1279
1279
 
1280
1280
  ###########################################
1281
- sure = SURE(
1281
+ perturbflow = PerturbFlow(
1282
1282
  input_size=input_size,
1283
1283
  cell_factor_size=cell_factor_size,
1284
1284
  inverse_dispersion=args.inverse_dispersion,
@@ -1301,7 +1301,7 @@ def main():
1301
1301
  dtype=dtype,
1302
1302
  )
1303
1303
 
1304
- sure.fit(xs, us=us,
1304
+ perturbflow.fit(xs, us=us,
1305
1305
  num_epochs=args.num_epochs,
1306
1306
  learning_rate=args.learning_rate,
1307
1307
  batch_size=args.batch_size,
@@ -1313,9 +1313,9 @@ def main():
1313
1313
 
1314
1314
  if args.save_model is not None:
1315
1315
  if args.save_model.endswith('gz'):
1316
- SURE.save_model(sure, args.save_model, compression=True)
1316
+ PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
1317
1317
  else:
1318
- SURE.save_model(sure, args.save_model)
1318
+ PerturbFlow.save_model(perturbflow, args.save_model)
1319
1319
 
1320
1320
 
1321
1321