SURE-tools 2.0.10__tar.gz → 2.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {sure_tools-2.0.10 → sure_tools-2.1.1}/PKG-INFO +1 -1
  2. sure_tools-2.0.10/SURE/SURE.py → sure_tools-2.1.1/SURE/PerturbFlow.py +21 -30
  3. sure_tools-2.1.1/SURE/SURE.py +1236 -0
  4. sure_tools-2.1.1/SURE/__init__.py +11 -0
  5. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE_tools.egg-info/PKG-INFO +1 -1
  6. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE_tools.egg-info/SOURCES.txt +1 -0
  7. sure_tools-2.1.1/SURE_tools.egg-info/entry_points.txt +3 -0
  8. {sure_tools-2.0.10 → sure_tools-2.1.1}/setup.py +3 -2
  9. sure_tools-2.0.10/SURE/__init__.py +0 -10
  10. sure_tools-2.0.10/SURE_tools.egg-info/entry_points.txt +0 -2
  11. {sure_tools-2.0.10 → sure_tools-2.1.1}/LICENSE +0 -0
  12. {sure_tools-2.0.10 → sure_tools-2.1.1}/README.md +0 -0
  13. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/assembly/__init__.py +0 -0
  14. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/assembly/assembly.py +0 -0
  15. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/assembly/atlas.py +0 -0
  16. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/atac/__init__.py +0 -0
  17. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/atac/utils.py +0 -0
  18. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/codebook/__init__.py +0 -0
  19. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/codebook/codebook.py +0 -0
  20. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/flow/__init__.py +0 -0
  21. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/flow/plot_quiver.py +0 -0
  22. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/utils/__init__.py +0 -0
  23. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/utils/custom_mlp.py +0 -0
  24. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/utils/queue.py +0 -0
  25. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE/utils/utils.py +0 -0
  26. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE_tools.egg-info/requires.txt +0 -0
  28. {sure_tools-2.0.10 → sure_tools-2.1.1}/SURE_tools.egg-info/top_level.txt +0 -0
  29. {sure_tools-2.0.10 → sure_tools-2.1.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.0.10
3
+ Version: 2.1.1
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -54,7 +54,7 @@ def set_random_seed(seed):
54
54
  # Set seed for Pyro
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
- class SURE(nn.Module):
57
+ class PerturbFlow(nn.Module):
58
58
  """SUccinct REpresentation of single-omics cells
59
59
 
60
60
  Parameters
@@ -97,7 +97,6 @@ class SURE(nn.Module):
97
97
  input_size: int,
98
98
  codebook_size: int = 200,
99
99
  cell_factor_size: int = 0,
100
- cell_factor_names: list = None,
101
100
  supervised_mode: bool = False,
102
101
  z_dim: int = 10,
103
102
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
@@ -135,7 +134,6 @@ class SURE(nn.Module):
135
134
  self.post_layer_fct = post_layer_fct
136
135
  self.post_act_fct = post_act_fct
137
136
  self.hidden_layer_activation = hidden_layer_activation
138
- self.cell_factor_names = cell_factor_names
139
137
 
140
138
  self.codebook_weights = None
141
139
 
@@ -330,7 +328,7 @@ class SURE(nn.Module):
330
328
  return xs
331
329
 
332
330
  def model1(self, xs):
333
- pyro.module('sure', self)
331
+ pyro.module('PerturbFlow', self)
334
332
 
335
333
  eps = torch.finfo(xs.dtype).eps
336
334
  batch_size = xs.size(0)
@@ -407,7 +405,7 @@ class SURE(nn.Module):
407
405
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
408
406
 
409
407
  def model2(self, xs, us=None):
410
- pyro.module('sure', self)
408
+ pyro.module('PerturbFlow', self)
411
409
 
412
410
  eps = torch.finfo(xs.dtype).eps
413
411
  batch_size = xs.size(0)
@@ -495,7 +493,7 @@ class SURE(nn.Module):
495
493
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
496
494
 
497
495
  def model3(self, xs, ys, embeds=None):
498
- pyro.module('sure', self)
496
+ pyro.module('PerturbFlow', self)
499
497
 
500
498
  eps = torch.finfo(xs.dtype).eps
501
499
  batch_size = xs.size(0)
@@ -587,7 +585,7 @@ class SURE(nn.Module):
587
585
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
588
586
 
589
587
  def model4(self, xs, us, ys, embeds=None):
590
- pyro.module('sure', self)
588
+ pyro.module('PerturbFlow', self)
591
589
 
592
590
  eps = torch.finfo(xs.dtype).eps
593
591
  batch_size = xs.size(0)
@@ -813,11 +811,8 @@ class SURE(nn.Module):
813
811
  A = np.concatenate(A)
814
812
  return A
815
813
 
816
- def _cell_move(self, xs, factor_idx, perturb):
817
- zns,_ = self.encoder_zn(xs)
818
- if type(factor_idx) == str:
819
- factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
820
-
814
+ def _cell_response(self, xs, factor_idx, perturb):
815
+ zns,_ = self.encoder_zn(xs)
821
816
  if perturb.ndim==2:
822
817
  ms = self.cell_factor_effect[factor_idx]([zns, perturb])
823
818
  else:
@@ -825,7 +820,7 @@ class SURE(nn.Module):
825
820
 
826
821
  return ms
827
822
 
828
- def get_cell_move(self,
823
+ def get_cell_response(self,
829
824
  xs,
830
825
  factor_idx,
831
826
  perturb,
@@ -843,27 +838,23 @@ class SURE(nn.Module):
843
838
  Z = []
844
839
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
845
840
  for X_batch, P_batch, _ in dataloader:
846
- zns = self._cell_move(X_batch, factor_idx, P_batch)
841
+ zns = self._cell_response(X_batch, factor_idx, P_batch)
847
842
  Z.append(tensor_to_numpy(zns))
848
843
  pbar.update(1)
849
844
 
850
845
  Z = np.concatenate(Z)
851
846
  return Z
852
847
 
853
- def get_metacell_move(self, factor_idx, perturb):
848
+ def get_metacell_response(self, factor_idx, perturb):
854
849
  zs = self._get_codebook()
855
- ps = convert_to_tensor(perturb, device=self.get_device())
856
-
857
- if type(factor_idx) == str:
858
- factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
859
-
850
+ ps = convert_to_tensor(perturb, device=self.get_device())
860
851
  ms = self.cell_factor_effect[factor_idx]([zs,ps])
861
852
  return tensor_to_numpy(ms)
862
853
 
863
- def _get_expression_responses(self, delta_zs):
854
+ def _get_expression_response(self, delta_zs):
864
855
  return self.decoder_concentrate(delta_zs)
865
856
 
866
- def get_expression_responses(self,
857
+ def get_expression_response(self,
867
858
  delta_zs,
868
859
  batch_size: int = 1024):
869
860
  """
@@ -911,7 +902,7 @@ class SURE(nn.Module):
911
902
  threshold: int = 0,
912
903
  use_jax: bool = False):
