SURE-tools 2.0.9__tar.gz → 2.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (29) hide show
  1. {sure_tools-2.0.9 → sure_tools-2.1.0}/PKG-INFO +1 -1
  2. sure_tools-2.0.9/SURE/SURE.py → sure_tools-2.1.0/SURE/PerturbFlow.py +41 -17
  3. sure_tools-2.1.0/SURE/SURE.py +1253 -0
  4. sure_tools-2.1.0/SURE/__init__.py +11 -0
  5. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE_tools.egg-info/PKG-INFO +1 -1
  6. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE_tools.egg-info/SOURCES.txt +1 -0
  7. sure_tools-2.1.0/SURE_tools.egg-info/entry_points.txt +3 -0
  8. {sure_tools-2.0.9 → sure_tools-2.1.0}/setup.py +3 -2
  9. sure_tools-2.0.9/SURE/__init__.py +0 -10
  10. sure_tools-2.0.9/SURE_tools.egg-info/entry_points.txt +0 -2
  11. {sure_tools-2.0.9 → sure_tools-2.1.0}/LICENSE +0 -0
  12. {sure_tools-2.0.9 → sure_tools-2.1.0}/README.md +0 -0
  13. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/assembly/__init__.py +0 -0
  14. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/assembly/assembly.py +0 -0
  15. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/assembly/atlas.py +0 -0
  16. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/atac/__init__.py +0 -0
  17. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/atac/utils.py +0 -0
  18. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/codebook/__init__.py +0 -0
  19. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/codebook/codebook.py +0 -0
  20. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/flow/__init__.py +0 -0
  21. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/flow/plot_quiver.py +0 -0
  22. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/utils/__init__.py +0 -0
  23. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/utils/custom_mlp.py +0 -0
  24. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/utils/queue.py +0 -0
  25. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE/utils/utils.py +0 -0
  26. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE_tools.egg-info/requires.txt +0 -0
  28. {sure_tools-2.0.9 → sure_tools-2.1.0}/SURE_tools.egg-info/top_level.txt +0 -0
  29. {sure_tools-2.0.9 → sure_tools-2.1.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.0.9
3
+ Version: 2.1.0
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -54,7 +54,7 @@ def set_random_seed(seed):
54
54
  # Set seed for Pyro
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
- class SURE(nn.Module):
57
+ class PerturbFlow(nn.Module):
58
58
  """SUccinct REpresentation of single-omics cells
59
59
 
60
60
  Parameters
@@ -330,7 +330,7 @@ class SURE(nn.Module):
330
330
  return xs
331
331
 
332
332
  def model1(self, xs):
333
- pyro.module('sure', self)
333
+ pyro.module('PerturbFlow', self)
334
334
 
335
335
  eps = torch.finfo(xs.dtype).eps
336
336
  batch_size = xs.size(0)
@@ -407,7 +407,7 @@ class SURE(nn.Module):
407
407
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
408
408
 
409
409
  def model2(self, xs, us=None):
410
- pyro.module('sure', self)
410
+ pyro.module('PerturbFlow', self)
411
411
 
412
412
  eps = torch.finfo(xs.dtype).eps
413
413
  batch_size = xs.size(0)
@@ -495,7 +495,7 @@ class SURE(nn.Module):
495
495
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
496
496
 
497
497
  def model3(self, xs, ys, embeds=None):
498
- pyro.module('sure', self)
498
+ pyro.module('PerturbFlow', self)
499
499
 
500
500
  eps = torch.finfo(xs.dtype).eps
501
501
  batch_size = xs.size(0)
@@ -587,7 +587,7 @@ class SURE(nn.Module):
587
587
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
588
588
 
589
589
  def model4(self, xs, us, ys, embeds=None):
590
- pyro.module('sure', self)
590
+ pyro.module('PerturbFlow', self)
591
591
 
592
592
  eps = torch.finfo(xs.dtype).eps
593
593
  batch_size = xs.size(0)
@@ -813,7 +813,7 @@ class SURE(nn.Module):
813
813
  A = np.concatenate(A)
814
814
  return A
815
815
 
816
- def _cell_move(self, xs, factor_idx, perturb):
816
+ def _cell_state_response(self, xs, factor_idx, perturb):
817
817
  zns,_ = self.encoder_zn(xs)
818
818
  if type(factor_idx) == str:
819
819
  factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
@@ -825,7 +825,7 @@ class SURE(nn.Module):
825
825
 
826
826
  return ms
827
827
 
828
- def get_cell_move(self,
828
+ def get_cell_state_response(self,
829
829
  xs,
830
830
  factor_idx,
831
831
  perturb,
@@ -843,14 +843,14 @@ class SURE(nn.Module):
843
843
  Z = []
844
844
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
845
845
  for X_batch, P_batch, _ in dataloader:
846
- zns = self._cell_move(X_batch, factor_idx, P_batch)
846
+ zns = self._cell_state_response(X_batch, factor_idx, P_batch)
847
847
  Z.append(tensor_to_numpy(zns))
848
848
  pbar.update(1)
849
849
 
850
850
  Z = np.concatenate(Z)
851
851
  return Z
852
852
 
853
- def get_metacell_move(self, factor_idx, perturb):
853
+ def get_metacell_response(self, factor_idx, perturb):
854
854
  zs = self._get_codebook()
855
855
  ps = convert_to_tensor(perturb, device=self.get_device())
856
856
 
@@ -860,6 +860,30 @@ class SURE(nn.Module):
860
860
  ms = self.cell_factor_effect[factor_idx]([zs,ps])
861
861
  return tensor_to_numpy(ms)
862
862
 
863
+ def _get_expression_response(self, delta_zs):
864
+ return self.decoder_concentrate(delta_zs)
865
+
866
+ def get_expression_response(self,
867
+ delta_zs,
868
+ batch_size: int = 1024):
869
+ """
870
+ Return cells' changes in the latent space induced by specific perturbation of a factor
871
+
872
+ """
873
+ delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
874
+ dataset = CustomDataset(delta_zs)
875
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
876
+
877
+ R = []
878
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
879
+ for delta_Z_batch, _ in dataloader:
880
+ r = self._cell_move(delta_Z_batch)
881
+ R.append(tensor_to_numpy(r))
882
+ pbar.update(1)
883
+
884
+ R = np.concatenate(R)
885
+ return R
886
+
863
887
  def preprocess(self, xs, threshold=0):
864
888
  if self.loss_func == 'bernoulli':
865
889
  ad = sc.AnnData(xs)
@@ -887,7 +911,7 @@ class SURE(nn.Module):
887
911
  threshold: int = 0,
888
912
  use_jax: bool = False):
889
913
  """
890
- Train the SURE model.
914
+ Train the PerturbFlow model.
891
915
 
892
916
  Parameters
893
917
  ----------
@@ -913,7 +937,7 @@ class SURE(nn.Module):
913
937
  Parameter for optimization.
914
938
  use_jax
915
939
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
916
- the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
940
+ the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
917
941
  """
918
942
  xs = self.preprocess(xs, threshold=threshold)
919
943
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
@@ -1031,12 +1055,12 @@ class SURE(nn.Module):
1031
1055
 
1032
1056
 
1033
1057
  EXAMPLE_RUN = (
1034
- "example run: SURE --help"
1058
+ "example run: PerturbFlow --help"
1035
1059
  )
1036
1060
 
1037
1061
  def parse_args():
1038
1062
  parser = argparse.ArgumentParser(
1039
- description="SURE\n{}".format(EXAMPLE_RUN))
1063
+ description="PerturbFlow\n{}".format(EXAMPLE_RUN))
1040
1064
 
1041
1065
  parser.add_argument(
1042
1066
  "--cuda", action="store_true", help="use GPU(s) to speed up training"
@@ -1254,7 +1278,7 @@ def main():
1254
1278
  latent_dist = args.z_dist
1255
1279
 
1256
1280
  ###########################################
1257
- sure = SURE(
1281
+ perturbflow = PerturbFlow(
1258
1282
  input_size=input_size,
1259
1283
  cell_factor_size=cell_factor_size,
1260
1284
  inverse_dispersion=args.inverse_dispersion,
@@ -1277,7 +1301,7 @@ def main():
1277
1301
  dtype=dtype,
1278
1302
  )
1279
1303
 
1280
- sure.fit(xs, us=us,
1304
+ perturbflow.fit(xs, us=us,
1281
1305
  num_epochs=args.num_epochs,
1282
1306
  learning_rate=args.learning_rate,
1283
1307
  batch_size=args.batch_size,
@@ -1289,9 +1313,9 @@ def main():
1289
1313
 
1290
1314
  if args.save_model is not None:
1291
1315
  if args.save_model.endswith('gz'):
1292
- SURE.save_model(sure, args.save_model, compression=True)
1316
+ PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
1293
1317
  else:
1294
- SURE.save_model(sure, args.save_model)
1318
+ PerturbFlow.save_model(perturbflow, args.save_model)
1295
1319
 
1296
1320
 
1297
1321