913
904
  """
914
- Train the SURE model.
905
+ Train the PerturbFlow model.
915
906
 
916
907
  Parameters
917
908
  ----------
@@ -937,7 +928,7 @@ class SURE(nn.Module):
937
928
  Parameter for optimization.
938
929
  use_jax
939
930
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
940
- the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
931
+ the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
941
932
  """
942
933
  xs = self.preprocess(xs, threshold=threshold)
943
934
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
@@ -1055,12 +1046,12 @@ class SURE(nn.Module):
1055
1046
 
1056
1047
 
1057
1048
  EXAMPLE_RUN = (
1058
- "example run: SURE --help"
1049
+ "example run: PerturbFlow --help"
1059
1050
  )
1060
1051
 
1061
1052
  def parse_args():
1062
1053
  parser = argparse.ArgumentParser(
1063
- description="SURE\n{}".format(EXAMPLE_RUN))
1054
+ description="PerturbFlow\n{}".format(EXAMPLE_RUN))
1064
1055
 
1065
1056
  parser.add_argument(
1066
1057
  "--cuda", action="store_true", help="use GPU(s) to speed up training"
@@ -1278,7 +1269,7 @@ def main():
1278
1269
  latent_dist = args.z_dist
1279
1270
 
1280
1271
  ###########################################
1281
- sure = SURE(
1272
+ perturbflow = PerturbFlow(
1282
1273
  input_size=input_size,
1283
1274
  cell_factor_size=cell_factor_size,
1284
1275
  inverse_dispersion=args.inverse_dispersion,
@@ -1301,7 +1292,7 @@ def main():
1301
1292
  dtype=dtype,
1302
1293
  )
1303
1294
 
1304
- sure.fit(xs, us=us,
1295
+ perturbflow.fit(xs, us=us,
1305
1296
  num_epochs=args.num_epochs,
1306
1297
  learning_rate=args.learning_rate,
1307
1298
  batch_size=args.batch_size,
@@ -1313,9 +1304,9 @@ def main():
1313
1304
 
1314
1305
  if args.save_model is not None:
1315
1306
  if args.save_model.endswith('gz'):
1316
- SURE.save_model(sure, args.save_model, compression=True)
1307
+ PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
1317
1308
  else:
1318
- SURE.save_model(sure, args.save_model)
1309
+ PerturbFlow.save_model(perturbflow, args.save_model)
1319
1310
 
1320
1311
 
1321
1312